diff --git a/Cargo.lock b/Cargo.lock index 27c20fc..8a8a8fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,12 @@ dependencies = [ "password-hash", ] +[[package]] +name = "assign" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f093eed78becd229346bf859eec0aa4dd7ddde0757287b2b4107a1f09c80002" + [[package]] name = "async-trait" version = "0.1.53" @@ -110,6 +116,18 @@ dependencies = [ "mime", ] +[[package]] +name = "axum-macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cae774e664fd50bf80c9a7132d64ff70be3232a192ff4ba5619c3f72b9b4711f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "base64" version = "0.13.0" @@ -419,9 +437,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "http" -version = "0.2.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f4c6746584866f0feabcc69893c5b51beef3831656a968ed7ae254cdc4fd03" +checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" dependencies = [ "bytes", "fnv", @@ -499,6 +517,7 @@ checksum = "0f647032dfaa1f8b6dc29bd3edb7bbef4861b8b8007ebb118d6db284fd59f6ee" dependencies = [ "autocfg", "hashbrown", + "serde", ] [[package]] @@ -534,6 +553,24 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "js_int" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d937f95470b270ce8b8950207715d71aa8e153c0d44c6684d59397ed4949160a" +dependencies = [ + "serde", +] + +[[package]] +name = "js_option" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68421373957a1593a767013698dbf206e2b221eefe97a44d98d18672ff38423c" +dependencies = [ + "serde", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -576,6 +613,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matchers" version = "0.1.0" @@ -604,8 +647,12 @@ dependencies = [ "anyhow", "argon2", "axum", + "axum-macros", + "http", "rand", + "ruma", "serde", + "serde_json", "sqlx", "thiserror", "tokio", @@ -815,6 +862,16 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "proc-macro-crate" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" +dependencies = [ + "thiserror", + "toml", +] + [[package]] name = "proc-macro2" version = "1.0.37" @@ -911,6 +968,82 @@ dependencies = [ "winapi", ] +[[package]] +name = "ruma" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6602cb2ef70629013d1bfade5aeb775d971a5d87f008dd6a8c99566235fa1933" +dependencies = [ + "assign", + "js_int", + "ruma-client-api", + "ruma-common", +] + +[[package]] +name = "ruma-client-api" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4339827423dbb3b4f86cb191a38621f12daef73cb304ffd4e050c9ee553ecbd6" +dependencies = [ + "assign", + "bytes", + "http", + "js_int", + "maplit", + "percent-encoding", + "ruma-common", + "serde", + "serde_json", +] + +[[package]] +name = "ruma-common" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ec5360fd23ff56310f9eb571927614feeecceba91fe2d4937f031c236c0e86e" +dependencies = [ + "base64", + "bytes", + "form_urlencoded", + "http", + "indexmap", + "itoa", + "js_int", + "js_option", + "percent-encoding", + "ruma-identifiers-validation", + "ruma-macros", + "serde", + "serde_json", + "thiserror", + "tracing", + "url", + "wildmatch", +] + +[[package]] +name = "ruma-identifiers-validation" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74c3b1d01b5ddd8746f25d5971bc1cac5d7f1f455de839a2f817b9e04953a139" +dependencies = [ + "thiserror", +] + +[[package]] +name = "ruma-macros" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1a4faf04110071ce7ca438ad0763bdaa5514395593596320c0ca0936519656" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "ruma-identifiers-validation", + "syn", +] + [[package]] name = "rustls" version = "0.19.1" @@ -1298,6 +1431,15 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" +dependencies = [ + "serde", +] + [[package]] name = "tower" version = "0.4.12" @@ -1605,6 +1747,12 @@ dependencies = [ "webpki", ] +[[package]] +name = "wildmatch" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6c48bd20df7e4ced539c12f570f937c6b4884928a87fee70a479d72f031d4e0" + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 4e81b08..afef320 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,14 @@ axum = "0.5" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } serde = {version = "1.0", features = ["derive"] } +serde_json = "1.0" tower-http = { version = "0.2", features = ["cors", "trace"]} sqlx = { version = "0.5", features = ["sqlite", "macros", "runtime-tokio-rustls", "offline"] } anyhow = "1.0" thiserror = "1.0" argon2 = { version = "0.4", features = ["std"] } rand = { version = "0.8.5", features = ["std"] } -uuid = { version = "1.0", features = ["v4"] } \ No newline at end of file +uuid = { version = "1.0", features = ["v4"] } +ruma = { version = "0.6.4", features = ["client-api"] } +axum-macros = "0.2.2" +http = "0.2.8" diff --git a/migrations/20220507162217_add_rooms.sql b/migrations/20220507162217_add_rooms.sql new file mode 100644 index 0000000..d79eb0a --- /dev/null +++ b/migrations/20220507162217_add_rooms.sql @@ -0,0 +1,8 @@ +-- Add migration script here + +CREATE TABLE rooms( + uuid TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL +); + +CREATE UNIQUE INDEX name_index ON rooms(name); \ No newline at end of file diff --git a/migrations/20220507162532_add_events.sql b/migrations/20220507162532_add_events.sql new file mode 100644 index 0000000..9711628 --- /dev/null +++ b/migrations/20220507162532_add_events.sql @@ -0,0 +1,13 @@ +-- Add migration script here + +CREATE TABLE events( + uuid TEXT PRIMARY KEY NOT NULL, + room_uuid TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT, + sender_uuid TEXT NOT NULL, + origin_server_ts INT NOT NULL, + content TEXT NOT NULL, + FOREIGN KEY(room_uuid) REFERENCES rooms(uuid), + FOREIGN KEY(sender_uuid) REFERENCES users(uuid) +); \ No newline at end of file diff --git a/src/api/client_server/r0/auth.rs b/src/api/client_server/r0/auth.rs index 53d1671..664a3b3 100644 --- a/src/api/client_server/r0/auth.rs +++ b/src/api/client_server/r0/auth.rs @@ -1,9 +1,7 @@ use std::{collections::HashMap, sync::Arc}; -use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher}; use axum::{ extract::Query, - http::StatusCode, routing::{get, post}, Extension, Json, }; @@ -19,19 +17,19 @@ use crate::{ authentication::{AuthenticationResponse, AuthenticationSuccess}, registration::RegistrationResponse, }, - types::uuid::Uuid, -}; -use crate::{ - models::devices::Device, - responses::{flow::Flows, registration::RegistrationSuccess}, + ruma_wrapper::RumaResponse, }; +use crate::{models::devices::Device, responses::registration::RegistrationSuccess}; use crate::{ models::users::User, requests::registration::RegistrationRequest, - responses::username_available::UsernameAvailable, types::{authentication_data::AuthenticationData, user_id::UserId}, Config, }; +use ruma::api::client::{ + account, + session::get_login_types::v3::{LoginType, PasswordLoginType}, +}; pub fn routes() -> axum::Router { axum::Router::new() @@ -40,9 +38,12 @@ pub fn routes() -> axum::Router { .route("/r0/register/available", get(get_username_available)) } +use ruma::api::client::session; #[tracing::instrument] -async fn get_login() -> Result, ApiError> { - Ok(Json(Flows::new())) +async fn get_login() -> Result, ApiError> { + Ok(RumaResponse(session::get_login_types::v3::Response::new( + vec![LoginType::Password(PasswordLoginType::new())], + ))) } #[tracing::instrument(skip_all)] @@ -92,7 +93,7 @@ async fn get_username_available( Extension(config): Extension>, Extension(db): Extension, Query(params): Query>, -) -> Result, ApiError> { +) -> Result, ApiError> { let username = params .get("username") .ok_or(RegistrationError::MissingUserId)?; @@ -101,7 +102,9 @@ async fn get_username_available( .ok_or(RegistrationError::InvalidUserId)?; let exists = User::exists(&db, &user_id).await?; - Ok(Json(UsernameAvailable::new(!exists))) + Ok(RumaResponse( + account::get_username_availability::v3::Response::new(!exists), + )) } #[tracing::instrument(skip_all)] @@ -110,7 +113,10 @@ async fn post_register( Extension(db): Extension, Json(body): Json, ) -> Result, ApiError> { - config.enable_registration().then(|| true).ok_or(RegistrationError::RegistrationDisabled)?; + config + .enable_registration() + .then(|| true) + .ok_or(RegistrationError::RegistrationDisabled)?; body.auth() .ok_or(RegistrationError::AdditionalAuthenticationInformation)?; diff --git a/src/api/client_server/r0/create_room.rs b/src/api/client_server/r0/create_room.rs new file mode 100644 index 0000000..d580997 --- /dev/null +++ b/src/api/client_server/r0/create_room.rs @@ -0,0 +1,17 @@ +use axum::routing::post; + +use crate::api::client_server::errors::api_error::ApiError; +use crate::ruma_wrapper::RumaRequest; + +use ruma::api::client::room; + +pub fn routes() -> axum::Router { + axum::Router::new() + .route("/r0/createRoom", post(post_create_room)) + .layer(axum::middleware::from_fn(super::authentication_middleware)) +} +async fn post_create_room( + RumaRequest(_req): RumaRequest, +) -> Result { + Ok("".into()) +} diff --git a/src/api/client_server/r0/mod.rs b/src/api/client_server/r0/mod.rs index bf3e108..de478da 100644 --- a/src/api/client_server/r0/mod.rs +++ b/src/api/client_server/r0/mod.rs @@ -14,6 +14,7 @@ use super::errors::ErrorResponse; pub mod auth; pub mod thirdparty; +pub mod create_room; async fn authentication_middleware(mut req: Request, next: Next) -> impl IntoResponse { let db: &SqlitePool = req.extensions().get().unwrap(); diff --git a/src/api/client_server/r0/thirdparty.rs b/src/api/client_server/r0/thirdparty.rs index ba16ba7..2cdbe4c 100644 --- a/src/api/client_server/r0/thirdparty.rs +++ b/src/api/client_server/r0/thirdparty.rs @@ -1,9 +1,14 @@ -use std::sync::Arc; +use std::{collections::BTreeMap, sync::Arc}; use axum::{routing::get, Extension}; -use crate::{api::client_server::errors::api_error::ApiError, models::users::User}; +use crate::{ + api::client_server::errors::api_error::ApiError, + models::users::User, + ruma_wrapper::{RumaRequest, RumaResponse}, +}; +use ruma::api::client::thirdparty; pub fn routes() -> axum::Router { axum::Router::new() @@ -12,6 +17,11 @@ pub fn routes() -> axum::Router { } #[tracing::instrument(skip_all)] -async fn get_thirdparty_protocols(Extension(user): Extension>) -> Result { - Ok("{}".into()) -} \ No newline at end of file +async fn get_thirdparty_protocols( + Extension(_user): Extension>, + RumaRequest(_req): RumaRequest, +) -> Result, ApiError> { + Ok(RumaResponse(thirdparty::get_protocols::v3::Response::new( + BTreeMap::new(), + ))) +} diff --git a/src/main.rs b/src/main.rs index 88c8818..83101b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,3 @@ -#![allow(unused)] - use std::sync::Arc; use axum::{ @@ -9,17 +7,15 @@ use axum::{ Extension, Router, }; use config::Config; -use tower_http::{ - cors::CorsLayer, - trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}, -}; -use tracing::Level; +use tower_http::{cors::CorsLayer, trace::TraceLayer}; mod api; mod config; mod models; mod requests; mod responses; +mod ruma_wrapper; +mod state_resolution; mod types; #[tokio::main] @@ -44,7 +40,8 @@ async fn main() -> anyhow::Result<()> { let client_server = Router::new() .merge(api::client_server::versions::routes()) .merge(api::client_server::r0::auth::routes()) - .merge(api::client_server::r0::thirdparty::routes()); + .merge(api::client_server::r0::thirdparty::routes()) + .merge(api::client_server::r0::create_room::routes()); let router = Router::new() .nest("/_matrix/client", client_server) @@ -61,7 +58,7 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn fallback(mut request: Request) -> StatusCode { - println!("{} {}", request.method(), request.uri()); +async fn fallback(request: Request) -> StatusCode { + tracing::error!("{} {}", request.method(), request.uri()); StatusCode::INTERNAL_SERVER_ERROR } diff --git a/src/models/events.rs b/src/models/events.rs new file mode 100644 index 0000000..1ea5dc9 --- /dev/null +++ b/src/models/events.rs @@ -0,0 +1,72 @@ +use sqlx::SqlitePool; + +use crate::types::{uuid::Uuid, event_type::EventType}; + +use super::{rooms::Room, users::User}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct Event { + uuid: Uuid, + room_uuid: Uuid, + r#type: EventType, + state_key: Option, + sender_uuid: Uuid, + origin_server_ts: i64, + content: String, +} + +impl Event { + fn new( + room: &Room, + r#type: EventType, + state_key: Option, + sender: &User, + origin_server_ts: i64, + content: &str, + ) -> anyhow::Result { + Ok(Self { + uuid: uuid::Uuid::new_v4().into(), + room_uuid: room.uuid().to_owned(), + state_key, + r#type, + sender_uuid: sender.uuid().to_owned(), + origin_server_ts, + content: content.to_owned(), + }) + } + + pub async fn create(&self, conn: &SqlitePool) -> anyhow::Result { + Ok(sqlx::query_as!( + Self, + "insert into events(uuid, room_uuid, type, state_key, sender_uuid, origin_server_ts, content) + values(?, ?, ?, ?, ?, ?, ?) + returning uuid as 'uuid: Uuid', room_uuid as 'room_uuid: Uuid', type as 'type: EventType', state_key, sender_uuid as 'sender_uuid: Uuid', origin_server_ts, content", + self.uuid, + self.room_uuid, + self.r#type, + self.state_key, + self.sender_uuid, + self.origin_server_ts, + self.content + ) + .fetch_one(conn) + .await?) + } + + pub async fn all_for_room(conn: &SqlitePool, room: &Room) -> anyhow::Result> { + let room_uuid = room.uuid(); + Ok(sqlx::query_as!( + Self, + "select uuid as 'uuid: Uuid', room_uuid as 'room_uuid: Uuid', type as 'type: EventType', state_key, sender_uuid as 'sender_uuid: Uuid', origin_server_ts, content + from events + where room_uuid = ?", + room_uuid + ) + .fetch_all(conn) + .await?) + } + + pub fn content(&self) -> serde_json::Value { + serde_json::from_str(&self.content).expect("has to be valid json") + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 113c92a..4abd50f 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,3 +1,5 @@ pub mod devices; +pub mod events; +pub mod rooms; pub mod sessions; pub mod users; diff --git a/src/models/rooms.rs b/src/models/rooms.rs new file mode 100644 index 0000000..8fe6b44 --- /dev/null +++ b/src/models/rooms.rs @@ -0,0 +1,42 @@ +use sqlx::SqlitePool; + +use crate::types::uuid::Uuid; + +use super::events::Event; + +pub struct Room { + uuid: Uuid, + name: String, +} + +impl Room { + fn new(name: &str) -> anyhow::Result { + Ok(Self { + uuid: uuid::Uuid::new_v4().into(), + name: name.to_owned(), + }) + } + + pub async fn create(&self, conn: &SqlitePool) -> anyhow::Result { + Ok(sqlx::query_as!( + Self, + "insert into rooms(uuid, name) + values(?, ?) + returning uuid as 'uuid: Uuid', name", + self.uuid, + self.name + ) + .fetch_one(conn) + .await?) + } + + pub async fn events(&self, conn: &SqlitePool) -> anyhow::Result> { + Event::all_for_room(conn, self).await + } + + /// Get a reference to the room's uuid. + #[must_use] + pub fn uuid(&self) -> &Uuid { + &self.uuid + } +} diff --git a/src/requests/create_room_request.rs b/src/requests/create_room_request.rs new file mode 100644 index 0000000..5bbaf48 --- /dev/null +++ b/src/requests/create_room_request.rs @@ -0,0 +1,29 @@ +use crate::types::user_id::UserId; + +#[derive(Debug, serde::Deserialize)] +pub struct CreateRoomRequest { + /// Extra keys, such as `m.federate`, to be added to the content of the `m.room.create` event. + creation_content: Option<()>, + /// List of state events to set in the initial room. Used for overriding the default state + initial_state: Vec<()>, + /// List of user IDs to invite to the room + invite: Option>, + /// List of thirdparty IDs to invite to the room + invite_3pid: Option>, + /// Indicate if room is a direct chat room + is_direct: Option, + /// Set name of the room + name: Option, + /// Used to override the default power level event + power_level_content_override: Option<()>, + /// Preset for room creation + preset: Option<()>, + /// Desired room alias local part + room_alias_name: Option, + /// Version of room to create. Defaults to server default + room_version: Option, + /// Sets rooms topic + topic: Option, + /// Sets rooms visibility + visibility: () +} \ No newline at end of file diff --git a/src/requests/mod.rs b/src/requests/mod.rs index 038a1c7..6c83275 100644 --- a/src/requests/mod.rs +++ b/src/requests/mod.rs @@ -1 +1,2 @@ pub mod registration; +pub mod create_room_request; \ No newline at end of file diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs new file mode 100644 index 0000000..2ca287d --- /dev/null +++ b/src/ruma_wrapper.rs @@ -0,0 +1,65 @@ +use axum::{ + body::{Bytes, HttpBody, Full}, + extract::{FromRequest, Path}, + response::IntoResponse, + BoxError, +}; +use http::StatusCode; +use ruma::{ + api::{IncomingRequest, OutgoingResponse}, + exports::bytes::{BufMut, BytesMut}, + serde::CanonicalJsonValue, +}; + +use crate::api::client_server::errors::api_error::ApiError; + +pub struct RumaRequest(pub R) +where + R: IncomingRequest; + +#[axum::async_trait] +impl FromRequest for RumaRequest +where + R: IncomingRequest, + B: HttpBody + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = ApiError; + + async fn from_request( + req: &mut axum::extract::RequestParts, + ) -> Result { + let path_params = Path::>::from_request(req) + .await + .map_err(|e| anyhow::anyhow!(e))?; + let body = Bytes::from_request(req) + .await + .map_err(|e| anyhow::anyhow!(e))?; + + let json = + serde_json::from_slice::(&body).map_err(|e| anyhow::anyhow!(e))?; + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &json).expect("can't fail"); + let body = buf.into_inner().freeze(); + + let builder = http::Request::builder().uri(req.uri()).method(req.method()); + let request = builder.body(body).map_err(|e| anyhow::anyhow!(e))?; + Ok(Self( + R::try_from_http_request(request, &path_params).map_err(|e| anyhow::anyhow!(e))?, + )) + } +} + +pub struct RumaResponse(pub R) +where + R: OutgoingResponse; + +impl IntoResponse for RumaResponse { + fn into_response(self) -> axum::response::Response { + match self.0.try_into_http_response::() { + Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), + Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } +} diff --git a/src/state_resolution/mod.rs b/src/state_resolution/mod.rs new file mode 100644 index 0000000..7d88584 --- /dev/null +++ b/src/state_resolution/mod.rs @@ -0,0 +1 @@ +mod v2; \ No newline at end of file diff --git a/src/state_resolution/v2.rs b/src/state_resolution/v2.rs new file mode 100644 index 0000000..fb8630d --- /dev/null +++ b/src/state_resolution/v2.rs @@ -0,0 +1,102 @@ +use crate::models::events::Event; +use std::{ + collections::{HashMap, HashSet}, + future::Future, +}; +use tracing::info; + +type StateMap = HashMap; + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct StateTuple { + event_type: String, + state_key: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct EventId { + id: Box, +} + +#[tracing::instrument(skip(state_sets, auth_chain_sets, get_event_callback))] +pub async fn resolve( + room_id: &str, // TODO: own type + state_sets: Vec>, + auth_chain_sets: Vec>, + get_event_callback: F, +) -> StateMap +where + F: Fn(&EventId) -> Fut, + Fut: Future>, +{ + info!("Calculating conflicted state"); + + let (unconflicted_state, conflicted_state) = separate_state(&state_sets); + + if conflicted_state.is_empty() { + return unconflicted_state; + } + + info!("{} conflicted_state entries", conflicted_state.len()); + info!("Calculating auth_chain differences"); + + let conflicted_set = + get_auth_chain_differences(auth_chain_sets).chain(conflicted_state.into_values().flatten()); + let mut conflicted = HashSet::new(); + for eid in conflicted_set { + if let Some(event) = get_event_callback(&eid).await { + conflicted.insert(event); + } + } + + todo!() +} + +/// separates states from multiple state_maps into unconflicted and conflicted state +/// +/// For the set of all state_tuples find all event_ids. +/// If one event_id is found it is unconflicted, otherwise it is conflicted +fn separate_state( + state_sets: &[StateMap], +) -> (StateMap, StateMap>) { + let mut unconflicted_state: StateMap = StateMap::new(); + let mut conflicted_state: HashMap> = StateMap::new(); + + for key in state_sets + .iter() + .flat_map(HashMap::keys) + .map(ToOwned::to_owned) + .collect::>() + { + let mut event_ids: HashSet = state_sets + .iter() + .filter_map(|state_set| state_set.get(&key)) + .map(ToOwned::to_owned) + .collect(); + + if event_ids.len() == 1 { + unconflicted_state.insert(key, event_ids.into_iter().next().expect("len() is 1")); + } else { + conflicted_state.insert(key, event_ids); + } + } + + (unconflicted_state, conflicted_state) +} + +fn get_auth_chain_differences( + auth_chain_sets: Vec>, +) -> impl Iterator { + let num_sets = auth_chain_sets.len(); + + let mut id_counts: HashMap = HashMap::new(); + for id in auth_chain_sets.into_iter().flatten() { + *id_counts.entry(id).or_default() += 1; + } + + id_counts + .into_iter() + .filter_map(move |(id, count)| (count < num_sets).then(move || id)) +} + +fn is_control_event() {} diff --git a/src/types/client_event.rs b/src/types/client_event.rs new file mode 100644 index 0000000..eb5287c --- /dev/null +++ b/src/types/client_event.rs @@ -0,0 +1,12 @@ +use super::{uuid::Uuid, user_id::UserId}; + +pub struct ClientEvent { + content: (), + event_id: Uuid, + origin_server_ts: u64, + room_id: String, + sender: UserId, + state_key: Option, + r#type: String, + unsigned: () +} \ No newline at end of file diff --git a/src/types/event_type.rs b/src/types/event_type.rs new file mode 100644 index 0000000..046bcd3 --- /dev/null +++ b/src/types/event_type.rs @@ -0,0 +1,82 @@ +use sqlx::{encode::IsNull, Sqlite}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum EventType { + RoomCreate, + Unknown, +} + +impl sqlx::Type for EventType { + fn type_info() -> ::TypeInfo { + <&str as sqlx::Type>::type_info() + } +} + +impl<'e> sqlx::Encode<'e, Sqlite> for EventType { + fn encode_by_ref( + &self, + buf: &mut >::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + buf.push(sqlx::sqlite::SqliteArgumentValue::Text( + match self { + EventType::RoomCreate => "m.room.create", + EventType::Unknown => "???" + } + .into(), + )); + + IsNull::No + } +} + +impl<'d> sqlx::Decode<'d, Sqlite> for EventType { + fn decode( + value: >::ValueRef, + ) -> Result { + Ok(match <&str as sqlx::Decode>::decode(value)? { + "m.room.create" => EventType::RoomCreate, + _ => EventType::Unknown, + }) + } +} + +impl serde::Serialize for EventType { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(match self { + EventType::RoomCreate => "m.room.create", + EventType::Unknown => "dev.fuckwit.unknown_event" + }) + } +} + +impl<'de> serde::Deserialize<'de> for EventType { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct IdentifierVisitor; + + impl<'de> serde::de::Visitor<'de> for IdentifierVisitor { + type Value = EventType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("Identifier") + } + + fn visit_borrowed_str(self, v: &'de str) -> Result + where + E: serde::de::Error, + { + match v { + "m.id.user" => Ok(EventType::RoomCreate), + _ => Err(serde::de::Error::custom("Unknown identifier")), + } + } + } + + deserializer.deserialize_str(IdentifierVisitor {}) + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index a92ff25..9a6c61b 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -6,4 +6,7 @@ pub mod identifier_type; pub mod user_id; pub mod user_interactive_authorization; pub mod uuid; -pub mod server_name; \ No newline at end of file +pub mod server_name; +pub mod client_event; +pub mod room_event; +pub mod event_type; \ No newline at end of file diff --git a/src/types/room_event.rs b/src/types/room_event.rs new file mode 100644 index 0000000..d0fff51 --- /dev/null +++ b/src/types/room_event.rs @@ -0,0 +1,7 @@ +pub enum RoomEvent { + Create(RoomCreateEvent) +} + +pub struct RoomCreateEvent { + +} \ No newline at end of file diff --git a/src/types/uuid.rs b/src/types/uuid.rs index 36586ff..6f96f2d 100644 --- a/src/types/uuid.rs +++ b/src/types/uuid.rs @@ -1,6 +1,6 @@ use sqlx::{encode::IsNull, Sqlite, Type}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Uuid(pub uuid::Uuid); impl Uuid {