initial commit
This commit is contained in:
commit
d17c3a202a
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/target
|
||||
/db.sqlite*
|
1572
Cargo.lock
generated
Normal file
1572
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
17
Cargo.toml
Normal file
17
Cargo.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "matrix"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.17.0", features = ["full"] }
|
||||
axum = "0.5.3"
|
||||
tracing = "0.1.34"
|
||||
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
|
||||
serde = {version = "1.0.136", features = ["derive"] }
|
||||
tower-http = { version = "0.2.5", features = ["cors", "trace"]}
|
||||
sqlx = { version = "0.5.13", features = ["sqlite", "macros", "runtime-tokio-rustls"] }
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
9
migrations/20220423204756_add_user.sql
Normal file
9
migrations/20220423204756_add_user.sql
Normal file
@ -0,0 +1,9 @@
|
||||
-- Add migration script here
|
||||
CREATE TABLE users(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
user_id CHAR(255) NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
password TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX user_id_index ON users (user_id);
|
11
migrations/20220424172900_add_devices.sql
Normal file
11
migrations/20220424172900_add_devices.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- Add migration script here
|
||||
|
||||
CREATE TABLE devices(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
user_id INT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE INDEX device_id_index ON devices (device_id);
|
10
migrations/20220424175554_add_sessions.sql
Normal file
10
migrations/20220424175554_add_sessions.sql
Normal file
@ -0,0 +1,10 @@
|
||||
-- Add migration script here
|
||||
|
||||
CREATE TABLE sessions(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
device_id INT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
FOREIGN KEY(device_id) REFERENCES devices(id)
|
||||
);
|
||||
|
||||
CREATE INDEX value_index ON sessions (value);
|
110
src/api/client_server/auth.rs
Normal file
110
src/api/client_server/auth.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
http::StatusCode,
|
||||
routing::{get, post},
|
||||
Extension, Json,
|
||||
};
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::responses::registration::RegistrationResponse;
|
||||
use crate::{
|
||||
models::devices::Device,
|
||||
responses::{flow::Flows, registration::RegistrationSuccess},
|
||||
};
|
||||
use crate::{
|
||||
models::users::User,
|
||||
requests::registration::RegistrationRequest,
|
||||
responses::username_available::UsernameAvailable,
|
||||
types::{
|
||||
authentication_data::AuthenticationData, identifier::Identifier, matrix_user_id::UserId,
|
||||
},
|
||||
Config,
|
||||
};
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/login", get(get_login).post(post_login))
|
||||
.route("/r0/register", post(post_register))
|
||||
.route("/r0/register/available", get(get_username_available))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_login() -> Json<Flows> {
|
||||
Json(Flows::new())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn post_login(body: String) -> StatusCode {
|
||||
dbg!(body);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_username_available(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Json<UsernameAvailable> {
|
||||
let username = params.get("username").unwrap();
|
||||
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
|
||||
let exists = User::exists(&db, &user_id).await.unwrap();
|
||||
Json(UsernameAvailable::new(!exists))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn post_register(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Json(body): Json<RegistrationRequest>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> (StatusCode, Json<RegistrationResponse>) {
|
||||
// Client tries to get available flows
|
||||
if *&body.auth().is_none() {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(RegistrationResponse::user_interactive_authorization_info()),
|
||||
);
|
||||
}
|
||||
|
||||
let (user, device) = match &body.auth().unwrap() {
|
||||
AuthenticationData::Password(auth_data) => {
|
||||
// let username = match auth_data.identifier() {
|
||||
// Identifier::User(user_identifier) => user_identifier.user().unwrap(),
|
||||
// };
|
||||
let username = body.username().unwrap();
|
||||
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
|
||||
if User::exists(&db, &user_id).await.unwrap() {
|
||||
todo!("Error out")
|
||||
}
|
||||
let password = auth_data.password();
|
||||
let display_name = match *&body.initial_device_display_name() {
|
||||
Some(display_name) => display_name.as_ref(),
|
||||
None => "Random displayname",
|
||||
};
|
||||
|
||||
let user = User::create(&db, &user_id, &user_id.to_string(), &password)
|
||||
.await
|
||||
.unwrap();
|
||||
let device = Device::create(&db, &user, "test", display_name)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
(user, device)
|
||||
}
|
||||
};
|
||||
|
||||
// dont log in the user after registration
|
||||
if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() {
|
||||
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
|
||||
|
||||
(StatusCode::OK, Json(RegistrationResponse::Success(resp)))
|
||||
} else {
|
||||
let session = device.create_session(&db).await.unwrap();
|
||||
let resp =
|
||||
RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id());
|
||||
|
||||
(StatusCode::OK, Json(RegistrationResponse::Success(resp)))
|
||||
}
|
||||
}
|
2
src/api/client_server/mod.rs
Normal file
2
src/api/client_server/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod auth;
|
||||
pub mod versions;
|
12
src/api/client_server/versions.rs
Normal file
12
src/api/client_server/versions.rs
Normal file
@ -0,0 +1,12 @@
|
||||
use axum::{routing::get, Json};
|
||||
|
||||
use crate::responses::versions::Versions;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new().route("/versions", get(get_client_versions))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_client_versions() -> Json<Versions> {
|
||||
Json(Versions::default())
|
||||
}
|
1
src/api/mod.rs
Normal file
1
src/api/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod client_server;
|
73
src/main.rs
Normal file
73
src/main.rs
Normal file
@ -0,0 +1,73 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
handler::Handler,
|
||||
http::{Request, StatusCode},
|
||||
Router, Extension,
|
||||
};
|
||||
use tower_http::{
|
||||
cors::CorsLayer,
|
||||
trace::{DefaultOnRequest, TraceLayer, DefaultOnResponse},
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
mod api;
|
||||
mod responses;
|
||||
mod requests;
|
||||
mod types;
|
||||
mod models;
|
||||
|
||||
struct Config {
|
||||
db_path: String,
|
||||
homeserver_name: String,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self { db_path: "sqlite://db.sqlite3".into(), homeserver_name: "fuckwit.dev".into() }
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
if std::env::var("RUST_LOG").is_err() {
|
||||
std::env::set_var("RUST_LOG", "debug");
|
||||
}
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// init config
|
||||
let config = Arc::new(Config::default());
|
||||
|
||||
let pool = sqlx::SqlitePool::connect("sqlite://db.sqlite3").await.unwrap();
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::Any)
|
||||
.allow_methods(tower_http::cors::Any)
|
||||
.allow_headers(tower_http::cors::Any);
|
||||
|
||||
let tracing_layer = TraceLayer::new_for_http();
|
||||
|
||||
let client_server = Router::new()
|
||||
.merge(api::client_server::versions::routes())
|
||||
.merge(api::client_server::auth::routes());
|
||||
|
||||
let router = Router::new()
|
||||
.nest("/_matrix/client", client_server)
|
||||
.layer(cors)
|
||||
.layer(tracing_layer)
|
||||
.layer(Extension(pool))
|
||||
.layer(Extension(config))
|
||||
.fallback(fallback.into_service());
|
||||
|
||||
let _ = axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
|
||||
.serve(router.into_make_service())
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn fallback(request: Request<Body>) -> StatusCode {
|
||||
dbg!(request);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
50
src/models/devices.rs
Normal file
50
src/models/devices.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use super::{users::User, sessions::Session};
|
||||
|
||||
pub struct Device {
|
||||
id: i64,
|
||||
user_id: i64,
|
||||
device_id: String,
|
||||
display_name: String
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub async fn create(conn: &SqlitePool, user: &User, device_id: &str, display_name: &str) -> anyhow::Result<Self> {
|
||||
let user_id = user.id();
|
||||
Ok(sqlx::query_as!(Self, "insert into devices(user_id, device_id, display_name) values(?, ?, ?) returning id, user_id, device_id, display_name", user_id, device_id, display_name).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
pub async fn by_user(conn: &SqlitePool, user: &User) -> anyhow::Result<Self> {
|
||||
let user_id = user.id();
|
||||
Ok(sqlx::query_as!(Self, "select id, user_id, device_id, display_name from devices where user_id = ?", user_id).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
pub async fn create_session(&self, conn: &SqlitePool) -> anyhow::Result<Session> {
|
||||
Ok(Session::create(conn, self, "random_session_id").await?)
|
||||
}
|
||||
|
||||
/// Get the device's id.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Get the device's user id.
|
||||
#[must_use]
|
||||
pub fn user_id(&self) -> i64 {
|
||||
self.user_id
|
||||
}
|
||||
|
||||
/// Get a reference to the device's device id.
|
||||
#[must_use]
|
||||
pub fn device_id(&self) -> &str {
|
||||
self.device_id.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the device's display name.
|
||||
#[must_use]
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.display_name.as_ref()
|
||||
}
|
||||
}
|
3
src/models/mod.rs
Normal file
3
src/models/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod users;
|
||||
pub mod devices;
|
||||
pub mod sessions;
|
34
src/models/sessions.rs
Normal file
34
src/models/sessions.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use super::devices::Device;
|
||||
|
||||
pub struct Session {
|
||||
id: i64,
|
||||
device_id: i64,
|
||||
value: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn create(conn: &SqlitePool, device: &Device, value: &str) -> anyhow::Result<Self> {
|
||||
let device_id = device.id();
|
||||
Ok(sqlx::query_as!(Self, "insert into sessions(device_id, value) values(?, ?) returning id, device_id, value", device_id, value).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
/// Get the session's id.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Get the session's device id.
|
||||
#[must_use]
|
||||
pub fn device_id(&self) -> i64 {
|
||||
self.device_id
|
||||
}
|
||||
|
||||
/// Get a reference to the session's value.
|
||||
#[must_use]
|
||||
pub fn value(&self) -> &str {
|
||||
self.value.as_ref()
|
||||
}
|
||||
}
|
42
src/models/users.rs
Normal file
42
src/models/users.rs
Normal file
@ -0,0 +1,42 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::types::matrix_user_id::UserId;
|
||||
|
||||
pub struct User {
|
||||
id: i64,
|
||||
user_id: String,
|
||||
display_name: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub async fn exists(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<bool> {
|
||||
Ok(sqlx::query!("select user_id from users where user_id = ?", user_id).fetch_optional(conn).await?.is_some())
|
||||
}
|
||||
|
||||
pub async fn create(conn: &SqlitePool, user_id: &UserId, display_name: &str, password: &str) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(Self, "insert into users(user_id, display_name, password) values (?, ?, ?) returning id, user_id, display_name, password", user_id, display_name, password).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
pub async fn by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(Self, "select id, user_id, display_name, password from users where user_id = ?", user_id).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
/// Get the user's id.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Get a reference to the user's user id.
|
||||
#[must_use]
|
||||
pub fn user_id(&self) -> &str {
|
||||
self.user_id.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the user's password.
|
||||
#[must_use]
|
||||
pub fn password(&self) -> &str {
|
||||
self.password.as_ref()
|
||||
}
|
||||
}
|
1
src/requests/mod.rs
Normal file
1
src/requests/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod registration;
|
59
src/requests/registration.rs
Normal file
59
src/requests/registration.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use crate::types::{flow::Flow, identifier::Identifier, authentication_data::AuthenticationData};
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct RegistrationRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
auth: Option<AuthenticationData>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
device_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
inhibit_login: Option<bool>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
initial_device_display_name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
password: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
username: Option<String>,
|
||||
}
|
||||
|
||||
impl RegistrationRequest {
|
||||
#[must_use]
|
||||
pub fn auth(&self) -> Option<&AuthenticationData> {
|
||||
self.auth.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's device id.
|
||||
#[must_use]
|
||||
pub fn device_id(&self) -> Option<&String> {
|
||||
self.device_id.as_ref()
|
||||
}
|
||||
|
||||
/// Get the registration request's inhibit login.
|
||||
#[must_use]
|
||||
pub fn inhibit_login(&self) -> Option<bool> {
|
||||
self.inhibit_login
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's initial device display name.
|
||||
#[must_use]
|
||||
pub fn initial_device_display_name(&self) -> Option<&String> {
|
||||
self.initial_device_display_name.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's password.
|
||||
#[must_use]
|
||||
pub fn password(&self) -> Option<&String> {
|
||||
self.password.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's username.
|
||||
#[must_use]
|
||||
pub fn username(&self) -> Option<&String> {
|
||||
self.username.as_ref()
|
||||
}
|
||||
}
|
20
src/responses/flow.rs
Normal file
20
src/responses/flow.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use crate::types::flow::Flow;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
struct FlowWrapper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Flow
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct Flows {
|
||||
flows: Vec<FlowWrapper>,
|
||||
}
|
||||
|
||||
impl Flows {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
flows: vec![FlowWrapper{ _type: Flow::Password }],
|
||||
}
|
||||
}
|
||||
}
|
4
src/responses/mod.rs
Normal file
4
src/responses/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod flow;
|
||||
pub mod registration;
|
||||
pub mod username_available;
|
||||
pub mod versions;
|
34
src/responses/registration.rs
Normal file
34
src/responses/registration.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use crate::types::user_interactive_authorization::UserInteractiveAuthorizationInfo;
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RegistrationResponse {
|
||||
Success(RegistrationSuccess),
|
||||
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
|
||||
}
|
||||
|
||||
impl RegistrationResponse {
|
||||
pub fn user_interactive_authorization_info() -> Self {
|
||||
RegistrationResponse::UserInteractiveAuthorizationInfo(
|
||||
UserInteractiveAuthorizationInfo::new(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct RegistrationSuccess {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
access_token: Option<String>,
|
||||
device_id: String,
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
impl RegistrationSuccess {
|
||||
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
|
||||
Self {
|
||||
access_token: access_token.and_then(|v| Some(v.to_owned())),
|
||||
device_id: device_id.to_owned(),
|
||||
user_id: user_id.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
10
src/responses/username_available.rs
Normal file
10
src/responses/username_available.rs
Normal file
@ -0,0 +1,10 @@
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct UsernameAvailable {
|
||||
available: bool,
|
||||
}
|
||||
|
||||
impl UsernameAvailable {
|
||||
pub fn new(available: bool) -> Self {
|
||||
Self { available }
|
||||
}
|
||||
}
|
17
src/responses/versions.rs
Normal file
17
src/responses/versions.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Versions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
unstable_features: Option<HashMap<String, String>>,
|
||||
versions: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for Versions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
unstable_features: None,
|
||||
versions: vec!["v1.2".into()],
|
||||
}
|
||||
}
|
||||
}
|
36
src/types/authentication_data.rs
Normal file
36
src/types/authentication_data.rs
Normal file
@ -0,0 +1,36 @@
|
||||
use super::{flow::Flow, identifier::Identifier};
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AuthenticationData {
|
||||
Password(AuthenticationPassword)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct AuthenticationPassword {
|
||||
#[serde(rename = "type")]
|
||||
_type: Flow,
|
||||
identifier: Identifier,
|
||||
password: String,
|
||||
user: Option<String>
|
||||
}
|
||||
|
||||
impl AuthenticationPassword {
|
||||
/// Get a reference to the authentication password's identifier.
|
||||
#[must_use]
|
||||
pub fn identifier(&self) -> &Identifier {
|
||||
&self.identifier
|
||||
}
|
||||
|
||||
/// Get a reference to the authentication password's password.
|
||||
#[must_use]
|
||||
pub fn password(&self) -> &str {
|
||||
self.password.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the authentication password's user.
|
||||
#[must_use]
|
||||
pub fn user(&self) -> Option<&String> {
|
||||
self.user.as_ref()
|
||||
}
|
||||
}
|
44
src/types/flow.rs
Normal file
44
src/types/flow.rs
Normal file
@ -0,0 +1,44 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Flow {
|
||||
Password,
|
||||
}
|
||||
|
||||
impl serde::Serialize for Flow {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(match self {
|
||||
Flow::Password => "m.login.password",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for Flow {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct FlowVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for FlowVisitor {
|
||||
type Value = Flow;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("Flow")
|
||||
}
|
||||
|
||||
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
match v {
|
||||
"m.login.password" => Ok(Flow::Password),
|
||||
_ => Err(serde::de::Error::custom("Unknown flow")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(FlowVisitor {})
|
||||
}
|
||||
}
|
22
src/types/identifier.rs
Normal file
22
src/types/identifier.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use super::identifier_type::IdentifierType;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Identifier {
|
||||
User(IdentifierUser)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct IdentifierUser {
|
||||
#[serde(rename = "type")]
|
||||
_type: IdentifierType,
|
||||
user: Option<String>
|
||||
}
|
||||
|
||||
impl IdentifierUser {
|
||||
/// Get a reference to the identifier user's user.
|
||||
#[must_use]
|
||||
pub fn user(&self) -> Option<&String> {
|
||||
self.user.as_ref()
|
||||
}
|
||||
}
|
44
src/types/identifier_type.rs
Normal file
44
src/types/identifier_type.rs
Normal file
@ -0,0 +1,44 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum IdentifierType {
|
||||
User,
|
||||
}
|
||||
|
||||
impl serde::Serialize for IdentifierType {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(match self {
|
||||
IdentifierType::User => "m.id.user",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for IdentifierType {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdentifierVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for IdentifierVisitor {
|
||||
type Value = IdentifierType;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("Identifier")
|
||||
}
|
||||
|
||||
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
match v {
|
||||
"m.id.user" => Ok(IdentifierType::User),
|
||||
_ => Err(serde::de::Error::custom("Unknown identifier")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(IdentifierVisitor {})
|
||||
}
|
||||
}
|
35
src/types/matrix_user_id.rs
Normal file
35
src/types/matrix_user_id.rs
Normal file
@ -0,0 +1,35 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
pub struct UserId(String);
|
||||
|
||||
impl UserId {
|
||||
pub fn new(name: &str, server_name: &str) -> anyhow::Result<Self> {
|
||||
let user_id = Self(format!("@{name}:{server_name}"));
|
||||
|
||||
user_id.is_valid()?;
|
||||
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
fn is_valid(&self) -> anyhow::Result<()> {
|
||||
(self.0.len() <= 255).then(|| ()).ok_or(UserIdError::TooLong(self.0.len()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for UserId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UserIdError {
|
||||
#[error("UserId too long {0} (expected < 255)")]
|
||||
TooLong(usize),
|
||||
#[error("Invalid UserId given")]
|
||||
Invalid
|
||||
}
|
6
src/types/mod.rs
Normal file
6
src/types/mod.rs
Normal file
@ -0,0 +1,6 @@
|
||||
pub mod authentication_data;
|
||||
pub mod flow;
|
||||
pub mod identifier;
|
||||
pub mod identifier_type;
|
||||
pub mod user_interactive_authorization;
|
||||
pub mod matrix_user_id;
|
25
src/types/user_interactive_authorization.rs
Normal file
25
src/types/user_interactive_authorization.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use crate::types::flow::Flow;
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct UserInteractiveAuthorizationInfo {
|
||||
flows: Vec<FlowStage>,
|
||||
completed: Vec<Flow>,
|
||||
session: Option<String>,
|
||||
auth_error: Option<String>
|
||||
}
|
||||
|
||||
impl UserInteractiveAuthorizationInfo {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
flows: vec![FlowStage { stages: vec![Flow::Password]}],
|
||||
completed: vec![],
|
||||
session: None,
|
||||
auth_error: None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
struct FlowStage {
|
||||
stages: Vec<Flow>,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user