use switch to ruma and remove unneeded code
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Patrick Michl 2022-06-26 21:07:11 +02:00
parent 71590d6c60
commit 29093c51e3
31 changed files with 222 additions and 631 deletions

View File

@ -3,7 +3,7 @@ CREATE TABLE users(
uuid TEXT PRIMARY KEY NOT NULL,
user_id CHAR(255) NOT NULL,
display_name TEXT NOT NULL,
password TEXT NOT NULL
password_hash TEXT NOT NULL
);
CREATE INDEX user_id_index ON users (user_id);

View File

@ -102,7 +102,7 @@
},
"query": "insert into sessions(uuid, device_uuid, key)\n values(?, ?, ?)\n returning uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key"
},
"383949b72c69bca95bf23ef06900cd1ac5a136cdd4a525cbb624d327ce0cdefb": {
"2b3409859921423dc051ce76a0166116f39ca7f26053bac5bde0a61313bfd68c": {
"describe": {
"columns": [
{
@ -111,7 +111,7 @@
"type_info": "Text"
},
{
"name": "user_id: UserId",
"name": "user_id",
"ordinal": 1,
"type_info": "Text"
},
@ -121,7 +121,7 @@
"type_info": "Text"
},
{
"name": "password",
"name": "password_hash",
"ordinal": 3,
"type_info": "Text"
}
@ -136,9 +136,9 @@
"Right": 1
}
},
"query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where uuid = ?"
"query": "select uuid as 'uuid: Uuid', user_id, display_name, password_hash\n from users where user_id = ?"
},
"3fead3dac0e110757bc30be40bb0c6c2bc02127b6d9b6145bfc40fa5fe22ad06": {
"33f7c796b21878b2f06f3e012ada151226bd1ab58677ca6acc4edb10e0e1493a": {
"describe": {
"columns": [
{
@ -147,7 +147,7 @@
"type_info": "Text"
},
{
"name": "user_id: UserId",
"name": "user_id",
"ordinal": 1,
"type_info": "Text"
},
@ -157,7 +157,7 @@
"type_info": "Text"
},
{
"name": "password",
"name": "password_hash",
"ordinal": 3,
"type_info": "Text"
}
@ -172,7 +172,7 @@
"Right": 4
}
},
"query": "insert into users(uuid, user_id, display_name, password)\n values (?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password"
"query": "insert into users(uuid, user_id, display_name, password_hash)\n values (?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_id, display_name, password_hash"
},
"58d27b1d424297504f1da2e3b9b4020121251c1155fbf5dc870dafbef97659f3": {
"describe": {
@ -192,7 +192,7 @@
},
"query": "select user_id from users where user_id = ?"
},
"778b7f0a1c66f00812f0232a5904b7b6b295720ebb75d1c2720afeeda4f66936": {
"9673bbe9506ba700923467fe8aaa141f9030158790db74234c13a5800adf2575": {
"describe": {
"columns": [
{
@ -201,7 +201,7 @@
"type_info": "Text"
},
{
"name": "user_id: UserId",
"name": "user_id",
"ordinal": 1,
"type_info": "Text"
},
@ -211,7 +211,7 @@
"type_info": "Text"
},
{
"name": "password",
"name": "password_hash",
"ordinal": 3,
"type_info": "Text"
}
@ -223,10 +223,46 @@
false
],
"parameters": {
"Right": 5
"Right": 1
}
},
"query": "update users set uuid = ?, user_id = ?, display_name = ?, password = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password"
"query": "select uuid as 'uuid: Uuid', user_id, display_name, password_hash\n from users where uuid = ?"
},
"9ee4afab2c653a23144bcb05943aa4ff1e8dc1ae5baa9c87827b52671ae47784": {
"describe": {
"columns": [
{
"name": "uuid: Uuid",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "user_uuid: Uuid",
"ordinal": 1,
"type_info": "Int64"
},
{
"name": "device_id",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "display_name",
"ordinal": 3,
"type_info": "Text"
}
],
"nullable": [
false,
false,
false,
false
],
"parameters": {
"Right": 2
}
},
"query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ? and device_id = ?"
},
"b38fd90504bea0c63e6517738c2354e6b057fcc6c643283019b27689e286bf2d": {
"describe": {
@ -258,7 +294,7 @@
},
"query": "select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key\n from sessions where key = ?"
},
"ddcc531c080b2a1c70166d29a940aaada6701abe2933c305a879e7f18baeaf3a": {
"f1b148ebcfe22d9680b8ea3dc3c334523496e88fca724ec7d08c5e2948e58526": {
"describe": {
"columns": [
{
@ -267,17 +303,17 @@
"type_info": "Text"
},
{
"name": "user_uuid: Uuid",
"name": "user_id",
"ordinal": 1,
"type_info": "Int64"
},
{
"name": "device_id",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "display_name",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "password_hash",
"ordinal": 3,
"type_info": "Text"
}
@ -289,45 +325,9 @@
false
],
"parameters": {
"Right": 1
"Right": 5
}
},
"query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?"
},
"f57c49e5390c81f971851ff9ab35242a472b9efbb1ffa658de9b102188769750": {
"describe": {
"columns": [
{
"name": "uuid: Uuid",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "user_id: UserId",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "display_name",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "password",
"ordinal": 3,
"type_info": "Text"
}
],
"nullable": [
false,
false,
false,
false
],
"parameters": {
"Right": 1
}
},
"query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where user_id = ?"
"query": "update users set uuid = ?, user_id = ?, display_name = ?, password_hash = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id, display_name, password_hash"
}
}

