Compare commits

..

2 Commits

Author SHA1 Message Date
b8e1235396 add file 2022-04-25 19:37:18 +02:00
b4b4f837cf error responses 2022-04-25 19:36:26 +02:00
5 changed files with 125 additions and 34 deletions

View File

@ -17,9 +17,7 @@ use crate::{
models::users::User, models::users::User,
requests::registration::RegistrationRequest, requests::registration::RegistrationRequest,
responses::username_available::UsernameAvailable, responses::username_available::UsernameAvailable,
types::{ types::{authentication_data::AuthenticationData, matrix_user_id::UserId},
authentication_data::AuthenticationData, identifier::Identifier, matrix_user_id::UserId,
},
Config, Config,
}; };
@ -52,7 +50,9 @@ async fn get_username_available(
let username = params let username = params
.get("username") .get("username")
.ok_or(RegistrationError::MissingUserId)?; .ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.homeserver_name)?; let user_id = UserId::new(username, &config.homeserver_name)
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
let exists = User::exists(&db, &user_id).await?; let exists = User::exists(&db, &user_id).await?;
Ok(Json(UsernameAvailable::new(!exists))) Ok(Json(UsernameAvailable::new(!exists)))
@ -63,26 +63,21 @@ async fn post_register(
Extension(config): Extension<Arc<Config>>, Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>, Extension(db): Extension<SqlitePool>,
Json(body): Json<RegistrationRequest>, Json(body): Json<RegistrationRequest>,
Query(params): Query<HashMap<String, String>>, ) -> Result<Json<RegistrationResponse>, ApiError> {
) -> Result<(StatusCode, Json<RegistrationResponse>), ApiError> { body.auth()
// Client tries to get available flows .ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
if body.auth().is_none() {
return Ok((
StatusCode::UNAUTHORIZED,
Json(RegistrationResponse::user_interactive_authorization_info()),
));
}
let (user, device) = match &body.auth().unwrap() { let (user, device) = match &body.auth().expect("must be Some") {
AuthenticationData::Password(auth_data) => { AuthenticationData::Password(auth_data) => {
let username = body.username().ok_or(RegistrationError::MissingUserId)?; let username = body.username().ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.homeserver_name) let user_id = UserId::new(username, &config.homeserver_name)
.ok() .ok()
.ok_or(RegistrationError::InvalidUserId)?; .ok_or(RegistrationError::InvalidUserId)?;
if User::exists(&db, &user_id).await.unwrap() { User::exists(&db, &user_id)
todo!("Error out") .await?
} .then(|| ())
.ok_or(RegistrationError::UserIdTaken)?;
let password = auth_data.password(); let password = auth_data.password();
@ -91,12 +86,8 @@ async fn post_register(
None => "Random displayname", None => "Random displayname",
}; };
let user = User::create(&db, &user_id, &user_id.to_string(), password) let user = User::create(&db, &user_id, &user_id.to_string(), password).await?;
.await let device = Device::create(&db, &user, "test", display_name).await?;
.unwrap();
let device = Device::create(&db, &user, "test", display_name)
.await
.unwrap();
(user, device) (user, device)
} }
@ -105,12 +96,12 @@ async fn post_register(
if body.inhibit_login().unwrap_or(false) { if body.inhibit_login().unwrap_or(false) {
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id()); let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp)))) Ok(Json(RegistrationResponse::Success(resp)))
} else { } else {
let session = device.create_session(&db).await.unwrap(); let session = device.create_session(&db).await?;
let resp = let resp =
RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id()); RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id());
Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp)))) Ok(Json(RegistrationResponse::Success(resp)))
} }
} }

View File

