initial commit

This commit is contained in:
Patrick Michl 2022-04-24 22:04:15 +02:00
commit d17c3a202a
30 changed files with 2306 additions and 0 deletions

1
.env Normal file
View File

@ -0,0 +1 @@
DATABASE_URL=sqlite://db.sqlite3

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
/db.sqlite*

1572
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

17
Cargo.toml Normal file
View File

@ -0,0 +1,17 @@
[package]
name = "matrix"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1.17.0", features = ["full"] }
axum = "0.5.3"
tracing = "0.1.34"
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
serde = {version = "1.0.136", features = ["derive"] }
tower-http = { version = "0.2.5", features = ["cors", "trace"]}
sqlx = { version = "0.5.13", features = ["sqlite", "macros", "runtime-tokio-rustls"] }
anyhow = "1.0"
thiserror = "1.0"

View File

@ -0,0 +1,9 @@
-- Add migration script here
CREATE TABLE users(
id INTEGER PRIMARY KEY NOT NULL,
user_id CHAR(255) NOT NULL,
display_name TEXT NOT NULL,
password TEXT NOT NULL
);
CREATE INDEX user_id_index ON users (user_id);

View File

@ -0,0 +1,11 @@
-- Add migration script here
CREATE TABLE devices(
id INTEGER PRIMARY KEY NOT NULL,
user_id INT NOT NULL,
device_id TEXT NOT NULL,
display_name TEXT NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX device_id_index ON devices (device_id);

View File

@ -0,0 +1,10 @@
-- Add migration script here
CREATE TABLE sessions(
id INTEGER PRIMARY KEY NOT NULL,
device_id INT NOT NULL,
value TEXT NOT NULL,
FOREIGN KEY(device_id) REFERENCES devices(id)
);
CREATE INDEX value_index ON sessions (value);

View File

@ -0,0 +1,110 @@
use std::{collections::HashMap, sync::Arc};
use axum::{
extract::Query,
http::StatusCode,
routing::{get, post},
Extension, Json,
};
use sqlx::SqlitePool;
use crate::responses::registration::RegistrationResponse;
use crate::{
models::devices::Device,
responses::{flow::Flows, registration::RegistrationSuccess},
};
use crate::{
models::users::User,
requests::registration::RegistrationRequest,
responses::username_available::UsernameAvailable,
types::{
authentication_data::AuthenticationData, identifier::Identifier, matrix_user_id::UserId,
},
Config,
};
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/r0/login", get(get_login).post(post_login))
.route("/r0/register", post(post_register))
.route("/r0/register/available", get(get_username_available))
}
#[tracing::instrument]
async fn get_login() -> Json<Flows> {
Json(Flows::new())
}
#[tracing::instrument(skip_all)]
async fn post_login(body: String) -> StatusCode {
dbg!(body);
StatusCode::INTERNAL_SERVER_ERROR
}
#[tracing::instrument(skip_all)]
async fn get_username_available(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>,
Query(params): Query<HashMap<String, String>>,
) -> Json<UsernameAvailable> {
let username = params.get("username").unwrap();
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
let exists = User::exists(&db, &user_id).await.unwrap();
Json(UsernameAvailable::new(!exists))
}
#[tracing::instrument(skip_all)]
async fn post_register(
Extension(config): Extension<Arc<Config>>,
Extension(db): Extension<SqlitePool>,
Json(body): Json<RegistrationRequest>,
Query(params): Query<HashMap<String, String>>,
) -> (StatusCode, Json<RegistrationResponse>) {
// Client tries to get available flows
if *&body.auth().is_none() {
return (
StatusCode::UNAUTHORIZED,
Json(RegistrationResponse::user_interactive_authorization_info()),
);
}
let (user, device) = match &body.auth().unwrap() {
AuthenticationData::Password(auth_data) => {
// let username = match auth_data.identifier() {
// Identifier::User(user_identifier) => user_identifier.user().unwrap(),
// };
let username = body.username().unwrap();
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
if User::exists(&db, &user_id).await.unwrap() {
todo!("Error out")
}
let password = auth_data.password();
let display_name = match *&body.initial_device_display_name() {
Some(display_name) => display_name.as_ref(),
None => "Random displayname",
};
let user = User::create(&db, &user_id, &user_id.to_string(), &password)
.await
.unwrap();
let device = Device::create(&db, &user, "test", display_name)
.await
.unwrap();
(user, device)
}
};
// dont log in the user after registration
if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() {
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
(StatusCode::OK, Json(RegistrationResponse::Success(resp)))
} else {
let session = device.create_session(&db).await.unwrap();
let resp =
RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id());
(StatusCode::OK, Json(RegistrationResponse::Success(resp)))
}
}