View File

@ -1,9 +1,7 @@
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Json;
use sqlx::Statement;
use crate::responses::registration::RegistrationResponse;
use crate::types::error_code::ErrorCode;
use super::authentication_error::AuthenticationError;

View File

@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
use axum::{
extract::Query,
routing::{get, post},
Extension, Json,
Extension,
};
use sqlx::SqlitePool;
@ -12,23 +12,14 @@ use crate::{
api_error::ApiError, authentication_error::AuthenticationError,
registration_error::RegistrationError,
},
models::sessions::Session,
responses::{
authentication::{AuthenticationResponse, AuthenticationSuccess},
registration::RegistrationResponse,
},
ruma_wrapper::RumaResponse,
};
use crate::{models::devices::Device, responses::registration::RegistrationSuccess};
use crate::{
models::users::User,
requests::registration::RegistrationRequest,
types::{authentication_data::AuthenticationData, user_id::UserId},
models::{devices::Device, sessions::Session, users::User},
ruma_wrapper::{RumaRequest, RumaResponse},
Config,
};
use ruma::api::client::{
account,
session::get_login_types::v3::{LoginType, PasswordLoginType},
account, session,
uiaa::{IncomingAuthData, IncomingUserIdentifier},
};
pub fn routes() -> axum::Router {
@ -38,9 +29,10 @@ pub fn routes() -> axum::Router {
.route("/r0/register/available", get(get_username_available))
}
use ruma::api::client::session;
#[tracing::instrument]
async fn get_login() -> Result<RumaResponse<session::get_login_types::v3::Response>, ApiError> {
use session::get_login_types::v3::*;
Ok(RumaResponse(session::get_login_types::v3::Response::new(
vec![LoginType::Password(PasswordLoginType::new())],
)))
@ -48,43 +40,50 @@ async fn get_login() -> Result<RumaResponse<session::get_login_types::v3::Respon
#[tracing::instrument(skip_all)]
async fn post_login(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>,
Json(body): Json<AuthenticationData>,
) -> Result<Json<AuthenticationResponse>, ApiError> {
match body {
AuthenticationData::Password(auth_data) => {
let user = auth_data.user().unwrap();
let user_id = UserId::new(user, config.server_name())
.ok()
.ok_or(AuthenticationError::InvalidUserId)?;
RumaRequest(req): RumaRequest<session::login::v3::IncomingRequest>,
) -> Result<RumaResponse<session::login::v3::Response>, ApiError> {
use session::login::v3::*;
let user = User::find_by_user_id(&db, &user_id).await?;
match req.login_info {
IncomingLoginInfo::Password(incoming_password) => {
let password = incoming_password.password;
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
incoming_password.identifier
{
ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))?
} else {
return Err(AuthenticationError::InvalidUserId.into())
};
user.password_correct(auth_data.password())
let db_user = User::find_by_user_id(&db, user_id.as_str()).await?;
db_user
.password_correct(&password)
.ok()
.ok_or(AuthenticationError::Forbidden)?;
let device = if let Some(device_id) = auth_data.device_id() {
Device::find_for_user(&db, &user, device_id).await?
let device = if let Some(device_id) = req.device_id {
Device::find_for_user(&db, &db_user, device_id.as_str()).await?
} else {
let device_id = uuid::Uuid::new_v4().to_string();
let display_name =
if let Some(display_name) = auth_data.initial_device_display_name() {
display_name.as_ref()
} else {
"Generic Device"
};
Device::new(&user, &device_id, display_name)?
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Generic Device".into());
Device::new(&db_user, &device_id, &display_name)?
.create(&db)
.await?
};
let session = Session::new(&device)?.create(&db).await?;
let response = Response::new(
user_id,
session.key,
ruma::OwnedDeviceId::from(device.device_id),
);
let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id);
Ok(Json(AuthenticationResponse::Success(resp)))
return Ok(RumaResponse(response));
}
_ => todo!(),
}
}
@ -94,76 +93,77 @@ async fn get_username_available(
Extension(db): Extension<SqlitePool>,
Query(params): Query<HashMap<String, String>>,
) -> Result<RumaResponse<account::get_username_availability::v3::Response>, ApiError> {
use account::get_username_availability::v3::*;
let username = params
.get("username")
.ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, config.server_name())
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
let user_id = ruma::UserId::parse(username).map_err(|_| RegistrationError::InvalidUserId)?;
let exists = User::exists(&db, &user_id).await?;
Ok(RumaResponse(
account::get_username_availability::v3::Response::new(!exists),
))
Ok(RumaResponse(Response::new(!exists)))
}
#[tracing::instrument(skip_all)]
async fn post_register(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>,
Json(body): Json<RegistrationRequest>,
) -> Result<Json<RegistrationResponse>, ApiError> {
RumaRequest(req): RumaRequest<account::register::v3::IncomingRequest>,
) -> Result<RumaResponse<account::register::v3::Response>, ApiError> {
use account::register::v3::*;
config
.enable_registration()
.then(|| true)
.ok_or(RegistrationError::RegistrationDisabled)?;
body.auth()
.ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
let (user, device) = match &body.auth().expect("must be Some") {
AuthenticationData::Password(auth_data) => {
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, config.server_name())
.ok()
.ok_or(RegistrationError::InvalidUserId)?;
match req.auth {
Some(auth) => match auth {
IncomingAuthData::Password(incoming_password) => {
let password = incoming_password.password;
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
incoming_password.identifier
{
ruma::UserId::parse(user_id).map_err(|e| anyhow::anyhow!(e))?
} else {
Err(AuthenticationError::InvalidUserId)?
};
if User::exists(&db, &user_id).await? {
return Err(ApiError::from(RegistrationError::UserIdTaken));
}
if User::exists(&db, &user_id).await? {
return Err(ApiError::from(RegistrationError::UserIdTaken));
}
let display_name = match body.initial_device_display_name() {
Some(display_name) => display_name.as_ref(),
None => "Random displayname",
};
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Random Display Name".into());
let user = User::new(&user_id, &user_id.to_string(), auth_data.password())?
let user = User::new(&user_id, &user_id.to_string(), &password)?
.create(&db)
.await?;
let device = Device::new(
&user,
uuid::Uuid::new_v4().to_string().as_ref(),
&display_name,
)?
.create(&db)
.await?;
let mut response = Response::new(
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
);
if !req.inhibit_login {
let session = Session::new(&device)?.create(&db).await?;
response.access_token = Some(session.key);
}
if !req.inhibit_login {
response.device_id = Some(ruma::OwnedDeviceId::from(device.device_id));
}
let device = Device::new(
&user,
uuid::Uuid::new_v4().to_string().as_ref(),
display_name,
)?
.create(&db)
.await?;
(user, device)
}
return Ok(RumaResponse(response));
}
_ => todo!(),
},
None => Err(RegistrationError::AdditionalAuthenticationInformation)?,
};
if body.inhibit_login().unwrap_or(false) {
let resp = RegistrationSuccess::new(None, device.device_id(), &user.user_id().to_string());
Ok(Json(RegistrationResponse::Success(resp)))
} else {
let session = Session::new(&device)?.create(&db).await?;
let resp = RegistrationSuccess::new(
Some(session.key()),
device.device_id(),
&user.user_id().to_string(),
);
Ok(Json(RegistrationResponse::Success(resp)))
}
unreachable!()
}

View File

@ -13,8 +13,8 @@ use crate::{models::sessions::Session, types::error_code::ErrorCode};
use super::errors::ErrorResponse;
pub mod auth;
pub mod thirdparty;
pub mod create_room;
pub mod thirdparty;
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let db: &SqlitePool = req.extensions().get().unwrap();

View File

@ -1,12 +1,16 @@
use axum::{routing::get, Json};
use axum::routing::get;
use crate::responses::versions::Versions;
use crate::ruma_wrapper::RumaResponse;
pub fn routes() -> axum::Router {
axum::Router::new().route("/versions", get(get_client_versions))
}
use ruma::api::client::discovery;
#[tracing::instrument]
async fn get_client_versions() -> Json<Versions> {
Json(Versions::default())
async fn get_client_versions() -> RumaResponse<discovery::get_supported_versions::Response> {
use discovery::get_supported_versions::*;
RumaResponse(Response::new(vec!["v1.2".into()]))
}

View File

@ -3,7 +3,7 @@ use crate::types::server_name::ServerName;
pub struct Config {
db_path: String,
server_name: ServerName,
enable_registration: bool
enable_registration: bool,
}
impl Config {
@ -31,7 +31,7 @@ impl Default for Config {
Self {
db_path: "sqlite://db.sqlite3".into(),
server_name: ServerName::new("fuckwit.dev").unwrap(),
enable_registration: true
enable_registration: true,
}
}
}
}

View File

@ -12,10 +12,8 @@ use tower_http::{cors::CorsLayer, trace::TraceLayer};
mod api;
mod config;
mod models;
mod requests;
mod responses;
mod ruma_wrapper;
mod state_resolution;
mod types;
#[tokio::main]

View File

@ -2,20 +2,20 @@ use sqlx::SqlitePool;
use crate::types::uuid::Uuid;
use super::{sessions::Session, users::User};
use super::users::User;
pub struct Device {
uuid: Uuid,
user_uuid: Uuid,
device_id: String,
display_name: String,
pub uuid: Uuid,
pub user_uuid: Uuid,
pub device_id: String,
pub display_name: String,
}
impl Device {
pub fn new(user: &User, device_id: &str, display_name: &str) -> anyhow::Result<Self> {
Ok(Self {
uuid: uuid::Uuid::new_v4().into(),
user_uuid: user.uuid().clone(),
user_uuid: user.uuid.clone(),
device_id: device_id.to_owned(),
display_name: display_name.to_owned(),
})
@ -40,11 +40,11 @@ impl Device {
user: &User,
device_id: &str,
) -> anyhow::Result<Self> {
let user_uuid = user.uuid();
let user_uuid = user.uuid.clone();
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?",
user_uuid)
"select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ? and device_id = ?",
user_uuid, device_id)
.fetch_one(conn).await?
)
}
@ -67,22 +67,4 @@ impl Device {
pub fn uuid(&self) -> &Uuid {
&self.uuid
}
/// Get the device's user id.
#[must_use]
pub fn user_uuid(&self) -> &Uuid {
&self.user_uuid
}
/// Get a reference to the device's device id.
#[must_use]
pub fn device_id(&self) -> &str {
self.device_id.as_ref()
}
/// Get a reference to the device's display name.
#[must_use]
pub fn display_name(&self) -> &str {
self.display_name.as_ref()
}
}