@ -1,5 +1,10 @@
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::Json;
use sqlx::Statement;
use crate::responses::registration::RegistrationResponse;
use crate::types::error_code::ErrorCode;
use super::registration_error::RegistrationError; use super::registration_error::RegistrationError;
@ -13,6 +18,24 @@ macro_rules! map_err {
} }
} }
#[derive(Debug, serde::Serialize)]
struct ErrorResponse {
errcode: ErrorCode,
error: String,
#[serde(skip_serializing_if = "Option::is_none")]
retry_after_ms: Option<u64>,
}
impl ErrorResponse {
fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
Self {
errcode,
error: error.to_owned(),
retry_after_ms,
}
}
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ApiError { pub enum ApiError {
#[error("Registration Error")] #[error("Registration Error")]
@ -40,14 +63,50 @@ impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
ApiError::RegistrationError(registration_error) => match registration_error { ApiError::RegistrationError(registration_error) => match registration_error {
RegistrationError::InvalidUserId => { RegistrationError::AdditionalAuthenticationInformation => (
(StatusCode::OK, String::new()).into_response() StatusCode::UNAUTHORIZED,
} Json(RegistrationResponse::user_interactive_authorization_info()),
RegistrationError::MissingUserId => { ).into_response(),
(StatusCode::OK, String::new()).into_response() RegistrationError::InvalidUserId => (StatusCode::OK, Json(
} ErrorResponse::new(
ErrorCode::InvalidUsername,
&registration_error.to_string(),
None,
)
)).into_response(),
RegistrationError::MissingUserId => (StatusCode::OK, String::new()).into_response(),
RegistrationError::UserIdTaken => (
StatusCode::BAD_REQUEST,
Json(ErrorResponse::new(
ErrorCode::UserInUse,
&registration_error.to_string(),
None,
)),
)
.into_response(),
}, },
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response() ApiError::DBError(err) => {
tracing::error!("{}", err.to_string());
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new(
ErrorCode::Unknown,
"Database error! If you are the application owner please take a look at your application logs.",
None,
)),
)
.into_response()
}
ApiError::Generic(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new(
ErrorCode::Unknown,
"Fatal error occured! If you are the application owner please take a look at your application logs.",
None,
)),
)
.into_response(),
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
} }
} }
} }

View File

@ -1,7 +1,11 @@
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum RegistrationError { pub enum RegistrationError {
#[error("The homeserver requires additional authentication information")]
AdditionalAuthenticationInformation,
#[error("UserId is missing")] #[error("UserId is missing")]
MissingUserId, MissingUserId,
#[error("UserId is invalid")] #[error("The desired user ID is not a valid user name")]
InvalidUserId, InvalidUserId,
#[error("The desired user ID is already taken")]
UserIdTaken,
} }

36
src/types/error_code.rs Normal file
View File

@ -0,0 +1,36 @@
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum ErrorCode {
Forbidden,
UnknownToken,
MissingToken,
BadJson,
NotJson,
NotFound,
LimitExceeded,
Unknown,
UserInUse,
InvalidUsername,
Exclusive,
}
impl serde::Serialize for ErrorCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(match self {
ErrorCode::Forbidden => "M_FORBIDDEN",
ErrorCode::UnknownToken => "M_UNKNOWN_TOKEN",
ErrorCode::MissingToken => "M_MISSING_TOKEN",
ErrorCode::BadJson => "M_BAD_JSON",
ErrorCode::NotJson => "M_NOT_JSON",
ErrorCode::NotFound => "M_NOT_FOUND",
ErrorCode::LimitExceeded => "M_LIMIT_EXCEEDED",
ErrorCode::Unknown => "M_UNKNOWN",
ErrorCode::UserInUse => "M_USER_IN_USE",
ErrorCode::InvalidUsername => "M_INVALID_USERNAME",
ErrorCode::Exclusive => "M_EXCLUSIVE",
})
}
}

View File

@ -4,3 +4,4 @@ pub mod identifier;
pub mod identifier_type; pub mod identifier_type;
pub mod matrix_user_id; pub mod matrix_user_id;
pub mod user_interactive_authorization; pub mod user_interactive_authorization;
pub mod error_code;