View File

@ -0,0 +1,2 @@
pub mod auth;
pub mod versions;

View File

@ -0,0 +1,12 @@
use axum::{routing::get, Json};
use crate::responses::versions::Versions;
pub fn routes() -> axum::Router {
axum::Router::new().route("/versions", get(get_client_versions))
}
#[tracing::instrument]
async fn get_client_versions() -> Json<Versions> {
Json(Versions::default())
}

1
src/api/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod client_server;

73
src/main.rs Normal file
View File

@ -0,0 +1,73 @@
#![allow(unused)]
use std::sync::Arc;
use axum::{
body::Body,
handler::Handler,
http::{Request, StatusCode},
Router, Extension,
};
use tower_http::{
cors::CorsLayer,
trace::{DefaultOnRequest, TraceLayer, DefaultOnResponse},
};
use tracing::Level;
mod api;
mod responses;
mod requests;
mod types;
mod models;
struct Config {
db_path: String,
homeserver_name: String,
}
impl Default for Config {
fn default() -> Self {
Self { db_path: "sqlite://db.sqlite3".into(), homeserver_name: "fuckwit.dev".into() }
}
}
#[tokio::main]
async fn main() {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "debug");
}
tracing_subscriber::fmt::init();
// init config
let config = Arc::new(Config::default());
let pool = sqlx::SqlitePool::connect("sqlite://db.sqlite3").await.unwrap();
let cors = CorsLayer::new()
.allow_origin(tower_http::cors::Any)
.allow_methods(tower_http::cors::Any)
.allow_headers(tower_http::cors::Any);
let tracing_layer = TraceLayer::new_for_http();
let client_server = Router::new()
.merge(api::client_server::versions::routes())
.merge(api::client_server::auth::routes());
let router = Router::new()
.nest("/_matrix/client", client_server)
.layer(cors)
.layer(tracing_layer)
.layer(Extension(pool))
.layer(Extension(config))
.fallback(fallback.into_service());
let _ = axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
.serve(router.into_make_service())
.await;
}
async fn fallback(request: Request<Body>) -> StatusCode {
dbg!(request);
StatusCode::INTERNAL_SERVER_ERROR
}

50
src/models/devices.rs Normal file
View File

@ -0,0 +1,50 @@
use sqlx::SqlitePool;
use super::{users::User, sessions::Session};
pub struct Device {
id: i64,
user_id: i64,
device_id: String,
display_name: String
}
impl Device {
pub async fn create(conn: &SqlitePool, user: &User, device_id: &str, display_name: &str) -> anyhow::Result<Self> {
let user_id = user.id();
Ok(sqlx::query_as!(Self, "insert into devices(user_id, device_id, display_name) values(?, ?, ?) returning id, user_id, device_id, display_name", user_id, device_id, display_name).fetch_one(conn).await?)
}
pub async fn by_user(conn: &SqlitePool, user: &User) -> anyhow::Result<Self> {
let user_id = user.id();
Ok(sqlx::query_as!(Self, "select id, user_id, device_id, display_name from devices where user_id = ?", user_id).fetch_one(conn).await?)
}
pub async fn create_session(&self, conn: &SqlitePool) -> anyhow::Result<Session> {
Ok(Session::create(conn, self, "random_session_id").await?)
}
/// Get the device's id.
#[must_use]
pub fn id(&self) -> i64 {
self.id
}
/// Get the device's user id.
#[must_use]
pub fn user_id(&self) -> i64 {
self.user_id
}
/// Get a reference to the device's device id.
#[must_use]
pub fn device_id(&self) -> &str {
self.device_id.as_ref()
}
/// Get a reference to the device's display name.
#[must_use]
pub fn display_name(&self) -> &str {
self.display_name.as_ref()
}
}

3
src/models/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod users;
pub mod devices;
pub mod sessions;

34
src/models/sessions.rs Normal file
View File

@ -0,0 +1,34 @@
use sqlx::SqlitePool;
use super::devices::Device;
pub struct Session {
id: i64,
device_id: i64,
value: String,
}
impl Session {
pub async fn create(conn: &SqlitePool, device: &Device, value: &str) -> anyhow::Result<Self> {
let device_id = device.id();
Ok(sqlx::query_as!(Self, "insert into sessions(device_id, value) values(?, ?) returning id, device_id, value", device_id, value).fetch_one(conn).await?)
}
/// Get the session's id.
#[must_use]
pub fn id(&self) -> i64 {
self.id
}
/// Get the session's device id.
#[must_use]
pub fn device_id(&self) -> i64 {
self.device_id
}
/// Get a reference to the session's value.
#[must_use]
pub fn value(&self) -> &str {
self.value.as_ref()
}
}

