diff --git a/sqlx-data.json b/sqlx-data.json index 16006c7..ebd1741 100644 --- a/sqlx-data.json +++ b/sqlx-data.json @@ -36,6 +36,42 @@ }, "query": "insert into devices(uuid, user_uuid, device_id, display_name)\n values(?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name" }, + "1843c2b3e548d1dd13694a65ca1ba123da38668c3fc5bc431fe5884a6fc25f71": { + "describe": { + "columns": [ + { + "name": "uuid: Uuid", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "user_uuid: Uuid", + "ordinal": 1, + "type_info": "Int64" + }, + { + "name": "device_id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "display_name", + "ordinal": 3, + "type_info": "Text" + } + ], + "nullable": [ + false, + false, + false, + false + ], + "parameters": { + "Right": 1 + } + }, + "query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?" + }, "221d0935dff8911fe58ac047d39e11b0472d2180d7c297291a5dc440e00efb80": { "describe": { "columns": [ @@ -66,6 +102,42 @@ }, "query": "insert into sessions(uuid, device_uuid, key)\n values(?, ?, ?)\n returning uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key" }, + "383949b72c69bca95bf23ef06900cd1ac5a136cdd4a525cbb624d327ce0cdefb": { + "describe": { + "columns": [ + { + "name": "uuid: Uuid", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "user_id: UserId", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "display_name", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "password", + "ordinal": 3, + "type_info": "Text" + } + ], + "nullable": [ + false, + false, + false, + false + ], + "parameters": { + "Right": 1 + } + }, + "query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where uuid = ?" + }, "3fead3dac0e110757bc30be40bb0c6c2bc02127b6d9b6145bfc40fa5fe22ad06": { "describe": { "columns": [ @@ -156,6 +228,36 @@ }, "query": "update users set uuid = ?, user_id = ?, display_name = ?, password = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password" }, + "b38fd90504bea0c63e6517738c2354e6b057fcc6c643283019b27689e286bf2d": { + "describe": { + "columns": [ + { + "name": "uuid: Uuid", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "device_uuid: Uuid", + "ordinal": 1, + "type_info": "Int64" + }, + { + "name": "key", + "ordinal": 2, + "type_info": "Text" + } + ], + "nullable": [ + false, + false, + false + ], + "parameters": { + "Right": 1 + } + }, + "query": "select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key\n from sessions where key = ?" + }, "ddcc531c080b2a1c70166d29a940aaada6701abe2933c305a879e7f18baeaf3a": { "describe": { "columns": [ diff --git a/src/api/client_server/errors/mod.rs b/src/api/client_server/errors/mod.rs index 46bac93..b10ed3b 100644 --- a/src/api/client_server/errors/mod.rs +++ b/src/api/client_server/errors/mod.rs @@ -13,7 +13,7 @@ pub struct ErrorResponse { } impl ErrorResponse { - fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option) -> Self { + pub fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option) -> Self { Self { errcode, error: error.to_owned(), diff --git a/src/api/client_server/r0/auth.rs b/src/api/client_server/r0/auth.rs index 6392388..53d1671 100644 --- a/src/api/client_server/r0/auth.rs +++ b/src/api/client_server/r0/auth.rs @@ -54,7 +54,7 @@ async fn post_login( match body { AuthenticationData::Password(auth_data) => { let user = auth_data.user().unwrap(); - let user_id = UserId::new(&user, config.server_name()) + let user_id = UserId::new(user, config.server_name()) .ok() .ok_or(AuthenticationError::InvalidUserId)?; @@ -96,7 +96,7 @@ async fn get_username_available( let username = params .get("username") .ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.server_name()) + let user_id = UserId::new(username, config.server_name()) .ok() .ok_or(RegistrationError::InvalidUserId)?; let exists = User::exists(&db, &user_id).await?; @@ -117,7 +117,7 @@ async fn post_register( let (user, device) = match &body.auth().expect("must be Some") { AuthenticationData::Password(auth_data) => { let username = body.username().ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.server_name()) + let user_id = UserId::new(username, config.server_name()) .ok() .ok_or(RegistrationError::InvalidUserId)?; diff --git a/src/api/client_server/r0/mod.rs b/src/api/client_server/r0/mod.rs index 5696e21..bf3e108 100644 --- a/src/api/client_server/r0/mod.rs +++ b/src/api/client_server/r0/mod.rs @@ -1 +1,106 @@ -pub mod auth; \ No newline at end of file +use std::sync::Arc; + +use axum::{ + http::{Request, StatusCode}, + middleware::Next, + response::IntoResponse, + Json, +}; +use sqlx::SqlitePool; + +use crate::{models::sessions::Session, types::error_code::ErrorCode}; + +use super::errors::ErrorResponse; + +pub mod auth; +pub mod thirdparty; + +async fn authentication_middleware(mut req: Request, next: Next) -> impl IntoResponse { + let db: &SqlitePool = req.extensions().get().unwrap(); + let auth_header = req + .headers() + .get(axum::http::header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + if auth_header.is_none() { + return ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new( + ErrorCode::Forbidden, + "Authorization Header not given", + None, + )), + ) + .into_response(); + } + + let auth_header = auth_header.expect("Validated above"); + let idx = auth_header.find(' '); + + let idx = match idx { + Some(idx) => idx, + None => { + return ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new( + ErrorCode::Forbidden, + "Invalid Authorization Header", + None, + )), + ) + .into_response() + } + }; + + let session = match Session::find_by_key(db, &auth_header[idx + 1..]).await { + Ok(session) => session, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse::new( + ErrorCode::Unknown, + "Internal Server Error", + None, + )), + ) + .into_response() + } + }; + + let session = match session { + Some(session) => session, + None => { + return ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)), + ) + .into_response() + } + }; + + let device = match session.device(db).await { + Ok(device) => device, + Err(_) => { + return ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)), + ) + .into_response() + } + }; + + let user = match device.user(db).await { + Ok(user) => user, + Err(_) => { + return ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)), + ) + .into_response() + } + }; + + req.extensions_mut().insert(Arc::new(user)); + + next.run(req).await.into_response() +} diff --git a/src/api/client_server/r0/thirdparty.rs b/src/api/client_server/r0/thirdparty.rs new file mode 100644 index 0000000..ba16ba7 --- /dev/null +++ b/src/api/client_server/r0/thirdparty.rs @@ -0,0 +1,17 @@ +use std::sync::Arc; + +use axum::{routing::get, Extension}; + +use crate::{api::client_server::errors::api_error::ApiError, models::users::User}; + + +pub fn routes() -> axum::Router { + axum::Router::new() + .route("/r0/thirdparty/protocols", get(get_thirdparty_protocols)) + .layer(axum::middleware::from_fn(super::authentication_middleware)) +} + +#[tracing::instrument(skip_all)] +async fn get_thirdparty_protocols(Extension(user): Extension>) -> Result { + Ok("{}".into()) +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 9268d73..88c8818 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,11 +16,11 @@ use tower_http::{ use tracing::Level; mod api; +mod config; mod models; mod requests; mod responses; mod types; -mod config; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -32,7 +32,7 @@ async fn main() -> anyhow::Result<()> { let config = Arc::new(Config::default()); - let pool = sqlx::SqlitePool::connect(&config.db_path()).await?; + let pool = sqlx::SqlitePool::connect(config.db_path()).await?; let cors = CorsLayer::new() .allow_origin(tower_http::cors::Any) @@ -43,7 +43,8 @@ async fn main() -> anyhow::Result<()> { let client_server = Router::new() .merge(api::client_server::versions::routes()) - .merge(api::client_server::r0::auth::routes()); + .merge(api::client_server::r0::auth::routes()) + .merge(api::client_server::r0::thirdparty::routes()); let router = Router::new() .nest("/_matrix/client", client_server) @@ -60,7 +61,7 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn fallback(request: Request) -> StatusCode { - dbg!(request); +async fn fallback(mut request: Request) -> StatusCode { + println!("{} {}", request.method(), request.uri()); StatusCode::INTERNAL_SERVER_ERROR } diff --git a/src/models/devices.rs b/src/models/devices.rs index cc8de1c..9b8d2c9 100644 --- a/src/models/devices.rs +++ b/src/models/devices.rs @@ -49,6 +49,19 @@ impl Device { ) } + pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result { + Ok(sqlx::query_as!( + Self, + "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?", + uuid) + .fetch_one(conn).await? + ) + } + + pub async fn user(&self, conn: &SqlitePool) -> anyhow::Result { + User::find_by_uuid(conn, &self.user_uuid).await + } + /// Get the device's id. #[must_use] pub fn uuid(&self) -> &Uuid { diff --git a/src/models/sessions.rs b/src/models/sessions.rs index a5fd718..226f24e 100644 --- a/src/models/sessions.rs +++ b/src/models/sessions.rs @@ -1,4 +1,4 @@ -use rand::{thread_rng, Rng, distributions::Alphanumeric}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; use sqlx::SqlitePool; use crate::types::uuid::Uuid; @@ -14,7 +14,7 @@ pub struct Session { impl Session { pub fn new(device: &Device) -> anyhow::Result { let mut rng = thread_rng(); - let key: String = (0..32).map(|_| rng.sample(Alphanumeric) as char).collect(); + let key: String = (0..256).map(|_| rng.sample(Alphanumeric) as char).collect(); Ok(Self { uuid: uuid::Uuid::new_v4().into(), @@ -37,6 +37,21 @@ impl Session { .await?) } + pub async fn find_by_key(conn: &SqlitePool, key: &str) -> anyhow::Result> { + Ok(sqlx::query_as!( + Self, + "select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key + from sessions where key = ?", + key + ) + .fetch_optional(conn) + .await?) + } + + pub async fn device(&self, conn: &SqlitePool) -> anyhow::Result { + Device::find_by_uuid(conn, &self.device_uuid).await + } + /// Get the session's id. #[must_use] pub fn uuid(&self) -> &Uuid { diff --git a/src/models/users.rs b/src/models/users.rs index 80a0e0b..7b0a5d0 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -1,10 +1,11 @@ use crate::types::uuid::Uuid; -use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier}; +use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; use rand::rngs::OsRng; use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool}; use crate::types::user_id::UserId; +#[derive(Debug)] pub struct User { uuid: Uuid, user_id: UserId, @@ -68,6 +69,17 @@ impl User { .await?) } + pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result { + Ok(sqlx::query_as!( + Self, + "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password + from users where uuid = ?", + uuid + ) + .fetch_one(conn) + .await?) + } + pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result { Ok(sqlx::query_as!( Self, @@ -82,7 +94,9 @@ impl User { pub fn password_correct(&self, password: &str) -> anyhow::Result { let password_hash = PasswordHash::new(self.password())?; - Ok(Argon2::default().verify_password(password.as_bytes(), &password_hash).is_ok()) + Ok(Argon2::default() + .verify_password(password.as_bytes(), &password_hash) + .is_ok()) } /// Get the user's id. diff --git a/src/responses/registration.rs b/src/responses/registration.rs index 399723c..c42c99b 100644 --- a/src/responses/registration.rs +++ b/src/responses/registration.rs @@ -26,7 +26,7 @@ pub struct RegistrationSuccess { impl RegistrationSuccess { pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self { Self { - access_token: access_token.and_then(|v| Some(v.to_owned())), + access_token: access_token.map(|v| v.to_owned()), device_id: device_id.to_owned(), user_id: user_id.to_owned(), } diff --git a/src/types/server_name.rs b/src/types/server_name.rs index 8664f9a..2f16a10 100644 --- a/src/types/server_name.rs +++ b/src/types/server_name.rs @@ -4,7 +4,7 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; enum Hostname { IPv4(Ipv4Addr), IPv6(Ipv6Addr), - FQDN(String), + Fqdn(String), } pub struct ServerName { @@ -47,12 +47,12 @@ impl ServerName { if let Some(idx) = server_name.find(':') { let (hostname, port) = server_name.split_at(idx); Ok(Self { - hostname: Hostname::FQDN(hostname.to_owned()), + hostname: Hostname::Fqdn(hostname.to_owned()), port: Some(port[1..].parse()?), }) } else { Ok(Self { - hostname: Hostname::FQDN(server_name.to_owned()), + hostname: Hostname::Fqdn(server_name.to_owned()), port: None, }) } @@ -76,7 +76,7 @@ impl std::fmt::Display for ServerName { match &self.hostname { Hostname::IPv4(hostname) => write!(f, "{hostname}"), Hostname::IPv6(hostname) => write!(f, "[{hostname}]"), - Hostname::FQDN(hostname) => write!(f, "{hostname}"), + Hostname::Fqdn(hostname) => write!(f, "{hostname}"), }?; if let Some(port) = self.port { @@ -134,7 +134,7 @@ mod tests { let server_name = ServerName::new("example.com").unwrap(); assert_eq!( server_name.hostname(), - &Hostname::FQDN("example.com".into()) + &Hostname::Fqdn("example.com".into()) ); assert_eq!(server_name.port(), None); } @@ -144,7 +144,7 @@ mod tests { let server_name = ServerName::new("example.com:8080").unwrap(); assert_eq!( server_name.hostname(), - &Hostname::FQDN("example.com".into()) + &Hostname::Fqdn("example.com".into()) ); assert_eq!(server_name.port(), Some(8080)); } diff --git a/src/types/user_id.rs b/src/types/user_id.rs index 1f4d926..891d06d 100644 --- a/src/types/user_id.rs +++ b/src/types/user_id.rs @@ -4,7 +4,7 @@ use sqlx::{encode::IsNull, Sqlite}; use super::server_name::ServerName; -#[derive(Clone)] +#[derive(Debug, Clone)] #[repr(transparent)] pub struct UserId(String);