finish login route
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Patrick Michl 2022-05-01 21:11:06 +02:00
parent c20b4c6a23
commit 2c2ac27c26
16 changed files with 277 additions and 27 deletions

View File

@ -8,4 +8,10 @@ steps:
environment:
SQLX_OFFLINE: 'true'
commands:
- cargo check
- cargo check
- name: cargo test
image: rust:latest
environment:
SQLX_OFFLINE: 'true'
commands:
- cargo test

29
Cargo.lock generated
View File

@ -604,7 +604,7 @@ dependencies = [
"anyhow",
"argon2",
"axum",
"rand_core",
"rand",
"serde",
"sqlx",
"thiserror",
@ -809,6 +809,12 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
[[package]]
name = "ppv-lite86"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro2"
version = "1.0.37"
@ -827,6 +833,27 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.3"

View File

@ -16,5 +16,5 @@ sqlx = { version = "0.5", features = ["sqlite", "macros", "runtime-tokio-rustls"
anyhow = "1.0"
thiserror = "1.0"
argon2 = { version = "0.4", features = ["std"] }
rand_core = { version = "0.6", features = ["std"] }
rand = { version = "0.8.5", features = ["std"] }
uuid = { version = "1.0", features = ["v4"] }

View File

@ -14,6 +14,8 @@ pub enum RegistrationError {
InvalidUserId,
#[error("The desired user ID is already taken")]
UserIdTaken,
#[error("Registration is disabled")]
RegistrationDisabled,
}
impl IntoResponse for RegistrationError {
@ -42,6 +44,15 @@ impl IntoResponse for RegistrationError {
)),
)
.into_response(),
RegistrationError::RegistrationDisabled => (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(
ErrorCode::Forbidden,
&self.to_string(),
None,
)),
)
.into_response(),
}
}
}

View File

@ -1,3 +1,3 @@
pub mod auth;
pub mod errors;
pub mod r0;
pub mod versions;

View File