42
src/models/users.rs Normal file
View File

@ -0,0 +1,42 @@
use sqlx::SqlitePool;
use crate::types::matrix_user_id::UserId;
pub struct User {
id: i64,
user_id: String,
display_name: String,
password: String,
}
impl User {
pub async fn exists(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<bool> {
Ok(sqlx::query!("select user_id from users where user_id = ?", user_id).fetch_optional(conn).await?.is_some())
}
pub async fn create(conn: &SqlitePool, user_id: &UserId, display_name: &str, password: &str) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(Self, "insert into users(user_id, display_name, password) values (?, ?, ?) returning id, user_id, display_name, password", user_id, display_name, password).fetch_one(conn).await?)
}
pub async fn by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(Self, "select id, user_id, display_name, password from users where user_id = ?", user_id).fetch_one(conn).await?)
}
/// Get the user's id.
#[must_use]
pub fn id(&self) -> i64 {
self.id
}
/// Get a reference to the user's user id.
#[must_use]
pub fn user_id(&self) -> &str {
self.user_id.as_ref()
}
/// Get a reference to the user's password.
#[must_use]
pub fn password(&self) -> &str {
self.password.as_ref()
}
}

1
src/requests/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod registration;

View File

@ -0,0 +1,59 @@
use crate::types::{flow::Flow, identifier::Identifier, authentication_data::AuthenticationData};
#[derive(Debug, serde::Deserialize)]
pub struct RegistrationRequest {
#[serde(skip_serializing_if = "Option::is_none")]
auth: Option<AuthenticationData>,
#[serde(skip_serializing_if = "Option::is_none")]
device_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inhibit_login: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
initial_device_display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
password: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
username: Option<String>,
}
impl RegistrationRequest {
#[must_use]
pub fn auth(&self) -> Option<&AuthenticationData> {
self.auth.as_ref()
}
/// Get a reference to the registration request's device id.
#[must_use]
pub fn device_id(&self) -> Option<&String> {
self.device_id.as_ref()
}
/// Get the registration request's inhibit login.
#[must_use]
pub fn inhibit_login(&self) -> Option<bool> {
self.inhibit_login
}
/// Get a reference to the registration request's initial device display name.
#[must_use]
pub fn initial_device_display_name(&self) -> Option<&String> {
self.initial_device_display_name.as_ref()
}
/// Get a reference to the registration request's password.
#[must_use]
pub fn password(&self) -> Option<&String> {
self.password.as_ref()
}
/// Get a reference to the registration request's username.
#[must_use]
pub fn username(&self) -> Option<&String> {
self.username.as_ref()
}
}

20
src/responses/flow.rs Normal file
View File

@ -0,0 +1,20 @@
use crate::types::flow::Flow;
#[derive(Debug, Clone, serde::Serialize)]
struct FlowWrapper {
#[serde(rename = "type")]
_type: Flow
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct Flows {
flows: Vec<FlowWrapper>,
}
impl Flows {
pub fn new() -> Self {
Self {
flows: vec![FlowWrapper{ _type: Flow::Password }],
}
}
}

4
src/responses/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod flow;
pub mod registration;
pub mod username_available;
pub mod versions;

View File

@ -0,0 +1,34 @@
use crate::types::user_interactive_authorization::UserInteractiveAuthorizationInfo;
#[derive(Debug, serde::Serialize)]
#[serde(untagged)]
pub enum RegistrationResponse {
Success(RegistrationSuccess),
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
}
impl RegistrationResponse {
pub fn user_interactive_authorization_info() -> Self {
RegistrationResponse::UserInteractiveAuthorizationInfo(
UserInteractiveAuthorizationInfo::new(),
)
}
}
#[derive(Debug, serde::Serialize)]
pub struct RegistrationSuccess {
#[serde(skip_serializing_if = "Option::is_none")]
access_token: Option<String>,
device_id: String,
user_id: String,
}
impl RegistrationSuccess {
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
Self {
access_token: access_token.and_then(|v| Some(v.to_owned())),
device_id: device_id.to_owned(),
user_id: user_id.to_owned(),
}
}
}

View File

@ -0,0 +1,10 @@
#[derive(Debug, serde::Serialize)]
pub struct UsernameAvailable {
available: bool,
}
impl UsernameAvailable {
pub fn new(available: bool) -> Self {
Self { available }
}
}

17
src/responses/versions.rs Normal file
View File

