diff --git a/src/api/client_server/auth.rs b/src/api/client_server/auth.rs index 32bc589..3ff81b0 100644 --- a/src/api/client_server/auth.rs +++ b/src/api/client_server/auth.rs @@ -10,7 +10,7 @@ use axum::{ use rand_core::OsRng; use sqlx::SqlitePool; -use crate::responses::registration::RegistrationResponse; +use crate::{responses::{registration::RegistrationResponse, authentication::{AuthenticationResponse, AuthenticationSuccess}}, api::client_server::errors::authentication_error::AuthenticationError}; use crate::{ models::devices::Device, responses::{flow::Flows, registration::RegistrationSuccess}, @@ -38,9 +38,16 @@ async fn get_login() -> Result, ApiError> { } #[tracing::instrument(skip_all)] -async fn post_login(body: String) -> StatusCode { - dbg!(body); - StatusCode::INTERNAL_SERVER_ERROR +async fn post_login( + Extension(config): Extension>, + Extension(db): Extension, + Json(body): Json +) -> Result, ApiError> { + let user = UserId::new("name", "server_name").ok().ok_or(AuthenticationError::InvalidUserId)?; + todo!("Flesh this out more"); + let resp = AuthenticationSuccess::new("", "", &user); + + Ok(Json(AuthenticationResponse::Success(resp))) } #[tracing::instrument(skip_all)] @@ -76,17 +83,17 @@ async fn post_register( .ok() .ok_or(RegistrationError::InvalidUserId)?; - User::exists(&db, &user_id) - .await? - .then(|| ()) - .ok_or(RegistrationError::UserIdTaken)?; + if User::exists(&db, &user_id).await? { + return Err(ApiError::from(RegistrationError::UserIdTaken)); + } let display_name = match body.initial_device_display_name() { Some(display_name) => display_name.as_ref(), None => "Random displayname", }; - let user = User::create(&db, &user_id, &user_id.to_string(), auth_data.password()).await?; + let user = + User::create(&db, &user_id, &user_id.to_string(), auth_data.password()).await?; let device = Device::create(&db, &user, "test", display_name).await?; (user, device) diff --git a/src/api/client_server/errors/api_error.rs b/src/api/client_server/errors/api_error.rs index 4b02b74..dea9eea 100644 --- a/src/api/client_server/errors/api_error.rs +++ b/src/api/client_server/errors/api_error.rs @@ -6,7 +6,9 @@ use sqlx::Statement; use crate::responses::registration::RegistrationResponse; use crate::types::error_code::ErrorCode; +use super::authentication_error::AuthenticationError; use super::registration_error::RegistrationError; +use super::ErrorResponse; macro_rules! map_err { ($err:ident, $($type:path => $target:path),+) => { @@ -18,34 +20,19 @@ macro_rules! map_err { } } -#[derive(Debug, serde::Serialize)] -struct ErrorResponse { - errcode: ErrorCode, - error: String, - #[serde(skip_serializing_if = "Option::is_none")] - retry_after_ms: Option, -} - -impl ErrorResponse { - fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option) -> Self { - Self { - errcode, - error: error.to_owned(), - retry_after_ms, - } - } -} - #[derive(Debug, thiserror::Error)] pub enum ApiError { #[error("Registration Error")] RegistrationError(#[from] RegistrationError), + #[error("Authentication Error")] + AuthenticationError(#[from] AuthenticationError), + #[error("Database Error")] DBError(#[from] sqlx::Error), #[error("Generic Error")] - Generic(anyhow::Error) + Generic(anyhow::Error), } impl From for ApiError { @@ -62,29 +49,8 @@ impl From for ApiError { impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { match self { - ApiError::RegistrationError(registration_error) => match registration_error { - RegistrationError::AdditionalAuthenticationInformation => ( - StatusCode::UNAUTHORIZED, - Json(RegistrationResponse::user_interactive_authorization_info()), - ).into_response(), - RegistrationError::InvalidUserId | RegistrationError::MissingUserId => ( - StatusCode::BAD_REQUEST, - Json(ErrorResponse::new( - ErrorCode::InvalidUsername, - ®istration_error.to_string(), - None, - ) - )).into_response(), - RegistrationError::UserIdTaken => ( - StatusCode::BAD_REQUEST, - Json(ErrorResponse::new( - ErrorCode::UserInUse, - ®istration_error.to_string(), - None, - )), - ) - .into_response(), - }, + ApiError::RegistrationError(e) => e.into_response(), + ApiError::AuthenticationError(e) => e.into_response(), ApiError::DBError(err) => { tracing::error!("{}", err.to_string()); ( diff --git a/src/api/client_server/errors/authentication_error.rs b/src/api/client_server/errors/authentication_error.rs new file mode 100644 index 0000000..470469d --- /dev/null +++ b/src/api/client_server/errors/authentication_error.rs @@ -0,0 +1,41 @@ +use axum::{http::StatusCode, response::IntoResponse, Json}; + +use crate::types::error_code::ErrorCode; + +use super::ErrorResponse; + +#[derive(Debug, thiserror::Error)] +pub enum AuthenticationError { + #[error("UserId is missing")] + MissingUserId, + #[error("The user ID is not a valid user name")] + InvalidUserId, + #[error("The provided authentication data was incorrect")] + Forbidden, + #[error("The user has been deactivated")] + UserDeactivated, +} + +impl IntoResponse for AuthenticationError { + fn into_response(self) -> axum::response::Response { + match self { + Self::InvalidUserId | Self::MissingUserId => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse::new( + ErrorCode::InvalidUsername, + &self.to_string(), + None, + )), + ) + .into_response(), + Self::Forbidden => ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new(ErrorCode::Forbidden, &self.to_string(), None)), + ).into_response(), + Self::UserDeactivated => ( + StatusCode::FORBIDDEN, + Json(ErrorResponse::new(ErrorCode::UserDeactivated, &self.to_string(), None)), + ).into_response(), + } + } +} diff --git a/src/api/client_server/errors/mod.rs b/src/api/client_server/errors/mod.rs index 5072c0c..3c74ea9 100644 --- a/src/api/client_server/errors/mod.rs +++ b/src/api/client_server/errors/mod.rs @@ -1,2 +1,23 @@ +use crate::types::error_code::ErrorCode; + pub mod api_error; +pub mod authentication_error; pub mod registration_error; + +#[derive(Debug, serde::Serialize)] +pub struct ErrorResponse { + errcode: ErrorCode, + error: String, + #[serde(skip_serializing_if = "Option::is_none")] + retry_after_ms: Option, +} + +impl ErrorResponse { + fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option) -> Self { + Self { + errcode, + error: error.to_owned(), + retry_after_ms, + } + } +} \ No newline at end of file diff --git a/src/api/client_server/errors/registration_error.rs b/src/api/client_server/errors/registration_error.rs index 5805abc..0a8d577 100644 --- a/src/api/client_server/errors/registration_error.rs +++ b/src/api/client_server/errors/registration_error.rs @@ -1,3 +1,9 @@ +use axum::{http::StatusCode, response::IntoResponse, Json}; + +use crate::{responses::registration::RegistrationResponse, types::error_code::ErrorCode}; + +use super::ErrorResponse; + #[derive(Debug, thiserror::Error)] pub enum RegistrationError { #[error("The homeserver requires additional authentication information")] @@ -7,5 +13,35 @@ pub enum RegistrationError { #[error("The desired user ID is not a valid user name")] InvalidUserId, #[error("The desired user ID is already taken")] - UserIdTaken + UserIdTaken, +} + +impl IntoResponse for RegistrationError { + fn into_response(self) -> axum::response::Response { + match self { + RegistrationError::AdditionalAuthenticationInformation => ( + StatusCode::UNAUTHORIZED, + Json(RegistrationResponse::user_interactive_authorization_info()), + ) + .into_response(), + RegistrationError::InvalidUserId | RegistrationError::MissingUserId => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse::new( + ErrorCode::InvalidUsername, + &self.to_string(), + None, + )), + ) + .into_response(), + RegistrationError::UserIdTaken => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse::new( + ErrorCode::UserInUse, + &self.to_string(), + None, + )), + ) + .into_response(), + } + } } diff --git a/src/requests/mod.rs b/src/requests/mod.rs index 038a1c7..9b90583 100644 --- a/src/requests/mod.rs +++ b/src/requests/mod.rs @@ -1 +1 @@ -pub mod registration; +pub mod registration; \ No newline at end of file diff --git a/src/requests/registration.rs b/src/requests/registration.rs index f101e56..97c0f1e 100644 --- a/src/requests/registration.rs +++ b/src/requests/registration.rs @@ -2,22 +2,11 @@ use crate::types::{authentication_data::AuthenticationData, flow::Flow, identifi #[derive(Debug, serde::Deserialize)] pub struct RegistrationRequest { - #[serde(skip_serializing_if = "Option::is_none")] auth: Option, - - #[serde(skip_serializing_if = "Option::is_none")] device_id: Option, - - #[serde(skip_serializing_if = "Option::is_none")] inhibit_login: Option, - - #[serde(skip_serializing_if = "Option::is_none")] initial_device_display_name: Option, - - #[serde(skip_serializing_if = "Option::is_none")] password: Option, - - #[serde(skip_serializing_if = "Option::is_none")] username: Option, } diff --git a/src/responses/authentication.rs b/src/responses/authentication.rs new file mode 100644 index 0000000..63fbf68 --- /dev/null +++ b/src/responses/authentication.rs @@ -0,0 +1,26 @@ +use axum::{response::IntoResponse, Json}; + +use crate::types::user_id::UserId; + +#[derive(Debug, serde::Serialize)] +#[serde(untagged)] +pub enum AuthenticationResponse { + Success(AuthenticationSuccess) +} + +#[derive(Debug, serde::Serialize)] +pub struct AuthenticationSuccess { + access_token: String, + device_id: String, + user_id: String, +} + +impl AuthenticationSuccess { + pub fn new(access_token: &str, device_id: &str, user_id: &UserId) -> Self { + Self { + access_token: access_token.to_owned(), + device_id: device_id.to_owned(), + user_id: user_id.to_string(), + } + } +} \ No newline at end of file diff --git a/src/responses/mod.rs b/src/responses/mod.rs index 2bfa6bf..7c34e93 100644 --- a/src/responses/mod.rs +++ b/src/responses/mod.rs @@ -2,3 +2,4 @@ pub mod flow; pub mod registration; pub mod username_available; pub mod versions; +pub mod authentication; \ No newline at end of file diff --git a/src/types/authentication_data.rs b/src/types/authentication_data.rs index ea29dde..cad34fa 100644 --- a/src/types/authentication_data.rs +++ b/src/types/authentication_data.rs @@ -13,6 +13,8 @@ pub struct AuthenticationPassword { identifier: Identifier, password: String, user: Option, + device_id: Option, + initial_device_display_name: Option, } impl AuthenticationPassword { @@ -33,4 +35,16 @@ impl AuthenticationPassword { pub fn user(&self) -> Option<&String> { self.user.as_ref() } + + /// Get a reference to the authentication password's device id. + #[must_use] + pub fn device_id(&self) -> Option<&String> { + self.device_id.as_ref() + } + + /// Get a mutable reference to the authentication password's initial device display name. + #[must_use] + pub fn initial_device_display_name_mut(&self) -> Option<&String> { + self.initial_device_display_name.as_ref() + } } diff --git a/src/types/error_code.rs b/src/types/error_code.rs index 6febaa8..650f72a 100644 --- a/src/types/error_code.rs +++ b/src/types/error_code.rs @@ -12,6 +12,7 @@ pub enum ErrorCode { UserInUse, InvalidUsername, Exclusive, + UserDeactivated } impl serde::Serialize for ErrorCode { @@ -31,6 +32,7 @@ impl serde::Serialize for ErrorCode { ErrorCode::UserInUse => "M_USER_IN_USE", ErrorCode::InvalidUsername => "M_INVALID_USERNAME", ErrorCode::Exclusive => "M_EXCLUSIVE", + ErrorCode::UserDeactivated => "M_USER_DEACTIVATED", }) } }