diff --git a/src/api/client_server/auth.rs b/src/api/client_server/auth.rs index 9210895..3971d6d 100644 --- a/src/api/client_server/auth.rs +++ b/src/api/client_server/auth.rs @@ -17,9 +17,7 @@ use crate::{ models::users::User, requests::registration::RegistrationRequest, responses::username_available::UsernameAvailable, - types::{ - authentication_data::AuthenticationData, identifier::Identifier, matrix_user_id::UserId, - }, + types::{authentication_data::AuthenticationData, matrix_user_id::UserId}, Config, }; @@ -52,7 +50,9 @@ async fn get_username_available( let username = params .get("username") .ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.homeserver_name)?; + let user_id = UserId::new(username, &config.homeserver_name) + .ok() + .ok_or(RegistrationError::InvalidUserId)?; let exists = User::exists(&db, &user_id).await?; Ok(Json(UsernameAvailable::new(!exists))) @@ -63,26 +63,21 @@ async fn post_register( Extension(config): Extension>, Extension(db): Extension, Json(body): Json, - Query(params): Query>, -) -> Result<(StatusCode, Json), ApiError> { - // Client tries to get available flows - if body.auth().is_none() { - return Ok(( - StatusCode::UNAUTHORIZED, - Json(RegistrationResponse::user_interactive_authorization_info()), - )); - } +) -> Result, ApiError> { + body.auth() + .ok_or(RegistrationError::AdditionalAuthenticationInformation)?; - let (user, device) = match &body.auth().unwrap() { + let (user, device) = match &body.auth().expect("must be Some") { AuthenticationData::Password(auth_data) => { let username = body.username().ok_or(RegistrationError::MissingUserId)?; let user_id = UserId::new(username, &config.homeserver_name) .ok() .ok_or(RegistrationError::InvalidUserId)?; - if User::exists(&db, &user_id).await.unwrap() { - todo!("Error out") - } + User::exists(&db, &user_id) + .await? + .then(|| ()) + .ok_or(RegistrationError::UserIdTaken)?; let password = auth_data.password(); @@ -91,12 +86,8 @@ async fn post_register( None => "Random displayname", }; - let user = User::create(&db, &user_id, &user_id.to_string(), password) - .await - .unwrap(); - let device = Device::create(&db, &user, "test", display_name) - .await - .unwrap(); + let user = User::create(&db, &user_id, &user_id.to_string(), password).await?; + let device = Device::create(&db, &user, "test", display_name).await?; (user, device) } @@ -105,12 +96,12 @@ async fn post_register( if body.inhibit_login().unwrap_or(false) { let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id()); - Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp)))) + Ok(Json(RegistrationResponse::Success(resp))) } else { - let session = device.create_session(&db).await.unwrap(); + let session = device.create_session(&db).await?; let resp = RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id()); - Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp)))) + Ok(Json(RegistrationResponse::Success(resp))) } } diff --git a/src/api/client_server/errors/api_error.rs b/src/api/client_server/errors/api_error.rs index 33b9158..2949ae4 100644 --- a/src/api/client_server/errors/api_error.rs +++ b/src/api/client_server/errors/api_error.rs @@ -1,5 +1,10 @@ use axum::http::StatusCode; use axum::response::IntoResponse; +use axum::Json; +use sqlx::Statement; + +use crate::responses::registration::RegistrationResponse; +use crate::types::error_code::ErrorCode; use super::registration_error::RegistrationError; @@ -13,6 +18,24 @@ macro_rules! map_err { } } +#[derive(Debug, serde::Serialize)] +struct ErrorResponse { + errcode: ErrorCode, + error: String, + #[serde(skip_serializing_if = "Option::is_none")] + retry_after_ms: Option, +} + +impl ErrorResponse { + fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option) -> Self { + Self { + errcode, + error: error.to_owned(), + retry_after_ms, + } + } +} + #[derive(Debug, thiserror::Error)] pub enum ApiError { #[error("Registration Error")] @@ -40,14 +63,50 @@ impl IntoResponse for ApiError { fn into_response(self) -> axum::response::Response { match self { ApiError::RegistrationError(registration_error) => match registration_error { - RegistrationError::InvalidUserId => { - (StatusCode::OK, String::new()).into_response() - } - RegistrationError::MissingUserId => { - (StatusCode::OK, String::new()).into_response() - } + RegistrationError::AdditionalAuthenticationInformation => ( + StatusCode::UNAUTHORIZED, + Json(RegistrationResponse::user_interactive_authorization_info()), + ).into_response(), + RegistrationError::InvalidUserId => (StatusCode::OK, Json( + ErrorResponse::new( + ErrorCode::InvalidUsername, + ®istration_error.to_string(), + None, + ) + )).into_response(), + RegistrationError::MissingUserId => (StatusCode::OK, String::new()).into_response(), + RegistrationError::UserIdTaken => ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse::new( + ErrorCode::UserInUse, + ®istration_error.to_string(), + None, + )), + ) + .into_response(), }, - _ => StatusCode::INTERNAL_SERVER_ERROR.into_response() + ApiError::DBError(err) => { + tracing::error!("{}", err.to_string()); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse::new( + ErrorCode::Unknown, + "Database error! If you are the application owner please take a look at your application logs.", + None, + )), + ) + .into_response() + } + ApiError::Generic(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse::new( + ErrorCode::Unknown, + "Fatal error occured! If you are the application owner please take a look at your application logs.", + None, + )), + ) + .into_response(), + _ => StatusCode::INTERNAL_SERVER_ERROR.into_response(), } } } diff --git a/src/api/client_server/errors/registration_error.rs b/src/api/client_server/errors/registration_error.rs index 5ca506f..c03749c 100644 --- a/src/api/client_server/errors/registration_error.rs +++ b/src/api/client_server/errors/registration_error.rs @@ -1,7 +1,11 @@ #[derive(Debug, thiserror::Error)] pub enum RegistrationError { + #[error("The homeserver requires additional authentication information")] + AdditionalAuthenticationInformation, #[error("UserId is missing")] MissingUserId, - #[error("UserId is invalid")] + #[error("The desired user ID is not a valid user name")] InvalidUserId, + #[error("The desired user ID is already taken")] + UserIdTaken, } diff --git a/src/types/mod.rs b/src/types/mod.rs index 383d3cb..9d4c82d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -4,3 +4,4 @@ pub mod identifier; pub mod identifier_type; pub mod matrix_user_id; pub mod user_interactive_authorization; +pub mod error_code; \ No newline at end of file