@ -0,0 +1,17 @@
use std::collections::HashMap;
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Versions {
#[serde(skip_serializing_if = "Option::is_none")]
unstable_features: Option<HashMap<String, String>>,
versions: Vec<String>,
}
impl Default for Versions {
fn default() -> Self {
Self {
unstable_features: None,
versions: vec!["v1.2".into()],
}
}
}

View File

@ -0,0 +1,36 @@
use super::{flow::Flow, identifier::Identifier};
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(untagged)]
pub enum AuthenticationData {
Password(AuthenticationPassword)
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct AuthenticationPassword {
#[serde(rename = "type")]
_type: Flow,
identifier: Identifier,
password: String,
user: Option<String>
}
impl AuthenticationPassword {
/// Get a reference to the authentication password's identifier.
#[must_use]
pub fn identifier(&self) -> &Identifier {
&self.identifier
}
/// Get a reference to the authentication password's password.
#[must_use]
pub fn password(&self) -> &str {
self.password.as_ref()
}
/// Get a reference to the authentication password's user.
#[must_use]
pub fn user(&self) -> Option<&String> {
self.user.as_ref()
}
}

44
src/types/flow.rs Normal file
View File

@ -0,0 +1,44 @@
#[derive(Debug, Clone)]
pub enum Flow {
Password,
}
impl serde::Serialize for Flow {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(match self {
Flow::Password => "m.login.password",
})
}
}
impl<'de> serde::Deserialize<'de> for Flow {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct FlowVisitor;
impl<'de> serde::de::Visitor<'de> for FlowVisitor {
type Value = Flow;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("Flow")
}
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v {
"m.login.password" => Ok(Flow::Password),
_ => Err(serde::de::Error::custom("Unknown flow")),
}
}
}
deserializer.deserialize_str(FlowVisitor {})
}
}

22
src/types/identifier.rs Normal file
View File

@ -0,0 +1,22 @@
use super::identifier_type::IdentifierType;
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(untagged)]
pub enum Identifier {
User(IdentifierUser)
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct IdentifierUser {
#[serde(rename = "type")]
_type: IdentifierType,
user: Option<String>
}
impl IdentifierUser {
/// Get a reference to the identifier user's user.
#[must_use]
pub fn user(&self) -> Option<&String> {
self.user.as_ref()
}
}

View File

@ -0,0 +1,44 @@
#[derive(Debug, Clone)]
pub enum IdentifierType {
User,
}
impl serde::Serialize for IdentifierType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(match self {
IdentifierType::User => "m.id.user",
})
}
}
impl<'de> serde::Deserialize<'de> for IdentifierType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct IdentifierVisitor;
impl<'de> serde::de::Visitor<'de> for IdentifierVisitor {
type Value = IdentifierType;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("Identifier")
}
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v {
"m.id.user" => Ok(IdentifierType::User),
_ => Err(serde::de::Error::custom("Unknown identifier")),
}
}
}
deserializer.deserialize_str(IdentifierVisitor {})
}
}

View File

@ -0,0 +1,35 @@
use std::fmt::Display;
#[derive(sqlx::Type)]
#[sqlx(transparent)]
pub struct UserId(String);
impl UserId {
pub fn new(name: &str, server_name: &str) -> anyhow::Result<Self> {
let user_id = Self(format!("@{name}:{server_name}"));
user_id.is_valid()?;
Ok(user_id)
}
fn is_valid(&self) -> anyhow::Result<()> {
(self.0.len() <= 255).then(|| ()).ok_or(UserIdError::TooLong(self.0.len()))?;
Ok(())
}
}
impl Display for UserId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, thiserror::Error)]
pub enum UserIdError {
#[error("UserId too long {0} (expected < 255)")]
TooLong(usize),
#[error("Invalid UserId given")]
Invalid
}

6
src/types/mod.rs Normal file
View File

@ -0,0 +1,6 @@
pub mod authentication_data;
pub mod flow;
pub mod identifier;
pub mod identifier_type;
pub mod user_interactive_authorization;
pub mod matrix_user_id;

View File

@ -0,0 +1,25 @@
use crate::types::flow::Flow;
#[derive(Debug, serde::Serialize)]
pub struct UserInteractiveAuthorizationInfo {
flows: Vec<FlowStage>,
completed: Vec<Flow>,
session: Option<String>,
auth_error: Option<String>
}
impl UserInteractiveAuthorizationInfo {
pub fn new() -> Self {
Self {
flows: vec![FlowStage { stages: vec![Flow::Password]}],
completed: vec![],
session: None,
auth_error: None
}
}
}
#[derive(Debug, serde::Serialize)]
struct FlowStage {
stages: Vec<Flow>,
}