diff --git a/migrations/20220423204756_add_user.sql b/migrations/20220423204756_add_user.sql index 031e081..adde1cc 100644 --- a/migrations/20220423204756_add_user.sql +++ b/migrations/20220423204756_add_user.sql @@ -3,7 +3,7 @@ CREATE TABLE users( uuid TEXT PRIMARY KEY NOT NULL, user_id CHAR(255) NOT NULL, display_name TEXT NOT NULL, - password TEXT NOT NULL + password_hash TEXT NOT NULL ); CREATE INDEX user_id_index ON users (user_id); \ No newline at end of file diff --git a/sqlx-data.json b/sqlx-data.json index ebd1741..6ab5411 100644 --- a/sqlx-data.json +++ b/sqlx-data.json @@ -102,7 +102,7 @@ }, "query": "insert into sessions(uuid, device_uuid, key)\n values(?, ?, ?)\n returning uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key" }, - "383949b72c69bca95bf23ef06900cd1ac5a136cdd4a525cbb624d327ce0cdefb": { + "2b3409859921423dc051ce76a0166116f39ca7f26053bac5bde0a61313bfd68c": { "describe": { "columns": [ { @@ -111,7 +111,7 @@ "type_info": "Text" }, { - "name": "user_id: UserId", + "name": "user_id", "ordinal": 1, "type_info": "Text" }, @@ -121,7 +121,7 @@ "type_info": "Text" }, { - "name": "password", + "name": "password_hash", "ordinal": 3, "type_info": "Text" } @@ -136,9 +136,9 @@ "Right": 1 } }, - "query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where uuid = ?" + "query": "select uuid as 'uuid: Uuid', user_id, display_name, password_hash\n from users where user_id = ?" }, - "3fead3dac0e110757bc30be40bb0c6c2bc02127b6d9b6145bfc40fa5fe22ad06": { + "33f7c796b21878b2f06f3e012ada151226bd1ab58677ca6acc4edb10e0e1493a": { "describe": { "columns": [ { @@ -147,7 +147,7 @@ "type_info": "Text" }, { - "name": "user_id: UserId", + "name": "user_id", "ordinal": 1, "type_info": "Text" }, @@ -157,7 +157,7 @@ "type_info": "Text" }, { - "name": "password", + "name": "password_hash", "ordinal": 3, "type_info": "Text" } @@ -172,7 +172,7 @@ "Right": 4 } }, - "query": "insert into users(uuid, user_id, display_name, password)\n values (?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password" + "query": "insert into users(uuid, user_id, display_name, password_hash)\n values (?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_id, display_name, password_hash" }, "58d27b1d424297504f1da2e3b9b4020121251c1155fbf5dc870dafbef97659f3": { "describe": { @@ -192,7 +192,7 @@ }, "query": "select user_id from users where user_id = ?" }, - "778b7f0a1c66f00812f0232a5904b7b6b295720ebb75d1c2720afeeda4f66936": { + "9673bbe9506ba700923467fe8aaa141f9030158790db74234c13a5800adf2575": { "describe": { "columns": [ { @@ -201,7 +201,7 @@ "type_info": "Text" }, { - "name": "user_id: UserId", + "name": "user_id", "ordinal": 1, "type_info": "Text" }, @@ -211,7 +211,7 @@ "type_info": "Text" }, { - "name": "password", + "name": "password_hash", "ordinal": 3, "type_info": "Text" } @@ -223,10 +223,46 @@ false ], "parameters": { - "Right": 5 + "Right": 1 } }, - "query": "update users set uuid = ?, user_id = ?, display_name = ?, password = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password" + "query": "select uuid as 'uuid: Uuid', user_id, display_name, password_hash\n from users where uuid = ?" + }, + "9ee4afab2c653a23144bcb05943aa4ff1e8dc1ae5baa9c87827b52671ae47784": { + "describe": { + "columns": [ + { + "name": "uuid: Uuid", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "user_uuid: Uuid", + "ordinal": 1, + "type_info": "Int64" + }, + { + "name": "device_id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "display_name", + "ordinal": 3, + "type_info": "Text" + } + ], + "nullable": [ + false, + false, + false, + false + ], + "parameters": { + "Right": 2 + } + }, + "query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ? and device_id = ?" }, "b38fd90504bea0c63e6517738c2354e6b057fcc6c643283019b27689e286bf2d": { "describe": { @@ -258,7 +294,7 @@ }, "query": "select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key\n from sessions where key = ?" }, - "ddcc531c080b2a1c70166d29a940aaada6701abe2933c305a879e7f18baeaf3a": { + "f1b148ebcfe22d9680b8ea3dc3c334523496e88fca724ec7d08c5e2948e58526": { "describe": { "columns": [ { @@ -267,17 +303,17 @@ "type_info": "Text" }, { - "name": "user_uuid: Uuid", + "name": "user_id", "ordinal": 1, - "type_info": "Int64" - }, - { - "name": "device_id", - "ordinal": 2, "type_info": "Text" }, { "name": "display_name", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "password_hash", "ordinal": 3, "type_info": "Text" } @@ -289,45 +325,9 @@ false ], "parameters": { - "Right": 1 + "Right": 5 } }, - "query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?" - }, - "f57c49e5390c81f971851ff9ab35242a472b9efbb1ffa658de9b102188769750": { - "describe": { - "columns": [ - { - "name": "uuid: Uuid", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "user_id: UserId", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "display_name", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "password", - "ordinal": 3, - "type_info": "Text" - } - ], - "nullable": [ - false, - false, - false, - false - ], - "parameters": { - "Right": 1 - } - }, - "query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where user_id = ?" + "query": "update users set uuid = ?, user_id = ?, display_name = ?, password_hash = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id, display_name, password_hash" } } \ No newline at end of file diff --git a/src/api/client_server/errors/api_error.rs b/src/api/client_server/errors/api_error.rs index 4c10970..f2770e2 100644 --- a/src/api/client_server/errors/api_error.rs +++ b/src/api/client_server/errors/api_error.rs @@ -1,9 +1,7 @@ use axum::http::StatusCode; use axum::response::IntoResponse; use axum::Json; -use sqlx::Statement; -use crate::responses::registration::RegistrationResponse; use crate::types::error_code::ErrorCode; use super::authentication_error::AuthenticationError; diff --git a/src/api/client_server/r0/auth.rs b/src/api/client_server/r0/auth.rs index 664a3b3..aa8ee35 100644 --- a/src/api/client_server/r0/auth.rs +++ b/src/api/client_server/r0/auth.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use axum::{ extract::Query, routing::{get, post}, - Extension, Json, + Extension, }; use sqlx::SqlitePool; @@ -12,23 +12,14 @@ use crate::{ api_error::ApiError, authentication_error::AuthenticationError, registration_error::RegistrationError, }, - models::sessions::Session, - responses::{ - authentication::{AuthenticationResponse, AuthenticationSuccess}, - registration::RegistrationResponse, - }, - ruma_wrapper::RumaResponse, -}; -use crate::{models::devices::Device, responses::registration::RegistrationSuccess}; -use crate::{ - models::users::User, - requests::registration::RegistrationRequest, - types::{authentication_data::AuthenticationData, user_id::UserId}, + models::{devices::Device, sessions::Session, users::User}, + ruma_wrapper::{RumaRequest, RumaResponse}, Config, }; + use ruma::api::client::{ - account, - session::get_login_types::v3::{LoginType, PasswordLoginType}, + account, session, + uiaa::{IncomingAuthData, IncomingUserIdentifier}, }; pub fn routes() -> axum::Router { @@ -38,9 +29,10 @@ pub fn routes() -> axum::Router { .route("/r0/register/available", get(get_username_available)) } -use ruma::api::client::session; #[tracing::instrument] async fn get_login() -> Result, ApiError> { + use session::get_login_types::v3::*; + Ok(RumaResponse(session::get_login_types::v3::Response::new( vec![LoginType::Password(PasswordLoginType::new())], ))) @@ -48,43 +40,50 @@ async fn get_login() -> Result>, Extension(db): Extension, - Json(body): Json, -) -> Result, ApiError> { - match body { - AuthenticationData::Password(auth_data) => { - let user = auth_data.user().unwrap(); - let user_id = UserId::new(user, config.server_name()) - .ok() - .ok_or(AuthenticationError::InvalidUserId)?; + RumaRequest(req): RumaRequest, +) -> Result, ApiError> { + use session::login::v3::*; - let user = User::find_by_user_id(&db, &user_id).await?; + match req.login_info { + IncomingLoginInfo::Password(incoming_password) => { + let password = incoming_password.password; + let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) = + incoming_password.identifier + { + ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))? + } else { + return Err(AuthenticationError::InvalidUserId.into()) + }; - user.password_correct(auth_data.password()) + let db_user = User::find_by_user_id(&db, user_id.as_str()).await?; + db_user + .password_correct(&password) .ok() .ok_or(AuthenticationError::Forbidden)?; - let device = if let Some(device_id) = auth_data.device_id() { - Device::find_for_user(&db, &user, device_id).await? + let device = if let Some(device_id) = req.device_id { + Device::find_for_user(&db, &db_user, device_id.as_str()).await? } else { let device_id = uuid::Uuid::new_v4().to_string(); - let display_name = - if let Some(display_name) = auth_data.initial_device_display_name() { - display_name.as_ref() - } else { - "Generic Device" - }; - Device::new(&user, &device_id, display_name)? + let display_name = req + .initial_device_display_name + .unwrap_or_else(|| "Generic Device".into()); + Device::new(&db_user, &device_id, &display_name)? .create(&db) .await? }; let session = Session::new(&device)?.create(&db).await?; + let response = Response::new( + user_id, + session.key, + ruma::OwnedDeviceId::from(device.device_id), + ); - let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id); - Ok(Json(AuthenticationResponse::Success(resp))) + return Ok(RumaResponse(response)); } + _ => todo!(), } } @@ -94,76 +93,77 @@ async fn get_username_available( Extension(db): Extension, Query(params): Query>, ) -> Result, ApiError> { + use account::get_username_availability::v3::*; + let username = params .get("username") .ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, config.server_name()) - .ok() - .ok_or(RegistrationError::InvalidUserId)?; + let user_id = ruma::UserId::parse(username).map_err(|_| RegistrationError::InvalidUserId)?; let exists = User::exists(&db, &user_id).await?; - Ok(RumaResponse( - account::get_username_availability::v3::Response::new(!exists), - )) + Ok(RumaResponse(Response::new(!exists))) } #[tracing::instrument(skip_all)] async fn post_register( Extension(config): Extension>, Extension(db): Extension, - Json(body): Json, -) -> Result, ApiError> { + RumaRequest(req): RumaRequest, +) -> Result, ApiError> { + use account::register::v3::*; + config .enable_registration() .then(|| true) .ok_or(RegistrationError::RegistrationDisabled)?; - body.auth() - .ok_or(RegistrationError::AdditionalAuthenticationInformation)?; - let (user, device) = match &body.auth().expect("must be Some") { - AuthenticationData::Password(auth_data) => { - let username = body.username().ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, config.server_name()) - .ok() - .ok_or(RegistrationError::InvalidUserId)?; + match req.auth { + Some(auth) => match auth { + IncomingAuthData::Password(incoming_password) => { + let password = incoming_password.password; + let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) = + incoming_password.identifier + { + ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))? + } else { + Err(AuthenticationError::InvalidUserId)? + }; - if User::exists(&db, &user_id).await? { - return Err(ApiError::from(RegistrationError::UserIdTaken)); - } + if User::exists(&db, &user_id).await? { + return Err(ApiError::from(RegistrationError::UserIdTaken)); + } - let display_name = match body.initial_device_display_name() { - Some(display_name) => display_name.as_ref(), - None => "Random displayname", - }; + let display_name = req + .initial_device_display_name + .unwrap_or_else(|| "Random Display Name".into()); - let user = User::new(&user_id, &user_id.to_string(), auth_data.password())? + let user = User::new(&user_id, &user_id.to_string(), &password)? + .create(&db) + .await?; + + let device = Device::new( + &user, + uuid::Uuid::new_v4().to_string().as_ref(), + &display_name, + )? .create(&db) .await?; + let mut response = Response::new( + ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?, + ); + if !req.inhibit_login { + let session = Session::new(&device)?.create(&db).await?; + response.access_token = Some(session.key); + } + if !req.inhibit_login { + response.device_id = Some(ruma::OwnedDeviceId::from(device.device_id)); + } - let device = Device::new( - &user, - uuid::Uuid::new_v4().to_string().as_ref(), - display_name, - )? - .create(&db) - .await?; - - (user, device) - } + return Ok(RumaResponse(response)); + } + _ => todo!(), + }, + None => Err(RegistrationError::AdditionalAuthenticationInformation)?, }; - - if body.inhibit_login().unwrap_or(false) { - let resp = RegistrationSuccess::new(None, device.device_id(), &user.user_id().to_string()); - - Ok(Json(RegistrationResponse::Success(resp))) - } else { - let session = Session::new(&device)?.create(&db).await?; - let resp = RegistrationSuccess::new( - Some(session.key()), - device.device_id(), - &user.user_id().to_string(), - ); - - Ok(Json(RegistrationResponse::Success(resp))) - } + unreachable!() } diff --git a/src/api/client_server/r0/mod.rs b/src/api/client_server/r0/mod.rs index de478da..ee52723 100644 --- a/src/api/client_server/r0/mod.rs +++ b/src/api/client_server/r0/mod.rs @@ -13,8 +13,8 @@ use crate::{models::sessions::Session, types::error_code::ErrorCode}; use super::errors::ErrorResponse; pub mod auth; -pub mod thirdparty; pub mod create_room; +pub mod thirdparty; async fn authentication_middleware(mut req: Request, next: Next) -> impl IntoResponse { let db: &SqlitePool = req.extensions().get().unwrap(); diff --git a/src/api/client_server/versions.rs b/src/api/client_server/versions.rs index 07c90fe..18967de 100644 --- a/src/api/client_server/versions.rs +++ b/src/api/client_server/versions.rs @@ -1,12 +1,16 @@ -use axum::{routing::get, Json}; +use axum::routing::get; -use crate::responses::versions::Versions; +use crate::ruma_wrapper::RumaResponse; pub fn routes() -> axum::Router { axum::Router::new().route("/versions", get(get_client_versions)) } +use ruma::api::client::discovery; + #[tracing::instrument] -async fn get_client_versions() -> Json { - Json(Versions::default()) +async fn get_client_versions() -> RumaResponse { + use discovery::get_supported_versions::*; + + RumaResponse(Response::new(vec!["v1.2".into()])) } diff --git a/src/config.rs b/src/config.rs index b5260c9..971118a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ use crate::types::server_name::ServerName; pub struct Config { db_path: String, server_name: ServerName, - enable_registration: bool + enable_registration: bool, } impl Config { @@ -31,7 +31,7 @@ impl Default for Config { Self { db_path: "sqlite://db.sqlite3".into(), server_name: ServerName::new("fuckwit.dev").unwrap(), - enable_registration: true + enable_registration: true, } } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 83101b5..a154ad3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,10 +12,8 @@ use tower_http::{cors::CorsLayer, trace::TraceLayer}; mod api; mod config; mod models; -mod requests; mod responses; mod ruma_wrapper; -mod state_resolution; mod types; #[tokio::main] diff --git a/src/models/devices.rs b/src/models/devices.rs index 9b8d2c9..9e9d213 100644 --- a/src/models/devices.rs +++ b/src/models/devices.rs @@ -2,20 +2,20 @@ use sqlx::SqlitePool; use crate::types::uuid::Uuid; -use super::{sessions::Session, users::User}; +use super::users::User; pub struct Device { - uuid: Uuid, - user_uuid: Uuid, - device_id: String, - display_name: String, + pub uuid: Uuid, + pub user_uuid: Uuid, + pub device_id: String, + pub display_name: String, } impl Device { pub fn new(user: &User, device_id: &str, display_name: &str) -> anyhow::Result { Ok(Self { uuid: uuid::Uuid::new_v4().into(), - user_uuid: user.uuid().clone(), + user_uuid: user.uuid.clone(), device_id: device_id.to_owned(), display_name: display_name.to_owned(), }) @@ -40,11 +40,11 @@ impl Device { user: &User, device_id: &str, ) -> anyhow::Result { - let user_uuid = user.uuid(); + let user_uuid = user.uuid.clone(); Ok(sqlx::query_as!( Self, - "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?", - user_uuid) + "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ? and device_id = ?", + user_uuid, device_id) .fetch_one(conn).await? ) } @@ -67,22 +67,4 @@ impl Device { pub fn uuid(&self) -> &Uuid { &self.uuid } - - /// Get the device's user id. - #[must_use] - pub fn user_uuid(&self) -> &Uuid { - &self.user_uuid - } - - /// Get a reference to the device's device id. - #[must_use] - pub fn device_id(&self) -> &str { - self.device_id.as_ref() - } - - /// Get a reference to the device's display name. - #[must_use] - pub fn display_name(&self) -> &str { - self.display_name.as_ref() - } } diff --git a/src/models/events.rs b/src/models/events.rs index 1ea5dc9..5c90909 100644 --- a/src/models/events.rs +++ b/src/models/events.rs @@ -1,8 +1,8 @@ -use sqlx::SqlitePool; +/* use sqlx::SqlitePool; use crate::types::{uuid::Uuid, event_type::EventType}; -use super::{rooms::Room, users::User}; +use super::{users::User}; #[derive(Debug, PartialEq, Eq, Hash)] pub struct Event { @@ -70,3 +70,4 @@ impl Event { serde_json::from_str(&self.content).expect("has to be valid json") } } + */ diff --git a/src/models/mod.rs b/src/models/mod.rs index 4abd50f..61b1e5c 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,5 +1,4 @@ pub mod devices; pub mod events; -pub mod rooms; pub mod sessions; pub mod users; diff --git a/src/models/rooms.rs b/src/models/rooms.rs deleted file mode 100644 index 8fe6b44..0000000 --- a/src/models/rooms.rs +++ /dev/null @@ -1,42 +0,0 @@ -use sqlx::SqlitePool; - -use crate::types::uuid::Uuid; - -use super::events::Event; - -pub struct Room { - uuid: Uuid, - name: String, -} - -impl Room { - fn new(name: &str) -> anyhow::Result { - Ok(Self { - uuid: uuid::Uuid::new_v4().into(), - name: name.to_owned(), - }) - } - - pub async fn create(&self, conn: &SqlitePool) -> anyhow::Result { - Ok(sqlx::query_as!( - Self, - "insert into rooms(uuid, name) - values(?, ?) - returning uuid as 'uuid: Uuid', name", - self.uuid, - self.name - ) - .fetch_one(conn) - .await?) - } - - pub async fn events(&self, conn: &SqlitePool) -> anyhow::Result> { - Event::all_for_room(conn, self).await - } - - /// Get a reference to the room's uuid. - #[must_use] - pub fn uuid(&self) -> &Uuid { - &self.uuid - } -} diff --git a/src/models/sessions.rs b/src/models/sessions.rs index 226f24e..70c4b4c 100644 --- a/src/models/sessions.rs +++ b/src/models/sessions.rs @@ -6,9 +6,9 @@ use crate::types::uuid::Uuid; use super::devices::Device; pub struct Session { - uuid: Uuid, - device_uuid: Uuid, - key: String, + pub uuid: Uuid, + pub device_uuid: Uuid, + pub key: String, } impl Session { @@ -51,22 +51,4 @@ impl Session { pub async fn device(&self, conn: &SqlitePool) -> anyhow::Result { Device::find_by_uuid(conn, &self.device_uuid).await } - - /// Get the session's id. - #[must_use] - pub fn uuid(&self) -> &Uuid { - &self.uuid - } - - /// Get the session's device id. - #[must_use] - pub fn device_uuid(&self) -> &Uuid { - &self.device_uuid - } - - /// Get a reference to the session's value. - #[must_use] - pub fn key(&self) -> &str { - self.key.as_ref() - } } diff --git a/src/models/users.rs b/src/models/users.rs index 7b0a5d0..1bcad06 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -1,35 +1,35 @@ use crate::types::uuid::Uuid; use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; use rand::rngs::OsRng; -use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool}; - -use crate::types::user_id::UserId; +use ruma::OwnedUserId; +use sqlx::SqlitePool; #[derive(Debug)] pub struct User { - uuid: Uuid, - user_id: UserId, - display_name: String, - password: String, + pub uuid: Uuid, + pub user_id: String, + pub display_name: String, + pub password_hash: String, } impl User { - pub fn new(user_id: &UserId, display_name: &str, password: &str) -> anyhow::Result { + pub fn new(user_id: &OwnedUserId, display_name: &str, password: &str) -> anyhow::Result { let argon2 = Argon2::default(); let salt = SaltString::generate(OsRng); - let password = argon2 + let password_hash = argon2 .hash_password(password.as_bytes(), &salt)? .to_string(); Ok(Self { uuid: uuid::Uuid::new_v4().into(), - user_id: user_id.clone(), + user_id: user_id.to_string(), display_name: display_name.to_owned(), - password, + password_hash, }) } - pub async fn exists(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result { + pub async fn exists(conn: &SqlitePool, user_id: &OwnedUserId) -> anyhow::Result { + let user_id = user_id.to_string(); Ok( sqlx::query!("select user_id from users where user_id = ?", user_id) .fetch_optional(conn) @@ -41,13 +41,13 @@ impl User { pub async fn create(&self, conn: &SqlitePool) -> anyhow::Result { Ok(sqlx::query_as!( Self, - "insert into users(uuid, user_id, display_name, password) + "insert into users(uuid, user_id, display_name, password_hash) values (?, ?, ?, ?) - returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password", + returning uuid as 'uuid: Uuid', user_id, display_name, password_hash", self.uuid, self.user_id, self.display_name, - self.password + self.password_hash ) .fetch_one(conn) .await?) @@ -56,13 +56,13 @@ impl User { pub async fn update(&self, conn: &SqlitePool) -> anyhow::Result { Ok(sqlx::query_as!( Self, - "update users set uuid = ?, user_id = ?, display_name = ?, password = ? + "update users set uuid = ?, user_id = ?, display_name = ?, password_hash = ? where uuid = ? - returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password", + returning uuid as 'uuid: Uuid', user_id, display_name, password_hash", self.uuid, self.user_id, self.display_name, - self.password, + self.password_hash, self.uuid ) .fetch_one(conn) @@ -72,7 +72,7 @@ impl User { pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result { Ok(sqlx::query_as!( Self, - "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password + "select uuid as 'uuid: Uuid', user_id, display_name, password_hash from users where uuid = ?", uuid ) @@ -80,10 +80,10 @@ impl User { .await?) } - pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result { + pub async fn find_by_user_id(conn: &SqlitePool, user_id: &str) -> anyhow::Result { Ok(sqlx::query_as!( Self, - "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password + "select uuid as 'uuid: Uuid', user_id, display_name, password_hash from users where user_id = ?", user_id ) @@ -92,28 +92,10 @@ impl User { } pub fn password_correct(&self, password: &str) -> anyhow::Result { - let password_hash = PasswordHash::new(self.password())?; + let password_hash = PasswordHash::new(&self.password_hash)?; Ok(Argon2::default() .verify_password(password.as_bytes(), &password_hash) .is_ok()) } - - /// Get the user's id. - #[must_use] - pub fn uuid(&self) -> &Uuid { - &self.uuid - } - - /// Get a reference to the user's user id. - #[must_use] - pub fn user_id(&self) -> &UserId { - &self.user_id - } - - /// Get a reference to the user's password. - #[must_use] - pub fn password(&self) -> &str { - self.password.as_ref() - } } diff --git a/src/requests/create_room_request.rs b/src/requests/create_room_request.rs deleted file mode 100644 index 5bbaf48..0000000 --- a/src/requests/create_room_request.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::types::user_id::UserId; - -#[derive(Debug, serde::Deserialize)] -pub struct CreateRoomRequest { - /// Extra keys, such as `m.federate`, to be added to the content of the `m.room.create` event. - creation_content: Option<()>, - /// List of state events to set in the initial room. Used for overriding the default state - initial_state: Vec<()>, - /// List of user IDs to invite to the room - invite: Option>, - /// List of thirdparty IDs to invite to the room - invite_3pid: Option>, - /// Indicate if room is a direct chat room - is_direct: Option, - /// Set name of the room - name: Option, - /// Used to override the default power level event - power_level_content_override: Option<()>, - /// Preset for room creation - preset: Option<()>, - /// Desired room alias local part - room_alias_name: Option, - /// Version of room to create. Defaults to server default - room_version: Option, - /// Sets rooms topic - topic: Option, - /// Sets rooms visibility - visibility: () -} \ No newline at end of file diff --git a/src/requests/mod.rs b/src/requests/mod.rs deleted file mode 100644 index 6c83275..0000000 --- a/src/requests/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod registration; -pub mod create_room_request; \ No newline at end of file diff --git a/src/requests/registration.rs b/src/requests/registration.rs deleted file mode 100644 index 97c0f1e..0000000 --- a/src/requests/registration.rs +++ /dev/null @@ -1,48 +0,0 @@ -use crate::types::{authentication_data::AuthenticationData, flow::Flow, identifier::Identifier}; - -#[derive(Debug, serde::Deserialize)] -pub struct RegistrationRequest { - auth: Option, - device_id: Option, - inhibit_login: Option, - initial_device_display_name: Option, - password: Option, - username: Option, -} - -impl RegistrationRequest { - #[must_use] - pub fn auth(&self) -> Option<&AuthenticationData> { - self.auth.as_ref() - } - - /// Get a reference to the registration request's device id. - #[must_use] - pub fn device_id(&self) -> Option<&String> { - self.device_id.as_ref() - } - - /// Get the registration request's inhibit login. - #[must_use] - pub fn inhibit_login(&self) -> Option { - self.inhibit_login - } - - /// Get a reference to the registration request's initial device display name. - #[must_use] - pub fn initial_device_display_name(&self) -> Option<&String> { - self.initial_device_display_name.as_ref() - } - - /// Get a reference to the registration request's password. - #[must_use] - pub fn password(&self) -> Option<&String> { - self.password.as_ref() - } - - /// Get a reference to the registration request's username. - #[must_use] - pub fn username(&self) -> Option<&String> { - self.username.as_ref() - } -} diff --git a/src/responses/authentication.rs b/src/responses/authentication.rs deleted file mode 100644 index 68493bf..0000000 --- a/src/responses/authentication.rs +++ /dev/null @@ -1,26 +0,0 @@ -use axum::{response::IntoResponse, Json}; - -use crate::types::user_id::UserId; - -#[derive(Debug, serde::Serialize)] -#[serde(untagged)] -pub enum AuthenticationResponse { - Success(AuthenticationSuccess), -} - -#[derive(Debug, serde::Serialize)] -pub struct AuthenticationSuccess { - access_token: String, - device_id: String, - user_id: String, -} - -impl AuthenticationSuccess { - pub fn new(access_token: &str, device_id: &str, user_id: &UserId) -> Self { - Self { - access_token: access_token.to_owned(), - device_id: device_id.to_owned(), - user_id: user_id.to_string(), - } - } -} diff --git a/src/responses/flow.rs b/src/responses/flow.rs deleted file mode 100644 index 39493ae..0000000 --- a/src/responses/flow.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::types::flow::Flow; - -#[derive(Debug, Clone, serde::Serialize)] -struct FlowWrapper { - #[serde(rename = "type")] - _type: Flow, -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct Flows { - flows: Vec, -} - -impl Flows { - pub fn new() -> Self { - Self { - flows: vec![FlowWrapper { - _type: Flow::Password, - }], - } - } -} diff --git a/src/responses/mod.rs b/src/responses/mod.rs index dbee683..038a1c7 100644 --- a/src/responses/mod.rs +++ b/src/responses/mod.rs @@ -1,5 +1 @@ -pub mod authentication; -pub mod flow; pub mod registration; -pub mod username_available; -pub mod versions; diff --git a/src/responses/registration.rs b/src/responses/registration.rs index c42c99b..2342889 100644 --- a/src/responses/registration.rs +++ b/src/responses/registration.rs @@ -3,7 +3,6 @@ use crate::types::user_interactive_authorization::UserInteractiveAuthorizationIn #[derive(Debug, serde::Serialize)] #[serde(untagged)] pub enum RegistrationResponse { - Success(RegistrationSuccess), UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo), } @@ -13,22 +12,4 @@ impl RegistrationResponse { UserInteractiveAuthorizationInfo::new(), ) } -} - -#[derive(Debug, serde::Serialize)] -pub struct RegistrationSuccess { - #[serde(skip_serializing_if = "Option::is_none")] - access_token: Option, - device_id: String, - user_id: String, -} - -impl RegistrationSuccess { - pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self { - Self { - access_token: access_token.map(|v| v.to_owned()), - device_id: device_id.to_owned(), - user_id: user_id.to_owned(), - } - } -} +} \ No newline at end of file diff --git a/src/responses/username_available.rs b/src/responses/username_available.rs deleted file mode 100644 index 6ca7805..0000000 --- a/src/responses/username_available.rs +++ /dev/null @@ -1,10 +0,0 @@ -#[derive(Debug, serde::Serialize)] -pub struct UsernameAvailable { - available: bool, -} - -impl UsernameAvailable { - pub fn new(available: bool) -> Self { - Self { available } - } -} diff --git a/src/responses/versions.rs b/src/responses/versions.rs deleted file mode 100644 index 07be3cd..0000000 --- a/src/responses/versions.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug, serde::Serialize, serde::Deserialize)] -pub struct Versions { - #[serde(skip_serializing_if = "Option::is_none")] - unstable_features: Option>, - versions: Vec, -} - -impl Default for Versions { - fn default() -> Self { - Self { - unstable_features: None, - versions: vec!["v1.2".into()], - } - } -} diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 2ca287d..ee51216 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,5 +1,5 @@ use axum::{ - body::{Bytes, HttpBody, Full}, + body::{Bytes, Full, HttpBody}, extract::{FromRequest, Path}, response::IntoResponse, BoxError, diff --git a/src/state_resolution/mod.rs b/src/state_resolution/mod.rs deleted file mode 100644 index 7d88584..0000000 --- a/src/state_resolution/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod v2; \ No newline at end of file diff --git a/src/state_resolution/v2.rs b/src/state_resolution/v2.rs deleted file mode 100644 index fb8630d..0000000 --- a/src/state_resolution/v2.rs +++ /dev/null @@ -1,102 +0,0 @@ -use crate::models::events::Event; -use std::{ - collections::{HashMap, HashSet}, - future::Future, -}; -use tracing::info; - -type StateMap = HashMap; - -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -pub struct StateTuple { - event_type: String, - state_key: String, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct EventId { - id: Box, -} - -#[tracing::instrument(skip(state_sets, auth_chain_sets, get_event_callback))] -pub async fn resolve( - room_id: &str, // TODO: own type - state_sets: Vec>, - auth_chain_sets: Vec>, - get_event_callback: F, -) -> StateMap -where - F: Fn(&EventId) -> Fut, - Fut: Future>, -{ - info!("Calculating conflicted state"); - - let (unconflicted_state, conflicted_state) = separate_state(&state_sets); - - if conflicted_state.is_empty() { - return unconflicted_state; - } - - info!("{} conflicted_state entries", conflicted_state.len()); - info!("Calculating auth_chain differences"); - - let conflicted_set = - get_auth_chain_differences(auth_chain_sets).chain(conflicted_state.into_values().flatten()); - let mut conflicted = HashSet::new(); - for eid in conflicted_set { - if let Some(event) = get_event_callback(&eid).await { - conflicted.insert(event); - } - } - - todo!() -} - -/// separates states from multiple state_maps into unconflicted and conflicted state -/// -/// For the set of all state_tuples find all event_ids. -/// If one event_id is found it is unconflicted, otherwise it is conflicted -fn separate_state( - state_sets: &[StateMap], -) -> (StateMap, StateMap>) { - let mut unconflicted_state: StateMap = StateMap::new(); - let mut conflicted_state: HashMap> = StateMap::new(); - - for key in state_sets - .iter() - .flat_map(HashMap::keys) - .map(ToOwned::to_owned) - .collect::>() - { - let mut event_ids: HashSet = state_sets - .iter() - .filter_map(|state_set| state_set.get(&key)) - .map(ToOwned::to_owned) - .collect(); - - if event_ids.len() == 1 { - unconflicted_state.insert(key, event_ids.into_iter().next().expect("len() is 1")); - } else { - conflicted_state.insert(key, event_ids); - } - } - - (unconflicted_state, conflicted_state) -} - -fn get_auth_chain_differences( - auth_chain_sets: Vec>, -) -> impl Iterator { - let num_sets = auth_chain_sets.len(); - - let mut id_counts: HashMap = HashMap::new(); - for id in auth_chain_sets.into_iter().flatten() { - *id_counts.entry(id).or_default() += 1; - } - - id_counts - .into_iter() - .filter_map(move |(id, count)| (count < num_sets).then(move || id)) -} - -fn is_control_event() {} diff --git a/src/types/client_event.rs b/src/types/client_event.rs deleted file mode 100644 index eb5287c..0000000 --- a/src/types/client_event.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::{uuid::Uuid, user_id::UserId}; - -pub struct ClientEvent { - content: (), - event_id: Uuid, - origin_server_ts: u64, - room_id: String, - sender: UserId, - state_key: Option, - r#type: String, - unsigned: () -} \ No newline at end of file diff --git a/src/types/event_type.rs b/src/types/event_type.rs index 046bcd3..29a2fec 100644 --- a/src/types/event_type.rs +++ b/src/types/event_type.rs @@ -20,7 +20,7 @@ impl<'e> sqlx::Encode<'e, Sqlite> for EventType { buf.push(sqlx::sqlite::SqliteArgumentValue::Text( match self { EventType::RoomCreate => "m.room.create", - EventType::Unknown => "???" + EventType::Unknown => "???", } .into(), )); @@ -47,7 +47,7 @@ impl serde::Serialize for EventType { { serializer.serialize_str(match self { EventType::RoomCreate => "m.room.create", - EventType::Unknown => "dev.fuckwit.unknown_event" + EventType::Unknown => "dev.fuckwit.unknown_event", }) } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 9a6c61b..7e097a7 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,12 +1,10 @@ pub mod authentication_data; pub mod error_code; +pub mod event_type; pub mod flow; pub mod identifier; pub mod identifier_type; +pub mod server_name; pub mod user_id; pub mod user_interactive_authorization; pub mod uuid; -pub mod server_name; -pub mod client_event; -pub mod room_event; -pub mod event_type; \ No newline at end of file diff --git a/src/types/room_event.rs b/src/types/room_event.rs deleted file mode 100644 index d0fff51..0000000 --- a/src/types/room_event.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub enum RoomEvent { - Create(RoomCreateEvent) -} - -pub struct RoomCreateEvent { - -} \ No newline at end of file diff --git a/src/types/server_name.rs b/src/types/server_name.rs index 2f16a10..8af6269 100644 --- a/src/types/server_name.rs +++ b/src/types/server_name.rs @@ -1,15 +1,15 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; #[derive(Debug, PartialEq, Eq)] -enum Hostname { +pub enum Hostname { IPv4(Ipv4Addr), IPv6(Ipv6Addr), Fqdn(String), } pub struct ServerName { - hostname: Hostname, - port: Option, + pub hostname: Hostname, + pub port: Option, } impl ServerName { @@ -57,18 +57,6 @@ impl ServerName { }) } } - - /// Get a reference to the server name's hostname. - #[must_use] - fn hostname(&self) -> &Hostname { - &self.hostname - } - - /// Get the server name's port. - #[must_use] - fn port(&self) -> Option { - self.port - } } impl std::fmt::Display for ServerName { @@ -93,59 +81,59 @@ mod tests { fn parse_ipv4_without_port() { let server_name = ServerName::new("127.0.0.1").unwrap(); assert_eq!( - server_name.hostname(), - &Hostname::IPv4("127.0.0.1".parse().unwrap()) + server_name.hostname, + Hostname::IPv4("127.0.0.1".parse().unwrap()) ); - assert_eq!(server_name.port(), None); + assert_eq!(server_name.port, None); } #[test] fn parse_ipv4_with_port() { let server_name = ServerName::new("127.0.0.1:8080").unwrap(); assert_eq!( - server_name.hostname(), - &Hostname::IPv4("127.0.0.1".parse().unwrap()) + server_name.hostname, + Hostname::IPv4("127.0.0.1".parse().unwrap()) ); - assert_eq!(server_name.port(), Some(8080)); + assert_eq!(server_name.port, Some(8080)); } #[test] fn parse_ipv6_without_port() { let server_name = ServerName::new("[::1]").unwrap(); assert_eq!( - server_name.hostname(), - &Hostname::IPv6("::1".parse().unwrap()) + server_name.hostname, + Hostname::IPv6("::1".parse().unwrap()) ); - assert_eq!(server_name.port(), None); + assert_eq!(server_name.port, None); } #[test] fn parse_ipv6_with_port() { let server_name = ServerName::new("[::1]:8080").unwrap(); assert_eq!( - server_name.hostname(), - &Hostname::IPv6("::1".parse().unwrap()) + server_name.hostname, + Hostname::IPv6("::1".parse().unwrap()) ); - assert_eq!(server_name.port(), Some(8080)); + assert_eq!(server_name.port, Some(8080)); } #[test] fn parse_fqdn_without_port() { let server_name = ServerName::new("example.com").unwrap(); assert_eq!( - server_name.hostname(), - &Hostname::Fqdn("example.com".into()) + server_name.hostname, + Hostname::Fqdn("example.com".into()) ); - assert_eq!(server_name.port(), None); + assert_eq!(server_name.port, None); } #[test] fn parse_fqdn_with_port() { let server_name = ServerName::new("example.com:8080").unwrap(); assert_eq!( - server_name.hostname(), - &Hostname::Fqdn("example.com".into()) + server_name.hostname, + Hostname::Fqdn("example.com".into()) ); - assert_eq!(server_name.port(), Some(8080)); + assert_eq!(server_name.port, Some(8080)); } }