use switch to ruma and remove unneeded code
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2022-06-26 21:07:11 +02:00
parent 71590d6c60
commit 29093c51e3
31 changed files with 222 additions and 631 deletions

View File

@ -1,9 +1,7 @@
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Json;
use sqlx::Statement;
use crate::responses::registration::RegistrationResponse;
use crate::types::error_code::ErrorCode;
use super::authentication_error::AuthenticationError;

View File

@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
use axum::{
extract::Query,
routing::{get, post},
Extension, Json,
Extension,
};
use sqlx::SqlitePool;
@ -12,23 +12,14 @@ use crate::{
api_error::ApiError, authentication_error::AuthenticationError,
registration_error::RegistrationError,
},
models::sessions::Session,
responses::{
authentication::{AuthenticationResponse, AuthenticationSuccess},
registration::RegistrationResponse,
},
ruma_wrapper::RumaResponse,
};
use crate::{models::devices::Device, responses::registration::RegistrationSuccess};
use crate::{
models::users::User,
requests::registration::RegistrationRequest,
types::{authentication_data::AuthenticationData, user_id::UserId},
models::{devices::Device, sessions::Session, users::User},
ruma_wrapper::{RumaRequest, RumaResponse},
Config,
};
use ruma::api::client::{
account,
session::get_login_types::v3::{LoginType, PasswordLoginType},
account, session,
uiaa::{IncomingAuthData, IncomingUserIdentifier},
};
pub fn routes() -> axum::Router {
@ -38,9 +29,10 @@ pub fn routes() -> axum::Router {
.route("/r0/register/available", get(get_username_available))
}
use ruma::api::client::session;
#[tracing::instrument]
async fn get_login() -> Result<RumaResponse<session::get_login_types::v3::Response>, ApiError> {
use session::get_login_types::v3::*;
Ok(RumaResponse(session::get_login_types::v3::Response::new(
vec![LoginType::Password(PasswordLoginType::new())],
)))
@ -48,43 +40,50 @@ async fn get_login() -> Result<RumaResponse<session::get_login_types::v3::Respon
#[tracing::instrument(skip_all)]
async fn post_login(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>,
Json(body): Json<AuthenticationData>,
) -> Result<Json<AuthenticationResponse>, ApiError> {
match body {
AuthenticationData::Password(auth_data) => {
let user = auth_data.user().unwrap();
let user_id = UserId::new(user, config.server_name())
.ok()
.ok_or(AuthenticationError::InvalidUserId)?;
RumaRequest(req): RumaRequest<session::login::v3::IncomingRequest>,
) -> Result<RumaResponse<session::login::v3::Response>, ApiError> {
use session::login::v3::*;
let user = User::find_by_user_id(&db, &user_id).await?;
match req.login_info {
IncomingLoginInfo::Password(incoming_password) => {
let password = incoming_password.password;
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
incoming_password.identifier
{
ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))?
} else {
return Err(AuthenticationError::InvalidUserId.into())
};
user.password_correct(auth_data.password())
let db_user = User::find_by_user_id(&db, user_id.as_str()).await?;
db_user
.password_correct(&password)
.ok()
.ok_or(AuthenticationError::Forbidden)?;
let device = if let Some(device_id) = auth_data.device_id() {
Device::find_for_user(&db, &user, device_id).await?
let device = if let Some(device_id) = req.device_id {
Device::find_for_user(&db, &db_user, device_id.as_str()).await?
} else {
let device_id = uuid::Uuid::new_v4().to_string();
let display_name =
if let Some(display_name) = auth_data.initial_device_display_name() {
display_name.as_ref()
} else {
"Generic Device"
};
Device::new(&user, &device_id, display_name)?
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Generic Device".into());
Device::new(&db_user, &device_id, &display_name)?
.create(&db)
.await?
};
let session = Session::new(&device)?.create(&db).await?;
let response = Response::new(
user_id,
session.key,
ruma::OwnedDeviceId::from(device.device_id),
);
let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id);
Ok(Json(AuthenticationResponse::Success(resp)))
return Ok(RumaResponse(response));
}
_ => todo!(),
}
}
@ -94,76 +93,77 @@ async fn get_username_available(
Extension(db): Extension<SqlitePool>,
Query(params): Query<HashMap<String, String>>,
) -> Result<RumaResponse<account::get_username_availability::v3::Response>, ApiError> {
use account::get_username_availability::v3::*;
let username = params
.get("username")
.ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, config.server_name())
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
let user_id = ruma::UserId::parse(username).map_err(|_| RegistrationError::InvalidUserId)?;
let exists = User::exists(&db, &user_id).await?;
Ok(RumaResponse(
account::get_username_availability::v3::Response::new(!exists),
))
Ok(RumaResponse(Response::new(!exists)))
}
#[tracing::instrument(skip_all)]
async fn post_register(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>,
Json(body): Json<RegistrationRequest>,
) -> Result<Json<RegistrationResponse>, ApiError> {
RumaRequest(req): RumaRequest<account::register::v3::IncomingRequest>,
) -> Result<RumaResponse<account::register::v3::Response>, ApiError> {
use account::register::v3::*;
config
.enable_registration()
.then(|| true)
.ok_or(RegistrationError::RegistrationDisabled)?;
body.auth()
.ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
let (user, device) = match &body.auth().expect("must be Some") {
AuthenticationData::Password(auth_data) => {
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, config.server_name())
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
match req.auth {
Some(auth) => match auth {
IncomingAuthData::Password(incoming_password) => {
let password = incoming_password.password;
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
incoming_password.identifier
{
ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))?
} else {
Err(AuthenticationError::InvalidUserId)?
};
if User::exists(&db, &user_id).await? {
return Err(ApiError::from(RegistrationError::UserIdTaken));
}
if User::exists(&db, &user_id).await? {
return Err(ApiError::from(RegistrationError::UserIdTaken));
}
let display_name = match body.initial_device_display_name() {
Some(display_name) => display_name.as_ref(),
None => "Random displayname",
};
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Random Display Name".into());
let user = User::new(&user_id, &user_id.to_string(), auth_data.password())?
let user = User::new(&user_id, &user_id.to_string(), &password)?
.create(&db)
.await?;
let device = Device::new(
&user,
uuid::Uuid::new_v4().to_string().as_ref(),
&display_name,
)?
.create(&db)
.await?;
let mut response = Response::new(
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
);
if !req.inhibit_login {
let session = Session::new(&device)?.create(&db).await?;
response.access_token = Some(session.key);
}
if !req.inhibit_login {
response.device_id = Some(ruma::OwnedDeviceId::from(device.device_id));
}
let device = Device::new(
&user,
uuid::Uuid::new_v4().to_string().as_ref(),
display_name,
)?
.create(&db)
.await?;
(user, device)
}
return Ok(RumaResponse(response));
}
_ => todo!(),
},
None => Err(RegistrationError::AdditionalAuthenticationInformation)?,
};
if body.inhibit_login().unwrap_or(false) {
let resp = RegistrationSuccess::new(None, device.device_id(), &user.user_id().to_string());
Ok(Json(RegistrationResponse::Success(resp)))
} else {
let session = Session::new(&device)?.create(&db).await?;
let resp = RegistrationSuccess::new(
Some(session.key()),
device.device_id(),
&user.user_id().to_string(),
);
Ok(Json(RegistrationResponse::Success(resp)))
}
unreachable!()
}

View File

@ -13,8 +13,8 @@ use crate::{models::sessions::Session, types::error_code::ErrorCode};
use super::errors::ErrorResponse;
pub mod auth;
pub mod thirdparty;
pub mod create_room;
pub mod thirdparty;
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let db: &SqlitePool = req.extensions().get().unwrap();

View File

@ -1,12 +1,16 @@
use axum::{routing::get, Json};
use axum::routing::get;
use crate::responses::versions::Versions;
use crate::ruma_wrapper::RumaResponse;
pub fn routes() -> axum::Router {
axum::Router::new().route("/versions", get(get_client_versions))
}
use ruma::api::client::discovery;
#[tracing::instrument]
async fn get_client_versions() -> Json<Versions> {
Json(Versions::default())
async fn get_client_versions() -> RumaResponse<discovery::get_supported_versions::Response> {
use discovery::get_supported_versions::*;
RumaResponse(Response::new(vec!["v1.2".into()]))
}