use switch to ruma and remove unneeded code
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@ -1,9 +1,7 @@
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use sqlx::Statement;
|
||||
|
||||
use crate::responses::registration::RegistrationResponse;
|
||||
use crate::types::error_code::ErrorCode;
|
||||
|
||||
use super::authentication_error::AuthenticationError;
|
||||
|
@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||
use axum::{
|
||||
extract::Query,
|
||||
routing::{get, post},
|
||||
Extension, Json,
|
||||
Extension,
|
||||
};
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
@ -12,23 +12,14 @@ use crate::{
|
||||
api_error::ApiError, authentication_error::AuthenticationError,
|
||||
registration_error::RegistrationError,
|
||||
},
|
||||
models::sessions::Session,
|
||||
responses::{
|
||||
authentication::{AuthenticationResponse, AuthenticationSuccess},
|
||||
registration::RegistrationResponse,
|
||||
},
|
||||
ruma_wrapper::RumaResponse,
|
||||
};
|
||||
use crate::{models::devices::Device, responses::registration::RegistrationSuccess};
|
||||
use crate::{
|
||||
models::users::User,
|
||||
requests::registration::RegistrationRequest,
|
||||
types::{authentication_data::AuthenticationData, user_id::UserId},
|
||||
models::{devices::Device, sessions::Session, users::User},
|
||||
ruma_wrapper::{RumaRequest, RumaResponse},
|
||||
Config,
|
||||
};
|
||||
|
||||
use ruma::api::client::{
|
||||
account,
|
||||
session::get_login_types::v3::{LoginType, PasswordLoginType},
|
||||
account, session,
|
||||
uiaa::{IncomingAuthData, IncomingUserIdentifier},
|
||||
};
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
@ -38,9 +29,10 @@ pub fn routes() -> axum::Router {
|
||||
.route("/r0/register/available", get(get_username_available))
|
||||
}
|
||||
|
||||
use ruma::api::client::session;
|
||||
#[tracing::instrument]
|
||||
async fn get_login() -> Result<RumaResponse<session::get_login_types::v3::Response>, ApiError> {
|
||||
use session::get_login_types::v3::*;
|
||||
|
||||
Ok(RumaResponse(session::get_login_types::v3::Response::new(
|
||||
vec![LoginType::Password(PasswordLoginType::new())],
|
||||
)))
|
||||
@ -48,43 +40,50 @@ async fn get_login() -> Result<RumaResponse<session::get_login_types::v3::Respon
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn post_login(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Json(body): Json<AuthenticationData>,
|
||||
) -> Result<Json<AuthenticationResponse>, ApiError> {
|
||||
match body {
|
||||
AuthenticationData::Password(auth_data) => {
|
||||
let user = auth_data.user().unwrap();
|
||||
let user_id = UserId::new(user, config.server_name())
|
||||
.ok()
|
||||
.ok_or(AuthenticationError::InvalidUserId)?;
|
||||
RumaRequest(req): RumaRequest<session::login::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<session::login::v3::Response>, ApiError> {
|
||||
use session::login::v3::*;
|
||||
|
||||
let user = User::find_by_user_id(&db, &user_id).await?;
|
||||
match req.login_info {
|
||||
IncomingLoginInfo::Password(incoming_password) => {
|
||||
let password = incoming_password.password;
|
||||
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
|
||||
incoming_password.identifier
|
||||
{
|
||||
ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))?
|
||||
} else {
|
||||
return Err(AuthenticationError::InvalidUserId.into())
|
||||
};
|
||||
|
||||
user.password_correct(auth_data.password())
|
||||
let db_user = User::find_by_user_id(&db, user_id.as_str()).await?;
|
||||
db_user
|
||||
.password_correct(&password)
|
||||
.ok()
|
||||
.ok_or(AuthenticationError::Forbidden)?;
|
||||
|
||||
let device = if let Some(device_id) = auth_data.device_id() {
|
||||
Device::find_for_user(&db, &user, device_id).await?
|
||||
let device = if let Some(device_id) = req.device_id {
|
||||
Device::find_for_user(&db, &db_user, device_id.as_str()).await?
|
||||
} else {
|
||||
let device_id = uuid::Uuid::new_v4().to_string();
|
||||
let display_name =
|
||||
if let Some(display_name) = auth_data.initial_device_display_name() {
|
||||
display_name.as_ref()
|
||||
} else {
|
||||
"Generic Device"
|
||||
};
|
||||
Device::new(&user, &device_id, display_name)?
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Generic Device".into());
|
||||
Device::new(&db_user, &device_id, &display_name)?
|
||||
.create(&db)
|
||||
.await?
|
||||
};
|
||||
|
||||
let session = Session::new(&device)?.create(&db).await?;
|
||||
let response = Response::new(
|
||||
user_id,
|
||||
session.key,
|
||||
ruma::OwnedDeviceId::from(device.device_id),
|
||||
);
|
||||
|
||||
let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id);
|
||||
Ok(Json(AuthenticationResponse::Success(resp)))
|
||||
return Ok(RumaResponse(response));
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,76 +93,77 @@ async fn get_username_available(
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<RumaResponse<account::get_username_availability::v3::Response>, ApiError> {
|
||||
use account::get_username_availability::v3::*;
|
||||
|
||||
let username = params
|
||||
.get("username")
|
||||
.ok_or(RegistrationError::MissingUserId)?;
|
||||
let user_id = UserId::new(username, config.server_name())
|
||||
.ok()
|
||||
.ok_or(RegistrationError::InvalidUserId)?;
|
||||
let user_id = ruma::UserId::parse(username).map_err(|_| RegistrationError::InvalidUserId)?;
|
||||
let exists = User::exists(&db, &user_id).await?;
|
||||
|
||||
Ok(RumaResponse(
|
||||
account::get_username_availability::v3::Response::new(!exists),
|
||||
))
|
||||
Ok(RumaResponse(Response::new(!exists)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn post_register(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Json(body): Json<RegistrationRequest>,
|
||||
) -> Result<Json<RegistrationResponse>, ApiError> {
|
||||
RumaRequest(req): RumaRequest<account::register::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<account::register::v3::Response>, ApiError> {
|
||||
use account::register::v3::*;
|
||||
|
||||
config
|
||||
.enable_registration()
|
||||
.then(|| true)
|
||||
.ok_or(RegistrationError::RegistrationDisabled)?;
|
||||
body.auth()
|
||||
.ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
|
||||
|
||||
let (user, device) = match &body.auth().expect("must be Some") {
|
||||
AuthenticationData::Password(auth_data) => {
|
||||
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
|
||||
let user_id = UserId::new(username, config.server_name())
|
||||
.ok()
|
||||
.ok_or(RegistrationError::InvalidUserId)?;
|
||||
match req.auth {
|
||||
Some(auth) => match auth {
|
||||
IncomingAuthData::Password(incoming_password) => {
|
||||
let password = incoming_password.password;
|
||||
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
|
||||
incoming_password.identifier
|
||||
{
|
||||
ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))?
|
||||
} else {
|
||||
Err(AuthenticationError::InvalidUserId)?
|
||||
};
|
||||
|
||||
if User::exists(&db, &user_id).await? {
|
||||
return Err(ApiError::from(RegistrationError::UserIdTaken));
|
||||
}
|
||||
if User::exists(&db, &user_id).await? {
|
||||
return Err(ApiError::from(RegistrationError::UserIdTaken));
|
||||
}
|
||||
|
||||
let display_name = match body.initial_device_display_name() {
|
||||
Some(display_name) => display_name.as_ref(),
|
||||
None => "Random displayname",
|
||||
};
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Random Display Name".into());
|
||||
|
||||
let user = User::new(&user_id, &user_id.to_string(), auth_data.password())?
|
||||
let user = User::new(&user_id, &user_id.to_string(), &password)?
|
||||
.create(&db)
|
||||
.await?;
|
||||
|
||||
let device = Device::new(
|
||||
&user,
|
||||
uuid::Uuid::new_v4().to_string().as_ref(),
|
||||
&display_name,
|
||||
)?
|
||||
.create(&db)
|
||||
.await?;
|
||||
let mut response = Response::new(
|
||||
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
|
||||
);
|
||||
if !req.inhibit_login {
|
||||
let session = Session::new(&device)?.create(&db).await?;
|
||||
response.access_token = Some(session.key);
|
||||
}
|
||||
if !req.inhibit_login {
|
||||
response.device_id = Some(ruma::OwnedDeviceId::from(device.device_id));
|
||||
}
|
||||
|
||||
let device = Device::new(
|
||||
&user,
|
||||
uuid::Uuid::new_v4().to_string().as_ref(),
|
||||
display_name,
|
||||
)?
|
||||
.create(&db)
|
||||
.await?;
|
||||
|
||||
(user, device)
|
||||
}
|
||||
return Ok(RumaResponse(response));
|
||||
}
|
||||
_ => todo!(),
|
||||
},
|
||||
None => Err(RegistrationError::AdditionalAuthenticationInformation)?,
|
||||
};
|
||||
|
||||
if body.inhibit_login().unwrap_or(false) {
|
||||
let resp = RegistrationSuccess::new(None, device.device_id(), &user.user_id().to_string());
|
||||
|
||||
Ok(Json(RegistrationResponse::Success(resp)))
|
||||
} else {
|
||||
let session = Session::new(&device)?.create(&db).await?;
|
||||
let resp = RegistrationSuccess::new(
|
||||
Some(session.key()),
|
||||
device.device_id(),
|
||||
&user.user_id().to_string(),
|
||||
);
|
||||
|
||||
Ok(Json(RegistrationResponse::Success(resp)))
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ use crate::{models::sessions::Session, types::error_code::ErrorCode};
|
||||
use super::errors::ErrorResponse;
|
||||
|
||||
pub mod auth;
|
||||
pub mod thirdparty;
|
||||
pub mod create_room;
|
||||
pub mod thirdparty;
|
||||
|
||||
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let db: &SqlitePool = req.extensions().get().unwrap();
|
||||
|
@ -1,12 +1,16 @@
|
||||
use axum::{routing::get, Json};
|
||||
use axum::routing::get;
|
||||
|
||||
use crate::responses::versions::Versions;
|
||||
use crate::ruma_wrapper::RumaResponse;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new().route("/versions", get(get_client_versions))
|
||||
}
|
||||
|
||||
use ruma::api::client::discovery;
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_client_versions() -> Json<Versions> {
|
||||
Json(Versions::default())
|
||||
async fn get_client_versions() -> RumaResponse<discovery::get_supported_versions::Response> {
|
||||
use discovery::get_supported_versions::*;
|
||||
|
||||
RumaResponse(Response::new(vec!["v1.2".into()]))
|
||||
}
|
||||
|
Reference in New Issue
Block a user