@ -7,16 +7,19 @@ use axum::{
routing::{get, post},
Extension, Json,
};
use rand_core::OsRng;
use sqlx::SqlitePool;
use crate::{
api::client_server::errors::authentication_error::AuthenticationError,
api::client_server::errors::{
api_error::ApiError, authentication_error::AuthenticationError,
registration_error::RegistrationError,
},
models::sessions::Session,
responses::{
authentication::{AuthenticationResponse, AuthenticationSuccess},
registration::RegistrationResponse,
},
types::uuid::Uuid,
};
use crate::{
models::devices::Device,
@ -30,8 +33,6 @@ use crate::{
Config,
};
use super::errors::{api_error::ApiError, registration_error::RegistrationError};
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/r0/login", get(get_login).post(post_login))
@ -53,17 +54,34 @@ async fn post_login(
match body {
AuthenticationData::Password(auth_data) => {
let user = auth_data.user().unwrap();
let user_id = UserId::new(&user, config.homeserver_name())
let user_id = UserId::new(&user, config.server_name())
.ok()
.ok_or(AuthenticationError::InvalidUserId)?;
let user = User::find_by_user_id(&db, &user_id).await?;
user.password_correct(auth_data.password()).ok().ok_or(AuthenticationError::Forbidden)?;
user.password_correct(auth_data.password())
.ok()
.ok_or(AuthenticationError::Forbidden)?;
todo!("find_or_create device for user and create a session");
let device = if let Some(device_id) = auth_data.device_id() {
Device::find_for_user(&db, &user, device_id).await?
} else {
let device_id = uuid::Uuid::new_v4().to_string();
let display_name =
if let Some(display_name) = auth_data.initial_device_display_name() {
display_name.as_ref()
} else {
"Generic Device"
};
Device::new(&user, &device_id, display_name)?
.create(&db)
.await?
};
let resp = AuthenticationSuccess::new("", "", &user_id);
let session = Session::new(&device)?.create(&db).await?;
let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id);
Ok(Json(AuthenticationResponse::Success(resp)))
}
}
@ -78,7 +96,7 @@ async fn get_username_available(
let username = params
.get("username")
.ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.homeserver_name())
let user_id = UserId::new(username, &config.server_name())
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
let exists = User::exists(&db, &user_id).await?;
@ -92,13 +110,14 @@ async fn post_register(
Extension(db): Extension<SqlitePool>,
Json(body): Json<RegistrationRequest>,
) -> Result<Json<RegistrationResponse>, ApiError> {
config.enable_registration().then(|| true).ok_or(RegistrationError::RegistrationDisabled)?;
body.auth()
.ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
let (user, device) = match &body.auth().expect("must be Some") {
AuthenticationData::Password(auth_data) => {
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.homeserver_name())
let user_id = UserId::new(username, &config.server_name())
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
@ -115,9 +134,13 @@ async fn post_register(
.create(&db)
.await?;
let device = Device::new(&user, "test", display_name)?
.create(&db)
.await?;
let device = Device::new(
&user,
uuid::Uuid::new_v4().to_string().as_ref(),
display_name,
)?
.create(&db)
.await?;
(user, device)
}

View File

@ -0,0 +1 @@
pub mod auth;

View File

@ -1,6 +1,9 @@
use crate::types::server_name::ServerName;
pub struct Config {
db_path: String,
homeserver_name: String,
server_name: ServerName,
enable_registration: bool
}
impl Config {
@ -12,8 +15,14 @@ impl Config {
/// Get a reference to the config's homeserver name.
#[must_use]
pub fn homeserver_name(&self) -> &str {
self.homeserver_name.as_ref()
pub fn server_name(&self) -> &ServerName {
&self.server_name
}
/// Get the config's enable registration.
#[must_use]
pub fn enable_registration(&self) -> bool {
self.enable_registration
}
}
@ -21,7 +30,8 @@ impl Default for Config {
fn default() -> Self {
Self {
db_path: "sqlite://db.sqlite3".into(),
homeserver_name: "fuckwit.dev".into(),
server_name: ServerName::new("fuckwit.dev").unwrap(),
enable_registration: true
}
}
}

View File

@ -43,7 +43,7 @@ async fn main() -> anyhow::Result<()> {
let client_server = Router::new()
.merge(api::client_server::versions::routes())
.merge(api::client_server::auth::routes());
.merge(api::client_server::r0::auth::routes());
let router = Router::new()
.nest("/_matrix/client", client_server)

View File

@ -35,6 +35,20 @@ impl Device {
)
}
pub async fn find_for_user(
conn: &SqlitePool,
user: &User,
device_id: &str,
) -> anyhow::Result<Self> {
let user_uuid = user.uuid();
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?",
user_uuid)
.fetch_one(conn).await?
)
}
/// Get the device's id.
#[must_use]
pub fn uuid(&self) -> &Uuid {

View File

@ -1,3 +1,4 @@
use rand::{thread_rng, Rng, distributions::Alphanumeric};
use sqlx::SqlitePool;
use crate::types::uuid::Uuid;
@ -12,10 +13,13 @@ pub struct Session {
impl Session {
pub fn new(device: &Device) -> anyhow::Result<Self> {
let mut rng = thread_rng();
let key: String = (0..32).map(|_| rng.sample(Alphanumeric) as char).collect();
Ok(Self {
uuid: uuid::Uuid::new_v4().into(),
device_uuid: device.uuid().clone(),
key: String::new(),
key,
})
}

View File

@ -1,6 +1,6 @@
use crate::types::uuid::Uuid;
use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier};
use rand_core::OsRng;
use rand::rngs::OsRng;
use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool};
use crate::types::user_id::UserId;

View File

@ -44,7 +44,7 @@ impl AuthenticationPassword {
/// Get a mutable reference to the authentication password's initial device display name.
#[must_use]
pub fn initial_device_display_name_mut(&self) -> Option<&String> {
pub fn initial_device_display_name(&self) -> Option<&String> {
self.initial_device_display_name.as_ref()
}
}

View File

@ -6,3 +6,4 @@ pub mod identifier_type;
pub mod user_id;
pub mod user_interactive_authorization;
pub mod uuid;
pub mod server_name;

151
src/types/server_name.rs Normal file
View File

@ -0,0 +1,151 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
#[derive(Debug, PartialEq, Eq)]
enum Hostname {
IPv4(Ipv4Addr),
IPv6(Ipv6Addr),
FQDN(String),
}
pub struct ServerName {
hostname: Hostname,
port: Option<u16>,
}
impl ServerName {
pub fn new(server_name: &str) -> anyhow::Result<Self> {
if let Ok(addr) = server_name.parse::<Ipv4Addr>() {
return Ok(Self {
hostname: Hostname::IPv4(addr),
port: None,
});
};
if let Ok(addr) = server_name.parse::<SocketAddrV4>() {
return Ok(Self {
hostname: Hostname::IPv4(addr.ip().to_owned()),
port: Some(addr.port()),
});
};
if server_name.ends_with(']') {
if let Ok(addr) = server_name[1..server_name.chars().count() - 1].parse::<Ipv6Addr>() {
return Ok(Self {
hostname: Hostname::IPv6(addr),
port: None,
});
};
}
if let Ok(addr) = server_name.parse::<SocketAddrV6>() {
return Ok(Self {
hostname: Hostname::IPv6(addr.ip().to_owned()),
port: Some(addr.port()),
});
};
if let Some(idx) = server_name.find(':') {
let (hostname, port) = server_name.split_at(idx);
Ok(Self {
hostname: Hostname::FQDN(hostname.to_owned()),
port: Some(port[1..].parse()?),
})
} else {
Ok(Self {
hostname: Hostname::FQDN(server_name.to_owned()),
port: None,
})
}
}
/// Get a reference to the server name's hostname.
#[must_use]
fn hostname(&self) -> &Hostname {
&self.hostname
}
/// Get the server name's port.
#[must_use]
fn port(&self) -> Option<u16> {
self.port
}
}
impl std::fmt::Display for ServerName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.hostname {
Hostname::IPv4(hostname) => write!(f, "{hostname}"),
Hostname::IPv6(hostname) => write!(f, "[{hostname}]"),
Hostname::FQDN(hostname) => write!(f, "{hostname}"),
}?;
if let Some(port) = self.port {
write!(f, ":{port}")?;
}
Ok(())
}
}
mod tests {
use super::*;
#[test]
fn parse_ipv4_without_port() {
let server_name = ServerName::new("127.0.0.1").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port(), None);
}
#[test]
fn parse_ipv4_with_port() {
let server_name = ServerName::new("127.0.0.1:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port(), Some(8080));
}
#[test]
fn parse_ipv6_without_port() {
let server_name = ServerName::new("[::1]").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv6("::1".parse().unwrap())
);
assert_eq!(server_name.port(), None);
}
#[test]
fn parse_ipv6_with_port() {
let server_name = ServerName::new("[::1]:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv6("::1".parse().unwrap())
);
assert_eq!(server_name.port(), Some(8080));
}
#[test]
fn parse_fqdn_without_port() {
let server_name = ServerName::new("example.com").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::FQDN("example.com".into())
);
assert_eq!(server_name.port(), None);
}
#[test]
fn parse_fqdn_with_port() {
let server_name = ServerName::new("example.com:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::FQDN("example.com".into())
);
assert_eq!(server_name.port(), Some(8080));
}
}

View File

@ -2,6 +2,8 @@ use std::fmt::Display;
use sqlx::{encode::IsNull, Sqlite};
use super::server_name::ServerName;
#[derive(Clone)]
#[repr(transparent)]
pub struct UserId(String);
@ -36,7 +38,7 @@ impl<'d> sqlx::Decode<'d, Sqlite> for UserId {
}
impl UserId {
pub fn new(name: &str, server_name: &str) -> anyhow::Result<Self> {
pub fn new(name: &str, server_name: &ServerName) -> anyhow::Result<Self> {
let user_id = Self(format!("@{name}:{server_name}"));
user_id.is_valid()?;