diff --git a/src/api/client_server/auth.rs b/src/api/client_server/auth.rs index 8b4c8db..9210895 100644 --- a/src/api/client_server/auth.rs +++ b/src/api/client_server/auth.rs @@ -23,6 +23,8 @@ use crate::{ Config, }; +use super::errors::{api_error::ApiError, registration_error::RegistrationError}; + pub fn routes() -> axum::Router { axum::Router::new() .route("/r0/login", get(get_login).post(post_login)) @@ -46,11 +48,14 @@ async fn get_username_available( Extension(config): Extension>, Extension(db): Extension, Query(params): Query>, -) -> Json { - let username = params.get("username").unwrap(); - let user_id = UserId::new(&username, &config.homeserver_name).unwrap(); - let exists = User::exists(&db, &user_id).await.unwrap(); - Json(UsernameAvailable::new(!exists)) +) -> Result, ApiError> { + let username = params + .get("username") + .ok_or(RegistrationError::MissingUserId)?; + let user_id = UserId::new(username, &config.homeserver_name)?; + let exists = User::exists(&db, &user_id).await?; + + Ok(Json(UsernameAvailable::new(!exists))) } #[tracing::instrument(skip_all)] @@ -59,29 +64,34 @@ async fn post_register( Extension(db): Extension, Json(body): Json, Query(params): Query>, -) -> (StatusCode, Json) { +) -> Result<(StatusCode, Json), ApiError> { // Client tries to get available flows - if *&body.auth().is_none() { - return ( + if body.auth().is_none() { + return Ok(( StatusCode::UNAUTHORIZED, Json(RegistrationResponse::user_interactive_authorization_info()), - ); + )); } let (user, device) = match &body.auth().unwrap() { AuthenticationData::Password(auth_data) => { - let username = body.username().unwrap(); - let user_id = UserId::new(&username, &config.homeserver_name).unwrap(); + let username = body.username().ok_or(RegistrationError::MissingUserId)?; + let user_id = UserId::new(username, &config.homeserver_name) + .ok() + .ok_or(RegistrationError::InvalidUserId)?; + if User::exists(&db, &user_id).await.unwrap() { todo!("Error out") } + let password = auth_data.password(); - let display_name = match *&body.initial_device_display_name() { + + let display_name = match body.initial_device_display_name() { Some(display_name) => display_name.as_ref(), None => "Random displayname", }; - let user = User::create(&db, &user_id, &user_id.to_string(), &password) + let user = User::create(&db, &user_id, &user_id.to_string(), password) .await .unwrap(); let device = Device::create(&db, &user, "test", display_name) @@ -92,15 +102,15 @@ async fn post_register( } }; - if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() { + if body.inhibit_login().unwrap_or(false) { let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id()); - (StatusCode::OK, Json(RegistrationResponse::Success(resp))) + Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp)))) } else { let session = device.create_session(&db).await.unwrap(); let resp = RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id()); - (StatusCode::OK, Json(RegistrationResponse::Success(resp))) + Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp)))) } } diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index 49836a5..d0d0efd 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -1,2 +1,3 @@ pub mod auth; pub mod versions; +pub mod errors;