View File

@ -1,8 +1,8 @@
use sqlx::SqlitePool;
/* use sqlx::SqlitePool;
use crate::types::{uuid::Uuid, event_type::EventType};
use super::{rooms::Room, users::User};
use super::{users::User};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Event {
@ -70,3 +70,4 @@ impl Event {
serde_json::from_str(&self.content).expect("has to be valid json")
}
}
*/

View File

@ -1,5 +1,4 @@
pub mod devices;
pub mod events;
pub mod rooms;
pub mod sessions;
pub mod users;

View File

@ -1,42 +0,0 @@
use sqlx::SqlitePool;
use crate::types::uuid::Uuid;
use super::events::Event;
pub struct Room {
uuid: Uuid,
name: String,
}
impl Room {
fn new(name: &str) -> anyhow::Result<Self> {
Ok(Self {
uuid: uuid::Uuid::new_v4().into(),
name: name.to_owned(),
})
}
pub async fn create(&self, conn: &SqlitePool) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"insert into rooms(uuid, name)
values(?, ?)
returning uuid as 'uuid: Uuid', name",
self.uuid,
self.name
)
.fetch_one(conn)
.await?)
}
pub async fn events(&self, conn: &SqlitePool) -> anyhow::Result<Vec<Event>> {
Event::all_for_room(conn, self).await
}
/// Get a reference to the room's uuid.
#[must_use]
pub fn uuid(&self) -> &Uuid {
&self.uuid
}
}

