diff --git a/.drone.yml b/.drone.yml index 8f0fe56..44b339b 100644 --- a/.drone.yml +++ b/.drone.yml @@ -8,4 +8,10 @@ steps: environment: SQLX_OFFLINE: 'true' commands: - - cargo check \ No newline at end of file + - cargo check +- name: cargo test + image: rust:latest + environment: + SQLX_OFFLINE: 'true' + commands: + - cargo test \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 25f9b25..27c20fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -604,7 +604,7 @@ dependencies = [ "anyhow", "argon2", "axum", - "rand_core", + "rand", "serde", "sqlx", "thiserror", @@ -809,6 +809,12 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + [[package]] name = "proc-macro2" version = "1.0.37" @@ -827,6 +833,27 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + [[package]] name = "rand_core" version = "0.6.3" diff --git a/Cargo.toml b/Cargo.toml index e82c226..4e81b08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,5 +16,5 @@ sqlx = { version = "0.5", features = ["sqlite", "macros", "runtime-tokio-rustls" anyhow = "1.0" thiserror = "1.0" argon2 = { version = "0.4", features = ["std"] } -rand_core = { version = "0.6", features = ["std"] } +rand = { version = "0.8.5", features = ["std"] } uuid = { version = "1.0", features = ["v4"] } \ No newline at end of file diff --git a/src/api/client_server/errors/registration_error.rs b/src/api/client_server/errors/registration_error.rs index 0a8d577..8cf4a51 100644 --- a/src/api/client_server/errors/registration_error.rs +++ b/src/api/client_server/errors/registration_error.rs @@ -14,6 +14,8 @@ pub enum RegistrationError { InvalidUserId, #[error("The desired user ID is already taken")] UserIdTaken, + #[error("Registration is disabled")] + RegistrationDisabled, } impl IntoResponse for RegistrationError { @@ -42,6 +44,15 @@ impl IntoResponse for RegistrationError { )), ) .into_response(), + RegistrationError::RegistrationDisabled => ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new( + ErrorCode::Forbidden, + &self.to_string(), + None, + )), + ) + .into_response(), } } } diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index 1840809..9d09faf 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -1,3 +1,3 @@ -pub mod auth; pub mod errors; +pub mod r0; pub mod versions; diff --git a/src/api/client_server/auth.rs b/src/api/client_server/r0/auth.rs similarity index 70% rename from src/api/client_server/auth.rs rename to src/api/client_server/r0/auth.rs index 5965d61..6392388 100644 --- a/src/api/client_server/auth.rs +++ b/src/api/client_server/r0/auth.rs @@ -7,16 +7,19 @@ use axum::{ routing::{get, post}, Extension, Json, }; -use rand_core::OsRng; use sqlx::SqlitePool; use crate::{ - api::client_server::errors::authentication_error::AuthenticationError, + api::client_server::errors::{ + api_error::ApiError, authentication_error::AuthenticationError, + registration_error::RegistrationError, + }, models::sessions::Session, responses::{ authentication::{AuthenticationResponse, AuthenticationSuccess}, registration::RegistrationResponse, }, + types::uuid::Uuid, }; use crate::{ models::devices::Device, @@ -30,8 +33,6 @@ use crate::{ Config, }; -use super::errors::{api_error::ApiError, registration_error::RegistrationError}; - pub fn routes() -> axum::Router { axum::Router::new() .route("/r0/login", get(get_login).post(post_login)) @@ -53,17 +54,34 @@ async fn post_login( match body { AuthenticationData::Password(auth_data) => { let user = auth_data.user().unwrap(); - let user_id = UserId::new(&user, config.homeserver_name()) + let user_id = UserId::new(&user, config.server_name()) .ok() .ok_or(AuthenticationError::InvalidUserId)?; - + let user = User::find_by_user_id(&db, &user_id).await?; - user.password_correct(auth_data.password()).ok().ok_or(AuthenticationError::Forbidden)?; + user.password_correct(auth_data.password()) + .ok() + .ok_or(AuthenticationError::Forbidden)?; - todo!("find_or_create device for user and create a session"); + let device = if let Some(device_id) = auth_data.device_id() { + Device::find_for_user(&db, &user, device_id).await? + } else { + let device_id = uuid::Uuid::new_v4().to_string(); + let display_name = + if let Some(display_name) = auth_data.initial_device_display_name() { + display_name.as_ref() + } else { + "Generic Device" + }; + Device::new(&user, &device_id, display_name)? + .create(&db) + .await? + }; - let resp = AuthenticationSuccess::new("", "", &user_id); + let session = Session::new(&device)?.create(&db).await?; + + let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id); Ok(Json(AuthenticationResponse::Success(resp))) } } @@ -78,7 +96,7 @@ async fn get_username_available( let username = params .get("username") .ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.homeserver_name()) + let user_id = UserId::new(username, &config.server_name()) .ok() .ok_or(RegistrationError::InvalidUserId)?; let exists = User::exists(&db, &user_id).await?; @@ -92,13 +110,14 @@ async fn post_register( Extension(db): Extension, Json(body): Json, ) -> Result, ApiError> { + config.enable_registration().then(|| true).ok_or(RegistrationError::RegistrationDisabled)?; body.auth() .ok_or(RegistrationError::AdditionalAuthenticationInformation)?; let (user, device) = match &body.auth().expect("must be Some") { AuthenticationData::Password(auth_data) => { let username = body.username().ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.homeserver_name()) + let user_id = UserId::new(username, &config.server_name()) .ok() .ok_or(RegistrationError::InvalidUserId)?; @@ -115,9 +134,13 @@ async fn post_register( .create(&db) .await?; - let device = Device::new(&user, "test", display_name)? - .create(&db) - .await?; + let device = Device::new( + &user, + uuid::Uuid::new_v4().to_string().as_ref(), + display_name, + )? + .create(&db) + .await?; (user, device) } diff --git a/src/api/client_server/r0/mod.rs b/src/api/client_server/r0/mod.rs new file mode 100644 index 0000000..5696e21 --- /dev/null +++ b/src/api/client_server/r0/mod.rs @@ -0,0 +1 @@ +pub mod auth; \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 606722e..b5260c9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,9 @@ +use crate::types::server_name::ServerName; + pub struct Config { db_path: String, - homeserver_name: String, + server_name: ServerName, + enable_registration: bool } impl Config { @@ -12,8 +15,14 @@ impl Config { /// Get a reference to the config's homeserver name. #[must_use] - pub fn homeserver_name(&self) -> &str { - self.homeserver_name.as_ref() + pub fn server_name(&self) -> &ServerName { + &self.server_name + } + + /// Get the config's enable registration. + #[must_use] + pub fn enable_registration(&self) -> bool { + self.enable_registration } } @@ -21,7 +30,8 @@ impl Default for Config { fn default() -> Self { Self { db_path: "sqlite://db.sqlite3".into(), - homeserver_name: "fuckwit.dev".into(), + server_name: ServerName::new("fuckwit.dev").unwrap(), + enable_registration: true } } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 38af7dc..9268d73 100644 --- a/src/main.rs +++ b/src/main.rs @@ -43,7 +43,7 @@ async fn main() -> anyhow::Result<()> { let client_server = Router::new() .merge(api::client_server::versions::routes()) - .merge(api::client_server::auth::routes()); + .merge(api::client_server::r0::auth::routes()); let router = Router::new() .nest("/_matrix/client", client_server) diff --git a/src/models/devices.rs b/src/models/devices.rs index 079dcaa..cc8de1c 100644 --- a/src/models/devices.rs +++ b/src/models/devices.rs @@ -35,6 +35,20 @@ impl Device { ) } + pub async fn find_for_user( + conn: &SqlitePool, + user: &User, + device_id: &str, + ) -> anyhow::Result { + let user_uuid = user.uuid(); + Ok(sqlx::query_as!( + Self, + "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?", + user_uuid) + .fetch_one(conn).await? + ) + } + /// Get the device's id. #[must_use] pub fn uuid(&self) -> &Uuid { diff --git a/src/models/sessions.rs b/src/models/sessions.rs index c1d0907..a5fd718 100644 --- a/src/models/sessions.rs +++ b/src/models/sessions.rs @@ -1,3 +1,4 @@ +use rand::{thread_rng, Rng, distributions::Alphanumeric}; use sqlx::SqlitePool; use crate::types::uuid::Uuid; @@ -12,10 +13,13 @@ pub struct Session { impl Session { pub fn new(device: &Device) -> anyhow::Result { + let mut rng = thread_rng(); + let key: String = (0..32).map(|_| rng.sample(Alphanumeric) as char).collect(); + Ok(Self { uuid: uuid::Uuid::new_v4().into(), device_uuid: device.uuid().clone(), - key: String::new(), + key, }) } diff --git a/src/models/users.rs b/src/models/users.rs index 6d92c92..80a0e0b 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -1,6 +1,6 @@ use crate::types::uuid::Uuid; use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier}; -use rand_core::OsRng; +use rand::rngs::OsRng; use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool}; use crate::types::user_id::UserId; diff --git a/src/types/authentication_data.rs b/src/types/authentication_data.rs index cad34fa..3d27032 100644 --- a/src/types/authentication_data.rs +++ b/src/types/authentication_data.rs @@ -44,7 +44,7 @@ impl AuthenticationPassword { /// Get a mutable reference to the authentication password's initial device display name. #[must_use] - pub fn initial_device_display_name_mut(&self) -> Option<&String> { + pub fn initial_device_display_name(&self) -> Option<&String> { self.initial_device_display_name.as_ref() } } diff --git a/src/types/mod.rs b/src/types/mod.rs index c9b765b..a92ff25 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -6,3 +6,4 @@ pub mod identifier_type; pub mod user_id; pub mod user_interactive_authorization; pub mod uuid; +pub mod server_name; \ No newline at end of file diff --git a/src/types/server_name.rs b/src/types/server_name.rs new file mode 100644 index 0000000..8664f9a --- /dev/null +++ b/src/types/server_name.rs @@ -0,0 +1,151 @@ +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; + +#[derive(Debug, PartialEq, Eq)] +enum Hostname { + IPv4(Ipv4Addr), + IPv6(Ipv6Addr), + FQDN(String), +} + +pub struct ServerName { + hostname: Hostname, + port: Option, +} + +impl ServerName { + pub fn new(server_name: &str) -> anyhow::Result { + if let Ok(addr) = server_name.parse::() { + return Ok(Self { + hostname: Hostname::IPv4(addr), + port: None, + }); + }; + + if let Ok(addr) = server_name.parse::() { + return Ok(Self { + hostname: Hostname::IPv4(addr.ip().to_owned()), + port: Some(addr.port()), + }); + }; + + if server_name.ends_with(']') { + if let Ok(addr) = server_name[1..server_name.chars().count() - 1].parse::() { + return Ok(Self { + hostname: Hostname::IPv6(addr), + port: None, + }); + }; + } + + if let Ok(addr) = server_name.parse::() { + return Ok(Self { + hostname: Hostname::IPv6(addr.ip().to_owned()), + port: Some(addr.port()), + }); + }; + + if let Some(idx) = server_name.find(':') { + let (hostname, port) = server_name.split_at(idx); + Ok(Self { + hostname: Hostname::FQDN(hostname.to_owned()), + port: Some(port[1..].parse()?), + }) + } else { + Ok(Self { + hostname: Hostname::FQDN(server_name.to_owned()), + port: None, + }) + } + } + + /// Get a reference to the server name's hostname. + #[must_use] + fn hostname(&self) -> &Hostname { + &self.hostname + } + + /// Get the server name's port. + #[must_use] + fn port(&self) -> Option { + self.port + } +} + +impl std::fmt::Display for ServerName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.hostname { + Hostname::IPv4(hostname) => write!(f, "{hostname}"), + Hostname::IPv6(hostname) => write!(f, "[{hostname}]"), + Hostname::FQDN(hostname) => write!(f, "{hostname}"), + }?; + + if let Some(port) = self.port { + write!(f, ":{port}")?; + } + Ok(()) + } +} + +mod tests { + use super::*; + + #[test] + fn parse_ipv4_without_port() { + let server_name = ServerName::new("127.0.0.1").unwrap(); + assert_eq!( + server_name.hostname(), + &Hostname::IPv4("127.0.0.1".parse().unwrap()) + ); + assert_eq!(server_name.port(), None); + } + + #[test] + fn parse_ipv4_with_port() { + let server_name = ServerName::new("127.0.0.1:8080").unwrap(); + assert_eq!( + server_name.hostname(), + &Hostname::IPv4("127.0.0.1".parse().unwrap()) + ); + assert_eq!(server_name.port(), Some(8080)); + } + + #[test] + fn parse_ipv6_without_port() { + let server_name = ServerName::new("[::1]").unwrap(); + assert_eq!( + server_name.hostname(), + &Hostname::IPv6("::1".parse().unwrap()) + ); + assert_eq!(server_name.port(), None); + } + + #[test] + fn parse_ipv6_with_port() { + let server_name = ServerName::new("[::1]:8080").unwrap(); + assert_eq!( + server_name.hostname(), + &Hostname::IPv6("::1".parse().unwrap()) + ); + assert_eq!(server_name.port(), Some(8080)); + } + + #[test] + fn parse_fqdn_without_port() { + let server_name = ServerName::new("example.com").unwrap(); + assert_eq!( + server_name.hostname(), + &Hostname::FQDN("example.com".into()) + ); + assert_eq!(server_name.port(), None); + } + + #[test] + fn parse_fqdn_with_port() { + let server_name = ServerName::new("example.com:8080").unwrap(); + assert_eq!( + server_name.hostname(), + &Hostname::FQDN("example.com".into()) + ); + assert_eq!(server_name.port(), Some(8080)); + } +} diff --git a/src/types/user_id.rs b/src/types/user_id.rs index 8bda56a..1f4d926 100644 --- a/src/types/user_id.rs +++ b/src/types/user_id.rs @@ -2,6 +2,8 @@ use std::fmt::Display; use sqlx::{encode::IsNull, Sqlite}; +use super::server_name::ServerName; + #[derive(Clone)] #[repr(transparent)] pub struct UserId(String); @@ -36,7 +38,7 @@ impl<'d> sqlx::Decode<'d, Sqlite> for UserId { } impl UserId { - pub fn new(name: &str, server_name: &str) -> anyhow::Result { + pub fn new(name: &str, server_name: &ServerName) -> anyhow::Result { let user_id = Self(format!("@{name}:{server_name}")); user_id.is_valid()?;