View File

@ -6,9 +6,9 @@ use crate::types::uuid::Uuid;
use super::devices::Device;
pub struct Session {
uuid: Uuid,
device_uuid: Uuid,
key: String,
pub uuid: Uuid,
pub device_uuid: Uuid,
pub key: String,
}
impl Session {
@ -51,22 +51,4 @@ impl Session {
pub async fn device(&self, conn: &SqlitePool) -> anyhow::Result<Device> {
Device::find_by_uuid(conn, &self.device_uuid).await
}
/// Get the session's id.
#[must_use]
pub fn uuid(&self) -> &Uuid {
&self.uuid
}
/// Get the session's device id.
#[must_use]
pub fn device_uuid(&self) -> &Uuid {
&self.device_uuid
}
/// Get a reference to the session's value.
#[must_use]
pub fn key(&self) -> &str {
self.key.as_ref()
}
}

View File

@ -1,35 +1,35 @@
use crate::types::uuid::Uuid;
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use rand::rngs::OsRng;
use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool};
use crate::types::user_id::UserId;
use ruma::OwnedUserId;
use sqlx::SqlitePool;
#[derive(Debug)]
pub struct User {
uuid: Uuid,
user_id: UserId,
display_name: String,
password: String,
pub uuid: Uuid,
pub user_id: String,
pub display_name: String,
pub password_hash: String,
}
impl User {
pub fn new(user_id: &UserId, display_name: &str, password: &str) -> anyhow::Result<Self> {
pub fn new(user_id: &OwnedUserId, display_name: &str, password: &str) -> anyhow::Result<Self> {
let argon2 = Argon2::default();
let salt = SaltString::generate(OsRng);
let password = argon2
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)?
.to_string();
Ok(Self {
uuid: uuid::Uuid::new_v4().into(),
user_id: user_id.clone(),
user_id: user_id.to_string(),
display_name: display_name.to_owned(),
password,
password_hash,
})
}
pub async fn exists(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<bool> {
pub async fn exists(conn: &SqlitePool, user_id: &OwnedUserId) -> anyhow::Result<bool> {
let user_id = user_id.to_string();
Ok(
sqlx::query!("select user_id from users where user_id = ?", user_id)
.fetch_optional(conn)
@ -41,13 +41,13 @@ impl User {
pub async fn create(&self, conn: &SqlitePool) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"insert into users(uuid, user_id, display_name, password)
"insert into users(uuid, user_id, display_name, password_hash)
values (?, ?, ?, ?)
returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password",
returning uuid as 'uuid: Uuid', user_id, display_name, password_hash",
self.uuid,
self.user_id,
self.display_name,
self.password
self.password_hash
)
.fetch_one(conn)
.await?)
@ -56,13 +56,13 @@ impl User {
pub async fn update(&self, conn: &SqlitePool) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"update users set uuid = ?, user_id = ?, display_name = ?, password = ?
"update users set uuid = ?, user_id = ?, display_name = ?, password_hash = ?
where uuid = ?
returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password",
returning uuid as 'uuid: Uuid', user_id, display_name, password_hash",
self.uuid,
self.user_id,
self.display_name,
self.password,
self.password_hash,
self.uuid
)
.fetch_one(conn)
@ -72,7 +72,7 @@ impl User {
pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password
"select uuid as 'uuid: Uuid', user_id, display_name, password_hash
from users where uuid = ?",
uuid
)
@ -80,10 +80,10 @@ impl User {
.await?)
}
pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
pub async fn find_by_user_id(conn: &SqlitePool, user_id: &str) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password
"select uuid as 'uuid: Uuid', user_id, display_name, password_hash
from users where user_id = ?",
user_id
)
@ -92,28 +92,10 @@ impl User {
}
pub fn password_correct(&self, password: &str) -> anyhow::Result<bool> {
let password_hash = PasswordHash::new(self.password())?;
let password_hash = PasswordHash::new(&self.password_hash)?;
Ok(Argon2::default()
.verify_password(password.as_bytes(), &password_hash)
.is_ok())
}
/// Get the user's id.
#[must_use]
pub fn uuid(&self) -> &Uuid {
&self.uuid
}
/// Get a reference to the user's user id.
#[must_use]
pub fn user_id(&self) -> &UserId {
&self.user_id
}
/// Get a reference to the user's password.
#[must_use]
pub fn password(&self) -> &str {
self.password.as_ref()
}
}

View File

@ -1,29 +0,0 @@
use crate::types::user_id::UserId;
#[derive(Debug, serde::Deserialize)]
pub struct CreateRoomRequest {
/// Extra keys, such as `m.federate`, to be added to the content of the `m.room.create` event.
creation_content: Option<()>,
/// List of state events to set in the initial room. Used for overriding the default state
initial_state: Vec<()>,
/// List of user IDs to invite to the room
invite: Option<Vec<String>>,
/// List of thirdparty IDs to invite to the room
invite_3pid: Option<Vec<()>>,
/// Indicate if room is a direct chat room
is_direct: Option<bool>,
/// Set name of the room
name: Option<String>,
/// Used to override the default power level event
power_level_content_override: Option<()>,
/// Preset for room creation
preset: Option<()>,
/// Desired room alias local part
room_alias_name: Option<String>,
/// Version of room to create. Defaults to server default
room_version: Option<String>,
/// Sets rooms topic
topic: Option<String>,
/// Sets rooms visibility
visibility: ()
}

View File

@ -1,2 +0,0 @@
pub mod registration;
pub mod create_room_request;

View File

@ -1,48 +0,0 @@
use crate::types::{authentication_data::AuthenticationData, flow::Flow, identifier::Identifier};
#[derive(Debug, serde::Deserialize)]
pub struct RegistrationRequest {
auth: Option<AuthenticationData>,
device_id: Option<String>,
inhibit_login: Option<bool>,
initial_device_display_name: Option<String>,
password: Option<String>,
username: Option<String>,
}
impl RegistrationRequest {
#[must_use]
pub fn auth(&self) -> Option<&AuthenticationData> {
self.auth.as_ref()
}
/// Get a reference to the registration request's device id.
#[must_use]
pub fn device_id(&self) -> Option<&String> {
self.device_id.as_ref()
}
/// Get the registration request's inhibit login.
#[must_use]
pub fn inhibit_login(&self) -> Option<bool> {
self.inhibit_login
}
/// Get a reference to the registration request's initial device display name.
#[must_use]
pub fn initial_device_display_name(&self) -> Option<&String> {
self.initial_device_display_name.as_ref()
}
/// Get a reference to the registration request's password.
#[must_use]
pub fn password(&self) -> Option<&String> {
self.password.as_ref()
}
/// Get a reference to the registration request's username.
#[must_use]
pub fn username(&self) -> Option<&String> {
self.username.as_ref()
}
}

View File

@ -1,26 +0,0 @@
use axum::{response::IntoResponse, Json};
use crate::types::user_id::UserId;
#[derive(Debug, serde::Serialize)]
#[serde(untagged)]
pub enum AuthenticationResponse {
Success(AuthenticationSuccess),
}
#[derive(Debug, serde::Serialize)]
pub struct AuthenticationSuccess {
access_token: String,
device_id: String,
user_id: String,
}
impl AuthenticationSuccess {
pub fn new(access_token: &str, device_id: &str, user_id: &UserId) -> Self {
Self {
access_token: access_token.to_owned(),
device_id: device_id.to_owned(),
user_id: user_id.to_string(),
}
}
}

View File

@ -1,22 +0,0 @@
use crate::types::flow::Flow;
#[derive(Debug, Clone, serde::Serialize)]
struct FlowWrapper {
#[serde(rename = "type")]
_type: Flow,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct Flows {
flows: Vec<FlowWrapper>,
}
impl Flows {
pub fn new() -> Self {
Self {
flows: vec![FlowWrapper {
_type: Flow::Password,
}],
}
}
}

View File

@ -1,5 +1 @@
pub mod authentication;
pub mod flow;
pub mod registration;
pub mod username_available;
pub mod versions;

View File

@ -3,7 +3,6 @@ use crate::types::user_interactive_authorization::UserInteractiveAuthorizationIn
#[derive(Debug, serde::Serialize)]
#[serde(untagged)]
pub enum RegistrationResponse {
Success(RegistrationSuccess),
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
}
@ -13,22 +12,4 @@ impl RegistrationResponse {
UserInteractiveAuthorizationInfo::new(),
)
}
}
#[derive(Debug, serde::Serialize)]
pub struct RegistrationSuccess {
#[serde(skip_serializing_if = "Option::is_none")]
access_token: Option<String>,
device_id: String,
user_id: String,
}
impl RegistrationSuccess {
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
Self {
access_token: access_token.map(|v| v.to_owned()),
device_id: device_id.to_owned(),
user_id: user_id.to_owned(),
}
}
}
}

View File

@ -1,10 +0,0 @@
#[derive(Debug, serde::Serialize)]
pub struct UsernameAvailable {
available: bool,
}
impl UsernameAvailable {
pub fn new(available: bool) -> Self {
Self { available }
}
}

View File

@ -1,17 +0,0 @@
use std::collections::HashMap;
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Versions {
#[serde(skip_serializing_if = "Option::is_none")]
unstable_features: Option<HashMap<String, String>>,
versions: Vec<String>,
}
impl Default for Versions {
fn default() -> Self {
Self {
unstable_features: None,
versions: vec!["v1.2".into()],
}
}
}

View File

@ -1,5 +1,5 @@
use axum::{
body::{Bytes, HttpBody, Full},
body::{Bytes, Full, HttpBody},
extract::{FromRequest, Path},
response::IntoResponse,
BoxError,

View File

@ -1 +0,0 @@
mod v2;

View File

@ -1,102 +0,0 @@
use crate::models::events::Event;
use std::{
collections::{HashMap, HashSet},
future::Future,
};
use tracing::info;
type StateMap<T> = HashMap<StateTuple, T>;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct StateTuple {
event_type: String,
state_key: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EventId {
id: Box<str>,
}
#[tracing::instrument(skip(state_sets, auth_chain_sets, get_event_callback))]
pub async fn resolve<F, Fut>(
room_id: &str, // TODO: own type
state_sets: Vec<StateMap<EventId>>,
auth_chain_sets: Vec<HashSet<EventId>>,
get_event_callback: F,
) -> StateMap<EventId>
where
F: Fn(&EventId) -> Fut,
Fut: Future<Output = Option<Event>>,
{
info!("Calculating conflicted state");
let (unconflicted_state, conflicted_state) = separate_state(&state_sets);
if conflicted_state.is_empty() {
return unconflicted_state;
}
info!("{} conflicted_state entries", conflicted_state.len());
info!("Calculating auth_chain differences");
let conflicted_set =
get_auth_chain_differences(auth_chain_sets).chain(conflicted_state.into_values().flatten());
let mut conflicted = HashSet::new();
for eid in conflicted_set {
if let Some(event) = get_event_callback(&eid).await {
conflicted.insert(event);
}
}
todo!()
}
/// separates states from multiple state_maps into unconflicted and conflicted state
///
/// For the set of all state_tuples find all event_ids.
/// If one event_id is found it is unconflicted, otherwise it is conflicted
fn separate_state(
state_sets: &[StateMap<EventId>],
) -> (StateMap<EventId>, StateMap<HashSet<EventId>>) {
let mut unconflicted_state: StateMap<EventId> = StateMap::new();
let mut conflicted_state: HashMap<StateTuple, HashSet<EventId>> = StateMap::new();
for key in state_sets
.iter()
.flat_map(HashMap::keys)
.map(ToOwned::to_owned)
.collect::<HashSet<StateTuple>>()
{
let mut event_ids: HashSet<EventId> = state_sets
.iter()
.filter_map(|state_set| state_set.get(&key))
.map(ToOwned::to_owned)
.collect();
if event_ids.len() == 1 {
unconflicted_state.insert(key, event_ids.into_iter().next().expect("len() is 1"));
} else {
conflicted_state.insert(key, event_ids);
}
}
(unconflicted_state, conflicted_state)
}
fn get_auth_chain_differences(
auth_chain_sets: Vec<HashSet<EventId>>,
) -> impl Iterator<Item = EventId> {
let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<EventId, usize> = HashMap::new();
for id in auth_chain_sets.into_iter().flatten() {
*id_counts.entry(id).or_default() += 1;
}
id_counts
.into_iter()
.filter_map(move |(id, count)| (count < num_sets).then(move || id))
}
fn is_control_event() {}

View File

@ -1,12 +0,0 @@
use super::{uuid::Uuid, user_id::UserId};
pub struct ClientEvent {
content: (),
event_id: Uuid,
origin_server_ts: u64,
room_id: String,
sender: UserId,
state_key: Option<String>,
r#type: String,
unsigned: ()
}

View File

@ -20,7 +20,7 @@ impl<'e> sqlx::Encode<'e, Sqlite> for EventType {
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
match self {
EventType::RoomCreate => "m.room.create",
EventType::Unknown => "???"
EventType::Unknown => "???",
}
.into(),
));
@ -47,7 +47,7 @@ impl serde::Serialize for EventType {
{
serializer.serialize_str(match self {
EventType::RoomCreate => "m.room.create",
EventType::Unknown => "dev.fuckwit.unknown_event"
EventType::Unknown => "dev.fuckwit.unknown_event",
})
}
}

View File

@ -1,12 +1,10 @@
pub mod authentication_data;
pub mod error_code;
pub mod event_type;
pub mod flow;
pub mod identifier;
pub mod identifier_type;
pub mod server_name;
pub mod user_id;
pub mod user_interactive_authorization;
pub mod uuid;
pub mod server_name;
pub mod client_event;
pub mod room_event;
pub mod event_type;

View File

@ -1,7 +0,0 @@
pub enum RoomEvent {
Create(RoomCreateEvent)
}
pub struct RoomCreateEvent {
}

View File

@ -1,15 +1,15 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
#[derive(Debug, PartialEq, Eq)]
enum Hostname {
pub enum Hostname {
IPv4(Ipv4Addr),
IPv6(Ipv6Addr),
Fqdn(String),
}
pub struct ServerName {
hostname: Hostname,
port: Option<u16>,
pub hostname: Hostname,
pub port: Option<u16>,
}
impl ServerName {
@ -57,18 +57,6 @@ impl ServerName {
})
}
}
/// Get a reference to the server name's hostname.
#[must_use]
fn hostname(&self) -> &Hostname {
&self.hostname
}
/// Get the server name's port.
#[must_use]
fn port(&self) -> Option<u16> {
self.port
}
}
impl std::fmt::Display for ServerName {
@ -93,59 +81,59 @@ mod tests {
fn parse_ipv4_without_port() {
let server_name = ServerName::new("127.0.0.1").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv4("127.0.0.1".parse().unwrap())
server_name.hostname,
Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port(), None);
assert_eq!(server_name.port, None);
}
#[test]
fn parse_ipv4_with_port() {
let server_name = ServerName::new("127.0.0.1:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv4("127.0.0.1".parse().unwrap())
server_name.hostname,
Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port(), Some(8080));
assert_eq!(server_name.port, Some(8080));
}
#[test]
fn parse_ipv6_without_port() {
let server_name = ServerName::new("[::1]").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv6("::1".parse().unwrap())
server_name.hostname,
Hostname::IPv6("::1".parse().unwrap())
);
assert_eq!(server_name.port(), None);
assert_eq!(server_name.port, None);
}
#[test]
fn parse_ipv6_with_port() {
let server_name = ServerName::new("[::1]:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv6("::1".parse().unwrap())
server_name.hostname,
Hostname::IPv6("::1".parse().unwrap())
);
assert_eq!(server_name.port(), Some(8080));
assert_eq!(server_name.port, Some(8080));
}
#[test]
fn parse_fqdn_without_port() {
let server_name = ServerName::new("example.com").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::Fqdn("example.com".into())
server_name.hostname,
Hostname::Fqdn("example.com".into())
);
assert_eq!(server_name.port(), None);
assert_eq!(server_name.port, None);
}
#[test]
fn parse_fqdn_with_port() {
let server_name = ServerName::new("example.com:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::Fqdn("example.com".into())
server_name.hostname,
Hostname::Fqdn("example.com".into())
);
assert_eq!(server_name.port(), Some(8080));
assert_eq!(server_name.port, Some(8080));
}
}