Compare commits
22 Commits
a789a8cc56
...
master
Author | SHA1 | Date | |
---|---|---|---|
d98e9ea9e3 | |||
9510d9c765 | |||
e33a734199 | |||
35bde07b39 | |||
26e39c7c06 | |||
6ed7b16bf6 | |||
277d7111c8 | |||
8ada363a92 | |||
29093c51e3 | |||
71590d6c60 | |||
ba84efd384 | |||
54f67d435e | |||
2c2ac27c26 | |||
c20b4c6a23 | |||
3b8c529183 | |||
341c516fcb | |||
b0895145de | |||
304f82baa4 | |||
601b2d4f42 | |||
2c91e99a4d | |||
b8e1235396 | |||
b4b4f837cf |
11
.drone.yml
Normal file
11
.drone.yml
Normal file
@ -0,0 +1,11 @@
|
||||
kind: pipeline
|
||||
type: docker
|
||||
name: check
|
||||
|
||||
steps:
|
||||
- name: cargo check
|
||||
image: rust:latest
|
||||
environment:
|
||||
SQLX_OFFLINE: 'true'
|
||||
commands:
|
||||
- cargo check
|
1862
Cargo.lock
generated
1862
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
23
Cargo.toml
23
Cargo.toml
@ -1,17 +1,8 @@
|
||||
[package]
|
||||
name = "matrix"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
[workspace]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.17.0", features = ["full"] }
|
||||
axum = "0.5.3"
|
||||
tracing = "0.1.34"
|
||||
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
|
||||
serde = {version = "1.0.136", features = ["derive"] }
|
||||
tower-http = { version = "0.2.5", features = ["cors", "trace"]}
|
||||
sqlx = { version = "0.5.13", features = ["sqlite", "macros", "runtime-tokio-rustls"] }
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
members = [
|
||||
"neo",
|
||||
"neo-entity",
|
||||
"neo-migration",
|
||||
"neo-util"
|
||||
]
|
||||
|
@ -1,9 +0,0 @@
|
||||
-- Add migration script here
|
||||
CREATE TABLE users(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
user_id CHAR(255) NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
password TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX user_id_index ON users (user_id);
|
@ -1,11 +0,0 @@
|
||||
-- Add migration script here
|
||||
|
||||
CREATE TABLE devices(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
user_id INT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
CREATE INDEX device_id_index ON devices (device_id);
|
@ -1,10 +0,0 @@
|
||||
-- Add migration script here
|
||||
|
||||
CREATE TABLE sessions(
|
||||
id INTEGER PRIMARY KEY NOT NULL,
|
||||
device_id INT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
FOREIGN KEY(device_id) REFERENCES devices(id)
|
||||
);
|
||||
|
||||
CREATE INDEX value_index ON sessions (value);
|
2
neo-entity/.gitignore
vendored
Normal file
2
neo-entity/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/target
|
||||
/Cargo.lock
|
12
neo-entity/Cargo.toml
Normal file
12
neo-entity/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "neo-entity"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
chrono = {version = "0.4", features = ["serde"] }
|
||||
sea-orm = { version = "^0.9", features = ["macros", "with-chrono", "with-uuid", "with-json"], default-features = false }
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
uuid = { version = "*", features = ["v4", "serde"]}
|
48
neo-entity/src/devices.rs
Normal file
48
neo-entity/src/devices.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use sea_orm::Set;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "devices")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub uuid: Uuid,
|
||||
pub user_uuid: Uuid,
|
||||
pub device_id: String,
|
||||
pub display_name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::users::Entity",
|
||||
from = "Column::UserUuid",
|
||||
to = "super::users::Column::Uuid",
|
||||
on_update = "NoAction",
|
||||
on_delete = "NoAction"
|
||||
)]
|
||||
Users,
|
||||
#[sea_orm(has_many = "super::sessions::Entity")]
|
||||
Sessions,
|
||||
}
|
||||
|
||||
impl Related<super::users::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Users.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::sessions::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Sessions.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
uuid: Set(Uuid::new_v4()),
|
||||
device_id: Set(Uuid::new_v4().to_string()),
|
||||
..ActiveModelTrait::default()
|
||||
}
|
||||
}
|
||||
}
|
55
neo-entity/src/events.rs
Normal file
55
neo-entity/src/events.rs
Normal file
@ -0,0 +1,55 @@
|
||||
use sea_orm::{entity::prelude::*, Set};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub uuid: Uuid,
|
||||
pub room_uuid: Uuid,
|
||||
pub r#type: String,
|
||||
pub state_key: Option<String>,
|
||||
pub sender_uuid: Uuid,
|
||||
pub origin_server_ts: i64,
|
||||
pub content: Json,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::users::Entity",
|
||||
from = "Column::SenderUuid",
|
||||
to = "super::users::Column::Uuid",
|
||||
on_update = "NoAction",
|
||||
on_delete = "NoAction"
|
||||
)]
|
||||
Users,
|
||||
#[sea_orm(
|
||||
belongs_to = "super::rooms::Entity",
|
||||
from = "Column::RoomUuid",
|
||||
to = "super::rooms::Column::Uuid",
|
||||
on_update = "NoAction",
|
||||
on_delete = "NoAction"
|
||||
)]
|
||||
Rooms,
|
||||
}
|
||||
|
||||
impl Related<super::users::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Users.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::rooms::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Rooms.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
uuid: Set(Uuid::new_v4()),
|
||||
..ActiveModelTrait::default()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,3 +1,6 @@
|
||||
pub mod devices;
|
||||
pub mod events;
|
||||
pub mod prelude;
|
||||
pub mod rooms;
|
||||
pub mod sessions;
|
||||
pub mod users;
|
6
neo-entity/src/prelude.rs
Normal file
6
neo-entity/src/prelude.rs
Normal file
@ -0,0 +1,6 @@
|
||||
#[allow(unused_imports)]
|
||||
pub use crate::{
|
||||
devices::{self, Entity as Device, Model as DeviceModel},
|
||||
sessions::{self, Entity as Session, Model as SessionModel},
|
||||
users::{self, Entity as User, Model as UserModel},
|
||||
};
|
30
neo-entity/src/rooms.rs
Normal file
30
neo-entity/src/rooms.rs
Normal file
@ -0,0 +1,30 @@
|
||||
use sea_orm::{entity::prelude::*, Set};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "rooms")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub uuid: Uuid,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::events::Entity")]
|
||||
Events,
|
||||
}
|
||||
|
||||
impl Related<super::events::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Events.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
uuid: Set(Uuid::new_v4()),
|
||||
..ActiveModelTrait::default()
|
||||
}
|
||||
}
|
||||
}
|
37
neo-entity/src/sessions.rs
Normal file
37
neo-entity/src/sessions.rs
Normal file
@ -0,0 +1,37 @@
|
||||
use sea_orm::{entity::prelude::*, Set};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "sessions")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub uuid: Uuid,
|
||||
pub device_uuid: Uuid,
|
||||
pub key: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::devices::Entity",
|
||||
from = "Column::DeviceUuid",
|
||||
to = "super::devices::Column::Uuid",
|
||||
on_update = "NoAction",
|
||||
on_delete = "NoAction"
|
||||
)]
|
||||
Devices,
|
||||
}
|
||||
|
||||
impl Related<super::devices::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Devices.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
uuid: Set(Uuid::new_v4()),
|
||||
..ActiveModelTrait::default()
|
||||
}
|
||||
}
|
||||
}
|
33
neo-entity/src/users.rs
Normal file
33
neo-entity/src/users.rs
Normal file
@ -0,0 +1,33 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use sea_orm::Set;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "users")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub uuid: Uuid,
|
||||
pub user_id: String,
|
||||
pub display_name: String,
|
||||
pub password_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::devices::Entity")]
|
||||
Devices,
|
||||
}
|
||||
|
||||
impl Related<super::devices::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Devices.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
uuid: Set(Uuid::new_v4()),
|
||||
..ActiveModelTrait::default()
|
||||
}
|
||||
}
|
||||
}
|
10
neo-migration/Cargo.toml
Normal file
10
neo-migration/Cargo.toml
Normal file
@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "neo-migration"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
sea-orm-migration = "^0.9"
|
||||
neo-entity = { path = "../neo-entity" }
|
||||
automod = "1"
|
37
neo-migration/README.md
Normal file
37
neo-migration/README.md
Normal file
@ -0,0 +1,37 @@
|
||||
# Running Migrator CLI
|
||||
|
||||
- Apply all pending migrations
|
||||
```sh
|
||||
cargo run
|
||||
```
|
||||
```sh
|
||||
cargo run -- up
|
||||
```
|
||||
- Apply first 10 pending migrations
|
||||
```sh
|
||||
cargo run -- up -n 10
|
||||
```
|
||||
- Rollback last applied migrations
|
||||
```sh
|
||||
cargo run -- down
|
||||
```
|
||||
- Rollback last 10 applied migrations
|
||||
```sh
|
||||
cargo run -- down -n 10
|
||||
```
|
||||
- Drop all tables from the database, then reapply all migrations
|
||||
```sh
|
||||
cargo run -- fresh
|
||||
```
|
||||
- Rollback all applied migrations, then reapply all migrations
|
||||
```sh
|
||||
cargo run -- refresh
|
||||
```
|
||||
- Rollback all applied migrations
|
||||
```sh
|
||||
cargo run -- reset
|
||||
```
|
||||
- Check the status of all migrations
|
||||
```sh
|
||||
cargo run -- status
|
||||
```
|
18
neo-migration/src/lib.rs
Normal file
18
neo-migration/src/lib.rs
Normal file
@ -0,0 +1,18 @@
|
||||
pub use sea_orm_migration::prelude::*;
|
||||
|
||||
automod::dir!("src");
|
||||
|
||||
pub struct Migrator;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigratorTrait for Migrator {
|
||||
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
|
||||
vec![
|
||||
Box::new(m20220707_092851_create_users::Migration),
|
||||
Box::new(m20220707_112339_create_devices::Migration),
|
||||
Box::new(m20220707_143304_create_sessions::Migration),
|
||||
Box::new(m20220724_223253_create_rooms::Migration),
|
||||
Box::new(m20220724_223335_create_events::Migration),
|
||||
]
|
||||
}
|
||||
}
|
45
neo-migration/src/m20220707_092851_create_users.rs
Normal file
45
neo-migration/src/m20220707_092851_create_users.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use neo_entity::users::{self, Entity as User};
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.create_table(
|
||||
Table::create()
|
||||
.table(User)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(users::Column::Uuid)
|
||||
.uuid()
|
||||
.primary_key()
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(users::Column::UserId).string().not_null())
|
||||
.col(
|
||||
ColumnDef::new(users::Column::DisplayName)
|
||||
.string()
|
||||
.not_null(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(users::Column::PasswordHash)
|
||||
.string()
|
||||
.not_null(),
|
||||
)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("user_id_index")
|
||||
.table(User)
|
||||
.col(users::Column::UserId)
|
||||
.to_owned(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
54
neo-migration/src/m20220707_112339_create_devices.rs
Normal file
54
neo-migration/src/m20220707_112339_create_devices.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use neo_entity::devices::{self, Entity as Device};
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.create_table(
|
||||
Table::create()
|
||||
.table(Device)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(devices::Column::Uuid)
|
||||
.uuid()
|
||||
.primary_key()
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(devices::Column::UserUuid).uuid().not_null())
|
||||
.col(
|
||||
ColumnDef::new(devices::Column::DeviceId)
|
||||
.string()
|
||||
.not_null(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(devices::Column::DisplayName)
|
||||
.string()
|
||||
.not_null(),
|
||||
)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("device_id_index")
|
||||
.table(Device)
|
||||
.col(devices::Column::DeviceId)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("user_uuid_index")
|
||||
.table(Device)
|
||||
.col(devices::Column::UserUuid)
|
||||
.to_owned(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
40
neo-migration/src/m20220707_143304_create_sessions.rs
Normal file
40
neo-migration/src/m20220707_143304_create_sessions.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use neo_entity::sessions::{self, Entity as Session};
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.create_table(
|
||||
Table::create()
|
||||
.table(Session)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(sessions::Column::Uuid)
|
||||
.uuid()
|
||||
.primary_key()
|
||||
.not_null(),
|
||||
)
|
||||
.col(
|
||||
ColumnDef::new(sessions::Column::DeviceUuid)
|
||||
.uuid()
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(sessions::Column::Key).string().not_null())
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("device_uuid_index")
|
||||
.table(Session)
|
||||
.col(sessions::Column::DeviceUuid)
|
||||
.to_owned(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
26
neo-migration/src/m20220724_223253_create_rooms.rs
Normal file
26
neo-migration/src/m20220724_223253_create_rooms.rs
Normal file
@ -0,0 +1,26 @@
|
||||
use neo_entity::rooms::{self, Entity as Room};
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.create_table(
|
||||
Table::create()
|
||||
.table(Room)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(rooms::Column::Uuid)
|
||||
.uuid()
|
||||
.primary_key()
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(rooms::Column::Name).string().not_null())
|
||||
.to_owned(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
72
neo-migration/src/m20220724_223335_create_events.rs
Normal file
72
neo-migration/src/m20220724_223335_create_events.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use neo_entity::events::{self, Entity as Event};
|
||||
use sea_orm_migration::prelude::*;
|
||||
|
||||
#[derive(DeriveMigrationName)]
|
||||
pub struct Migration;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl MigrationTrait for Migration {
|
||||
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||
manager
|
||||
.create_table(
|
||||
Table::create()
|
||||
.table(Event)
|
||||
.if_not_exists()
|
||||
.col(
|
||||
ColumnDef::new(events::Column::Uuid)
|
||||
.uuid()
|
||||
.primary_key()
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(events::Column::RoomUuid).uuid().not_null())
|
||||
.col(ColumnDef::new(events::Column::Type).string().not_null())
|
||||
.col(ColumnDef::new(events::Column::StateKey).string())
|
||||
.col(ColumnDef::new(events::Column::SenderUuid).uuid().not_null())
|
||||
.col(
|
||||
ColumnDef::new(events::Column::OriginServerTs)
|
||||
.integer()
|
||||
.not_null(),
|
||||
)
|
||||
.col(ColumnDef::new(events::Column::Content).json().not_null())
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("room_uuid_index")
|
||||
.table(Event)
|
||||
.col(events::Column::RoomUuid)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("type_index")
|
||||
.table(Event)
|
||||
.col(events::Column::Type)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("state_key_index")
|
||||
.table(Event)
|
||||
.col(events::Column::StateKey)
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("type_state_key_index")
|
||||
.table(Event)
|
||||
.col(events::Column::Type)
|
||||
.col(events::Column::StateKey)
|
||||
.to_owned(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
11
neo-util/Cargo.toml
Normal file
11
neo-util/Cargo.toml
Normal file
@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "neo-util"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
argon2 = { version = "0.4", features = ["std"] }
|
||||
rand = { version = "0.8.5", features = ["std"] }
|
20
neo-util/src/events.rs
Normal file
20
neo-util/src/events.rs
Normal file
@ -0,0 +1,20 @@
|
||||
pub static STATE_EVENTS: &[&str] = &[
|
||||
"m.room.create",
|
||||
"m.room.canonical_alias",
|
||||
"m.room.join_rules",
|
||||
"m.room.member",
|
||||
"m.room.power_levels",
|
||||
];
|
||||
|
||||
pub enum EventCategory {
|
||||
StateEvent,
|
||||
MessageEvent,
|
||||
}
|
||||
|
||||
pub fn classify_event(event_type: &str) -> EventCategory {
|
||||
if STATE_EVENTS.contains(&event_type) {
|
||||
EventCategory::StateEvent
|
||||
} else {
|
||||
EventCategory::MessageEvent
|
||||
}
|
||||
}
|
2
neo-util/src/lib.rs
Normal file
2
neo-util/src/lib.rs
Normal file
@ -0,0 +1,2 @@
|
||||
pub mod events;
|
||||
pub mod password;
|
17
neo-util/src/password.rs
Normal file
17
neo-util/src/password.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
pub fn hash_password(password: &str) -> anyhow::Result<String> {
|
||||
let argon2 = Argon2::default();
|
||||
let salt = SaltString::generate(OsRng);
|
||||
Ok(argon2
|
||||
.hash_password(password.as_bytes(), &salt)?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
pub fn password_correct(password: &str, hash: &str) -> anyhow::Result<bool> {
|
||||
let password_hash = PasswordHash::new(hash)?;
|
||||
Ok(Argon2::default()
|
||||
.verify_password(password.as_bytes(), &password_hash)
|
||||
.is_ok())
|
||||
}
|
23
neo/Cargo.toml
Normal file
23
neo/Cargo.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "neo"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.17", features = ["full"] }
|
||||
axum = "0.5"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
serde = {version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tower-http = { version = "0.2", features = ["cors", "trace"] }
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
rand = { version = "0.8.5", features = ["std"] }
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
ruma = { version = "0.6.4", features = ["client-api", "compat"] }
|
||||
http = "0.2.8"
|
||||
sea-orm = { version = "^0.9", features = ["sqlx-sqlite", "runtime-tokio-native-tls", "macros"], default-features = false }
|
||||
neo-entity = { version = "*", path = "../neo-entity" }
|
||||
neo-migration = { version = "*", path = "../neo-migration" }
|
||||
neo-util = { version = "*", path = "../neo-util" }
|
82
neo/src/api/client_server/errors/api_error.rs
Normal file
82
neo/src/api/client_server/errors/api_error.rs
Normal file
@ -0,0 +1,82 @@
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
|
||||
use crate::types::error_code::ErrorCode;
|
||||
|
||||
use super::authentication_error::AuthenticationError;
|
||||
use super::registration_error::RegistrationError;
|
||||
use super::ErrorResponse;
|
||||
|
||||
macro_rules! map_err {
|
||||
($err:ident, $($type:path => $target:path),+) => {
|
||||
$(
|
||||
if $err.is::<$type>() {
|
||||
return $target($err.downcast().unwrap());
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum ApiError {
|
||||
#[error("Registration Error")]
|
||||
RegistrationError(#[from] RegistrationError),
|
||||
|
||||
#[error("Authentication Error")]
|
||||
AuthenticationError(#[from] AuthenticationError),
|
||||
|
||||
#[error("Database Error")]
|
||||
DBError(#[from] sea_orm::DbErr),
|
||||
|
||||
#[error("Generic Error")]
|
||||
Generic(anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ApiError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
map_err!(err,
|
||||
sea_orm::DbErr => ApiError::DBError,
|
||||
RegistrationError => ApiError::RegistrationError,
|
||||
AuthenticationError => ApiError::AuthenticationError
|
||||
);
|
||||
|
||||
ApiError::Generic(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
ApiError::RegistrationError(e) => e.into_response(),
|
||||
ApiError::AuthenticationError(e) => e.into_response(),
|
||||
ApiError::DBError(err) => {
|
||||
tracing::error!("{}", err.to_string());
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Unknown,
|
||||
"Database error! If you are the application owner please take a look at your application logs.",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
ApiError::Generic(err) => {
|
||||
tracing::error!("{}", err.to_string());
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Unknown,
|
||||
"Fatal error occured! If you are the application owner please take a look at your application logs.",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
53
neo/src/api/client_server/errors/authentication_error.rs
Normal file
53
neo/src/api/client_server/errors/authentication_error.rs
Normal file
@ -0,0 +1,53 @@
|
||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||
|
||||
use crate::types::error_code::ErrorCode;
|
||||
|
||||
use super::ErrorResponse;
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[non_exhaustive]
|
||||
pub enum AuthenticationError {
|
||||
#[error("UserId is missing")]
|
||||
MissingUserId,
|
||||
#[error("The user ID is not a valid user name")]
|
||||
InvalidUserId,
|
||||
#[error("The provided authentication data was incorrect")]
|
||||
Forbidden,
|
||||
#[error("The user has been deactivated")]
|
||||
UserDeactivated,
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthenticationError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::InvalidUserId | Self::MissingUserId => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::InvalidUsername,
|
||||
&self.to_string(),
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response(),
|
||||
Self::Forbidden => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Forbidden,
|
||||
&self.to_string(),
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response(),
|
||||
Self::UserDeactivated => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::UserDeactivated,
|
||||
&self.to_string(),
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
23
neo/src/api/client_server/errors/mod.rs
Normal file
23
neo/src/api/client_server/errors/mod.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use crate::types::error_code::ErrorCode;
|
||||
|
||||
pub mod api_error;
|
||||
pub mod authentication_error;
|
||||
pub mod registration_error;
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
errcode: ErrorCode,
|
||||
error: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
retry_after_ms: Option<u64>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
|
||||
Self {
|
||||
errcode,
|
||||
error: error.to_owned(),
|
||||
retry_after_ms,
|
||||
}
|
||||
}
|
||||
}
|
58
neo/src/api/client_server/errors/registration_error.rs
Normal file
58
neo/src/api/client_server/errors/registration_error.rs
Normal file
@ -0,0 +1,58 @@
|
||||
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||
|
||||
use crate::{responses::registration::RegistrationResponse, types::error_code::ErrorCode};
|
||||
|
||||
use super::ErrorResponse;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegistrationError {
|
||||
#[error("The homeserver requires additional authentication information")]
|
||||
AdditionalAuthenticationInformation,
|
||||
#[error("UserId is missing")]
|
||||
MissingUserId,
|
||||
#[error("The desired user ID is not a valid user name")]
|
||||
InvalidUserId,
|
||||
#[error("The desired user ID is already taken")]
|
||||
UserIdTaken,
|
||||
#[error("Registration is disabled")]
|
||||
RegistrationDisabled,
|
||||
}
|
||||
|
||||
impl IntoResponse for RegistrationError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
RegistrationError::AdditionalAuthenticationInformation => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(RegistrationResponse::user_interactive_authorization_info()),
|
||||
)
|
||||
.into_response(),
|
||||
RegistrationError::InvalidUserId | RegistrationError::MissingUserId => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::InvalidUsername,
|
||||
&self.to_string(),
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response(),
|
||||
RegistrationError::UserIdTaken => (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::UserInUse,
|
||||
&self.to_string(),
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response(),
|
||||
RegistrationError::RegistrationDisabled => (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Forbidden,
|
||||
&self.to_string(),
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,3 +1,3 @@
|
||||
pub mod auth;
|
||||
pub mod versions;
|
||||
pub mod errors;
|
||||
pub mod r0;
|
||||
pub mod versions;
|
299
neo/src/api/client_server/r0/auth.rs
Normal file
299
neo/src/api/client_server/r0/auth.rs
Normal file
@ -0,0 +1,299 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
routing::{get, post},
|
||||
Extension,
|
||||
};
|
||||
|
||||
use neo_entity::prelude::*;
|
||||
use neo_util::password;
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use sea_orm::{ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set};
|
||||
|
||||
use crate::{
|
||||
api::client_server::errors::{
|
||||
api_error::ApiError, authentication_error::AuthenticationError,
|
||||
registration_error::RegistrationError,
|
||||
},
|
||||
ruma_wrapper::{RumaRequest, RumaResponse},
|
||||
Config,
|
||||
};
|
||||
|
||||
use ruma::api::client::{
|
||||
account, session,
|
||||
uiaa::{IncomingAuthData, IncomingUserIdentifier},
|
||||
};
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/login", get(get_login_types).post(login))
|
||||
.route("/r0/register", post(post_register))
|
||||
.route("/r0/register/available", get(get_username_available))
|
||||
}
|
||||
|
||||
async fn get_login_types() -> Result<RumaResponse<session::get_login_types::v3::Response>, ApiError>
|
||||
{
|
||||
use session::get_login_types::v3::*;
|
||||
|
||||
Ok(RumaResponse(session::get_login_types::v3::Response::new(
|
||||
vec![LoginType::Password(PasswordLoginType::new())],
|
||||
)))
|
||||
}
|
||||
|
||||
async fn login(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<DatabaseConnection>,
|
||||
RumaRequest(req): RumaRequest<session::login::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<session::login::v3::Response>, ApiError> {
|
||||
use session::login::v3::*;
|
||||
|
||||
match req.login_info {
|
||||
IncomingLoginInfo::Password(incoming_password) => {
|
||||
let password = incoming_password.password;
|
||||
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
|
||||
incoming_password.identifier
|
||||
{
|
||||
ruma::UserId::parse_with_server_name(user_id, &config.server_name)
|
||||
.map_err(|_| AuthenticationError::InvalidUserId)?
|
||||
} else {
|
||||
return Err(AuthenticationError::InvalidUserId.into());
|
||||
};
|
||||
|
||||
let db_user = User::find()
|
||||
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.ok_or(AuthenticationError::InvalidUserId)?;
|
||||
|
||||
if !password::password_correct(&password, &db_user.password_hash)
|
||||
.map_err(|_| AuthenticationError::Forbidden)?
|
||||
{
|
||||
return Err(AuthenticationError::Forbidden.into());
|
||||
}
|
||||
|
||||
let device = if let Some(device_id) = req.device_id {
|
||||
Device::find()
|
||||
.belongs_to(&db_user)
|
||||
.filter(devices::Column::DeviceId.eq(device_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.unwrap()
|
||||
} else {
|
||||
let device_id = uuid::Uuid::new_v4().to_string();
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Generic Device".into());
|
||||
|
||||
let device = devices::ActiveModel {
|
||||
device_id: Set(device_id),
|
||||
display_name: Set(display_name),
|
||||
user_uuid: Set(db_user.uuid),
|
||||
..Default::default()
|
||||
};
|
||||
device.insert(&db).await?
|
||||
};
|
||||
|
||||
let key = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let session = sessions::ActiveModel {
|
||||
device_uuid: Set(device.uuid),
|
||||
key: Set(key),
|
||||
..Default::default()
|
||||
};
|
||||
let session = session.insert(&db).await?;
|
||||
let response = Response::new(
|
||||
user_id,
|
||||
session.key,
|
||||
ruma::OwnedDeviceId::from(device.device_id),
|
||||
);
|
||||
|
||||
Ok(RumaResponse(response))
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_username_available(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<DatabaseConnection>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<RumaResponse<account::get_username_availability::v3::Response>, ApiError> {
|
||||
use account::get_username_availability::v3::*;
|
||||
tracing::debug!("username_available hit");
|
||||
|
||||
let username = params
|
||||
.get("username")
|
||||
.ok_or(RegistrationError::MissingUserId)?
|
||||
.to_owned();
|
||||
let user_id = ruma::UserId::parse_with_server_name(username, &config.server_name)
|
||||
.map_err(|_| RegistrationError::InvalidUserId)?;
|
||||
|
||||
let available = User::find()
|
||||
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.is_none();
|
||||
|
||||
Ok(RumaResponse(Response::new(available)))
|
||||
}
|
||||
|
||||
async fn post_register(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<DatabaseConnection>,
|
||||
RumaRequest(req): RumaRequest<account::register::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<account::register::v3::Response>, ApiError> {
|
||||
use account::register::v3::*;
|
||||
|
||||
config
|
||||
.enable_registration
|
||||
.then(|| true)
|
||||
.ok_or(RegistrationError::RegistrationDisabled)?;
|
||||
|
||||
match req.auth {
|
||||
Some(auth) => match auth {
|
||||
IncomingAuthData::Password(incoming_password) => {
|
||||
let password = incoming_password.password;
|
||||
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
|
||||
incoming_password.identifier
|
||||
{
|
||||
ruma::UserId::parse_with_server_name(user_id, &config.server_name)
|
||||
.map_err(|_| AuthenticationError::InvalidUserId)?
|
||||
} else {
|
||||
return Err(AuthenticationError::InvalidUserId.into());
|
||||
};
|
||||
|
||||
if User::find()
|
||||
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.is_some()
|
||||
{
|
||||
return Err(RegistrationError::UserIdTaken.into());
|
||||
};
|
||||
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Random Display Name".into());
|
||||
|
||||
let pw_hash = password::hash_password(&password)?;
|
||||
|
||||
let user = users::ActiveModel {
|
||||
user_id: Set(user_id.to_string()),
|
||||
display_name: Set(user_id.to_string()),
|
||||
password_hash: Set(pw_hash),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
.await?;
|
||||
|
||||
let device = devices::ActiveModel {
|
||||
display_name: Set(display_name),
|
||||
user_uuid: Set(user.uuid),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
.await?;
|
||||
|
||||
let mut response = Response::new(
|
||||
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
|
||||
);
|
||||
if !req.inhibit_login {
|
||||
let key = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let session = sessions::ActiveModel {
|
||||
device_uuid: Set(device.uuid),
|
||||
key: Set(key),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
.await?;
|
||||
response.access_token = Some(session.key);
|
||||
}
|
||||
if !req.inhibit_login {
|
||||
response.device_id = Some(device.device_id.into());
|
||||
}
|
||||
|
||||
Ok(RumaResponse(response))
|
||||
}
|
||||
_ => todo!(),
|
||||
},
|
||||
// For clients not following/using UIAA
|
||||
None => {
|
||||
let password = req
|
||||
.password
|
||||
.ok_or("password missing")
|
||||
.map_err(|_| RegistrationError::AdditionalAuthenticationInformation)?;
|
||||
let user_id = if let Some(username) = req.username {
|
||||
ruma::UserId::parse_with_server_name(username, &config.server_name)
|
||||
.map_err(|e| anyhow::anyhow!(e))?
|
||||
} else {
|
||||
return Err(AuthenticationError::InvalidUserId.into());
|
||||
};
|
||||
|
||||
if User::find()
|
||||
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.is_some()
|
||||
{
|
||||
return Err(RegistrationError::UserIdTaken.into());
|
||||
};
|
||||
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Random Display Name".into());
|
||||
|
||||
let pw_hash = password::hash_password(&password)?;
|
||||
|
||||
let user = users::ActiveModel {
|
||||
user_id: Set(user_id.to_string()),
|
||||
display_name: Set(user_id.to_string()),
|
||||
password_hash: Set(pw_hash),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
.await?;
|
||||
|
||||
let device = devices::ActiveModel {
|
||||
display_name: Set(display_name),
|
||||
user_uuid: Set(user.uuid),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
.await?;
|
||||
|
||||
let mut response =
|
||||
Response::new(ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?);
|
||||
if !req.inhibit_login {
|
||||
let key = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let session = sessions::ActiveModel {
|
||||
device_uuid: Set(device.uuid),
|
||||
key: Set(key),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
.await?;
|
||||
response.access_token = Some(session.key);
|
||||
}
|
||||
if !req.inhibit_login {
|
||||
response.device_id = Some(ruma::OwnedDeviceId::from(device.device_id));
|
||||
}
|
||||
|
||||
Ok(RumaResponse(response))
|
||||
}
|
||||
}
|
||||
}
|
25
neo/src/api/client_server/r0/filter.rs
Normal file
25
neo/src/api/client_server/r0/filter.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::routing::post;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::filter;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/user/:user_id/filter", post(create_filter))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn create_filter(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<filter::create_filter::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<filter::create_filter::v3::Response>, ApiError> {
|
||||
use filter::create_filter::v3::*;
|
||||
|
||||
Ok(RumaResponse(Response::new("a".into())))
|
||||
}
|
25
neo/src/api/client_server/r0/keys.rs
Normal file
25
neo/src/api/client_server/r0/keys.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::routing::post;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::keys;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/keys/query", post(get_keys))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn get_keys(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<keys::get_keys::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<keys::get_keys::v3::Response>, ApiError> {
|
||||
use keys::get_keys::v3::*;
|
||||
|
||||
Ok(RumaResponse(Response::new()))
|
||||
}
|
143
neo/src/api/client_server/r0/mod.rs
Normal file
143
neo/src/api/client_server/r0/mod.rs
Normal file
@ -0,0 +1,143 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
http::{Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
Json,
|
||||
};
|
||||
use neo_entity::{
|
||||
devices::Entity as Device,
|
||||
sessions::{self, Entity as Session},
|
||||
users::Entity as User,
|
||||
};
|
||||
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter};
|
||||
|
||||
use crate::types::error_code::ErrorCode;
|
||||
|
||||
use super::errors::ErrorResponse;
|
||||
|
||||
pub mod auth;
|
||||
pub mod filter;
|
||||
pub mod keys;
|
||||
pub mod presence;
|
||||
pub mod push;
|
||||
pub mod room;
|
||||
pub mod sync;
|
||||
pub mod thirdparty;
|
||||
|
||||
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let db: &DatabaseConnection = req.extensions().get().unwrap();
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok());
|
||||
|
||||
if auth_header.is_none() {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Forbidden,
|
||||
"Authorization Header not given",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let auth_header = auth_header.expect("Validated above");
|
||||
let idx = auth_header.find(' ');
|
||||
|
||||
let idx = match idx {
|
||||
Some(idx) => idx,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Forbidden,
|
||||
"Invalid Authorization Header",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let session = match Session::find()
|
||||
.filter(sessions::Column::Key.eq(&auth_header[idx + 1..]))
|
||||
.one(db)
|
||||
.await
|
||||
{
|
||||
Ok(session) => session,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Unknown,
|
||||
"Internal Server Error",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let session = match session {
|
||||
Some(session) => session,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let device = match session.find_related(Device).one(db).await {
|
||||
Ok(device) => device,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let device = match device {
|
||||
Some(device) => device,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let user = match device.find_related(User).one(db).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let user = match user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(Arc::new(user));
|
||||
|
||||
next.run(req).await.into_response()
|
||||
}
|
36
neo/src/api/client_server/r0/presence.rs
Normal file
36
neo/src/api/client_server/r0/presence.rs
Normal file
@ -0,0 +1,36 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::routing::{get, put};
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::presence;
|
||||
use ruma::presence::PresenceState;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/presence/:user_id/status", put(set_presence))
|
||||
.route("/r0/presence/:user_id/status", get(get_presence))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn set_presence(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<presence::set_presence::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<presence::set_presence::v3::Response>, ApiError> {
|
||||
use presence::set_presence::v3::*;
|
||||
|
||||
Ok(RumaResponse(Response::new()))
|
||||
}
|
||||
|
||||
async fn get_presence(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<presence::get_presence::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<presence::get_presence::v3::Response>, ApiError> {
|
||||
use presence::get_presence::v3::*;
|
||||
|
||||
Ok(RumaResponse(Response::new(PresenceState::Unavailable)))
|
||||
}
|
26
neo/src/api/client_server/r0/push.rs
Normal file
26
neo/src/api/client_server/r0/push.rs
Normal file
@ -0,0 +1,26 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::push;
|
||||
use ruma::push::Ruleset;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/pushrules/", get(get_pushrules))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn get_pushrules(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<push::get_pushrules_all::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<push::get_pushrules_all::v3::Response>, ApiError> {
|
||||
use push::get_pushrules_all::v3::*;
|
||||
|
||||
Ok(RumaResponse(Response::new(Ruleset::new())))
|
||||
}
|
19
neo/src/api/client_server/r0/room.rs
Normal file
19
neo/src/api/client_server/r0/room.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use axum::routing::post;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use crate::ruma_wrapper::RumaRequest;
|
||||
|
||||
use ruma::api::client::room;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/createRoom", post(create_room))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn create_room(
|
||||
RumaRequest(req): RumaRequest<room::create_room::v3::IncomingRequest>,
|
||||
) -> Result<String, ApiError> {
|
||||
dbg!(req);
|
||||
Ok("".into())
|
||||
}
|
29
neo/src/api/client_server/r0/sync.rs
Normal file
29
neo/src/api/client_server/r0/sync.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::routing::get;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::sync;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/sync", get(sync_events))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn sync_events(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(req): RumaRequest<sync::sync_events::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<sync::sync_events::v3::Response>, ApiError> {
|
||||
use sync::sync_events::v3::*;
|
||||
|
||||
if let Some(timeout) = req.timeout {
|
||||
tokio::time::sleep(timeout).await;
|
||||
}
|
||||
|
||||
Ok(RumaResponse(Response::new("todo".into())))
|
||||
}
|
23
neo/src/api/client_server/r0/thirdparty.rs
Normal file
23
neo/src/api/client_server/r0/thirdparty.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use axum::{routing::get, Extension};
|
||||
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use crate::{api::client_server::errors::api_error::ApiError, ruma_wrapper::RumaResponse};
|
||||
|
||||
use ruma::api::client::thirdparty;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/thirdparty/protocols", get(get_thirdparty_protocols))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
async fn get_thirdparty_protocols(
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
) -> Result<RumaResponse<thirdparty::get_protocols::v3::Response>, ApiError> {
|
||||
Ok(RumaResponse(thirdparty::get_protocols::v3::Response::new(
|
||||
BTreeMap::new(),
|
||||
)))
|
||||
}
|
20
neo/src/api/client_server/versions.rs
Normal file
20
neo/src/api/client_server/versions.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use axum::routing::get;
|
||||
|
||||
use crate::ruma_wrapper::RumaResponse;
|
||||
use ruma::api::client::discovery;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new().route("/versions", get(get_client_versions))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_client_versions() -> RumaResponse<discovery::get_supported_versions::Response> {
|
||||
use discovery::get_supported_versions::*;
|
||||
|
||||
RumaResponse(Response::new(vec![
|
||||
"r0.5.0".to_owned(),
|
||||
"r0.6.0".to_owned(),
|
||||
"v1.1".to_owned(),
|
||||
"v1.2".to_owned(),
|
||||
]))
|
||||
}
|
17
neo/src/config.rs
Normal file
17
neo/src/config.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use ruma::{OwnedServerName, ServerName};
|
||||
|
||||
pub struct Config {
|
||||
pub db_path: String,
|
||||
pub server_name: OwnedServerName,
|
||||
pub enable_registration: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
db_path: "sqlite://db.sqlite3".into(),
|
||||
server_name: ServerName::parse("fuckwit.dev").unwrap(),
|
||||
enable_registration: true,
|
||||
}
|
||||
}
|
||||
}
|
77
neo/src/main.rs
Normal file
77
neo/src/main.rs
Normal file
@ -0,0 +1,77 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
handler::Handler,
|
||||
http::{Request, StatusCode},
|
||||
Extension, Router,
|
||||
};
|
||||
use config::Config;
|
||||
use neo_migration::{Migrator, MigratorTrait};
|
||||
use sea_orm::Database;
|
||||
use tower_http::{
|
||||
cors::CorsLayer,
|
||||
trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
mod api;
|
||||
mod config;
|
||||
mod responses;
|
||||
mod ruma_wrapper;
|
||||
mod types;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
if std::env::var("RUST_LOG").is_err() {
|
||||
std::env::set_var("RUST_LOG", "debug,sqlx=off");
|
||||
}
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = Arc::new(Config::default());
|
||||
|
||||
let pool = Database::connect(config.db_path.clone()).await?;
|
||||
Migrator::up(&pool, None).await?;
|
||||
|
||||
// TODO: set correct CORS headers
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::Any)
|
||||
.allow_methods(tower_http::cors::Any)
|
||||
.allow_headers(tower_http::cors::Any);
|
||||
|
||||
let tracing_layer = TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().level(Level::INFO))
|
||||
.on_request(DefaultOnRequest::default().level(Level::INFO))
|
||||
.on_response(DefaultOnResponse::default().level(Level::INFO));
|
||||
|
||||
let client_server = Router::new()
|
||||
.merge(api::client_server::versions::routes())
|
||||
.merge(api::client_server::r0::auth::routes())
|
||||
.merge(api::client_server::r0::thirdparty::routes())
|
||||
.merge(api::client_server::r0::room::routes())
|
||||
.merge(api::client_server::r0::presence::routes())
|
||||
.merge(api::client_server::r0::push::routes())
|
||||
.merge(api::client_server::r0::filter::routes())
|
||||
.merge(api::client_server::r0::sync::routes())
|
||||
.merge(api::client_server::r0::keys::routes());
|
||||
|
||||
let router = Router::new()
|
||||
.nest("/_matrix/client", client_server)
|
||||
.layer(cors)
|
||||
.layer(tracing_layer)
|
||||
.fallback(fallback.into_service())
|
||||
.layer(Extension(pool))
|
||||
.layer(Extension(config));
|
||||
|
||||
let _ = axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
|
||||
.serve(router.into_make_service())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fallback(request: Request<Body>) -> StatusCode {
|
||||
tracing::error!("{} {}", request.method(), request.uri());
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
15
neo/src/responses/registration.rs
Normal file
15
neo/src/responses/registration.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use crate::types::user_interactive_authorization::UserInteractiveAuthorizationInfo;
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RegistrationResponse {
|
||||
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
|
||||
}
|
||||
|
||||
impl RegistrationResponse {
|
||||
pub fn user_interactive_authorization_info() -> Self {
|
||||
RegistrationResponse::UserInteractiveAuthorizationInfo(
|
||||
UserInteractiveAuthorizationInfo::new(),
|
||||
)
|
||||
}
|
||||
}
|
66
neo/src/ruma_wrapper.rs
Normal file
66
neo/src/ruma_wrapper.rs
Normal file
@ -0,0 +1,66 @@
|
||||
use axum::{
|
||||
body::{Bytes, Full, HttpBody},
|
||||
extract::{FromRequest, Path},
|
||||
response::IntoResponse,
|
||||
BoxError,
|
||||
};
|
||||
use http::StatusCode;
|
||||
use ruma::{
|
||||
api::{IncomingRequest, OutgoingResponse},
|
||||
exports::bytes::{BufMut, BytesMut},
|
||||
serde::CanonicalJsonValue,
|
||||
};
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
|
||||
pub struct RumaRequest<R>(pub R)
|
||||
where
|
||||
R: IncomingRequest;
|
||||
|
||||
#[axum::async_trait]
|
||||
impl<R, B> FromRequest<B> for RumaRequest<R>
|
||||
where
|
||||
R: IncomingRequest,
|
||||
B: HttpBody + Send,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
{
|
||||
type Rejection = ApiError;
|
||||
|
||||
async fn from_request(
|
||||
req: &mut axum::extract::RequestParts<B>,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let path_params = Path::<Vec<String>>::from_request(req)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
let body = Bytes::from_request(req)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
let json = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let mut buf = BytesMut::new().writer();
|
||||
serde_json::to_writer(&mut buf, &json).expect("can't fail");
|
||||
let body = buf.into_inner().freeze();
|
||||
|
||||
let builder = http::Request::builder().uri(req.uri()).method(req.method());
|
||||
let request = builder.body(body).map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
Ok(Self(
|
||||
R::try_from_http_request(request, &path_params).map_err(|e| anyhow::anyhow!(e))?,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RumaResponse<R>(pub R)
|
||||
where
|
||||
R: OutgoingResponse;
|
||||
|
||||
impl<R: OutgoingResponse> IntoResponse for RumaResponse<R> {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self.0.try_into_http_response::<BytesMut>() {
|
||||
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
|
||||
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
39
neo/src/types/error_code.rs
Normal file
39
neo/src/types/error_code.rs
Normal file
@ -0,0 +1,39 @@
|
||||
#[allow(unused)]
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ErrorCode {
|
||||
Forbidden,
|
||||
UnknownToken,
|
||||
MissingToken,
|
||||
BadJson,
|
||||
NotJson,
|
||||
NotFound,
|
||||
LimitExceeded,
|
||||
Unknown,
|
||||
UserInUse,
|
||||
InvalidUsername,
|
||||
Exclusive,
|
||||
UserDeactivated,
|
||||
}
|
||||
|
||||
impl serde::Serialize for ErrorCode {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(match self {
|
||||
ErrorCode::Forbidden => "M_FORBIDDEN",
|
||||
ErrorCode::UnknownToken => "M_UNKNOWN_TOKEN",
|
||||
ErrorCode::MissingToken => "M_MISSING_TOKEN",
|
||||
ErrorCode::BadJson => "M_BAD_JSON",
|
||||
ErrorCode::NotJson => "M_NOT_JSON",
|
||||
ErrorCode::NotFound => "M_NOT_FOUND",
|
||||
ErrorCode::LimitExceeded => "M_LIMIT_EXCEEDED",
|
||||
ErrorCode::Unknown => "M_UNKNOWN",
|
||||
ErrorCode::UserInUse => "M_USER_IN_USE",
|
||||
ErrorCode::InvalidUsername => "M_INVALID_USERNAME",
|
||||
ErrorCode::Exclusive => "M_EXCLUSIVE",
|
||||
ErrorCode::UserDeactivated => "M_USER_DEACTIVATED",
|
||||
})
|
||||
}
|
||||
}
|
82
neo/src/types/event_type.rs
Normal file
82
neo/src/types/event_type.rs
Normal file
@ -0,0 +1,82 @@
|
||||
use sqlx::{encode::IsNull, Sqlite};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum EventType {
|
||||
RoomCreate,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl sqlx::Type<Sqlite> for EventType {
|
||||
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
|
||||
<&str as sqlx::Type<Sqlite>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'e> sqlx::Encode<'e, Sqlite> for EventType {
|
||||
fn encode_by_ref(
|
||||
&self,
|
||||
buf: &mut <Sqlite as sqlx::database::HasArguments<'e>>::ArgumentBuffer,
|
||||
) -> sqlx::encode::IsNull {
|
||||
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
|
||||
match self {
|
||||
EventType::RoomCreate => "m.room.create",
|
||||
EventType::Unknown => "???",
|
||||
}
|
||||
.into(),
|
||||
));
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'d> sqlx::Decode<'d, Sqlite> for EventType {
|
||||
fn decode(
|
||||
value: <Sqlite as sqlx::database::HasValueRef<'d>>::ValueRef,
|
||||
) -> Result<Self, sqlx::error::BoxDynError> {
|
||||
Ok(match <&str as sqlx::Decode<Sqlite>>::decode(value)? {
|
||||
"m.room.create" => EventType::RoomCreate,
|
||||
_ => EventType::Unknown,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for EventType {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(match self {
|
||||
EventType::RoomCreate => "m.room.create",
|
||||
EventType::Unknown => "dev.fuckwit.unknown_event",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for EventType {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdentifierVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for IdentifierVisitor {
|
||||
type Value = EventType;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("Identifier")
|
||||
}
|
||||
|
||||
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
match v {
|
||||
"m.id.user" => Ok(EventType::RoomCreate),
|
||||
_ => Err(serde::de::Error::custom("Unknown identifier")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(IdentifierVisitor {})
|
||||
}
|
||||
}
|
4
neo/src/types/mod.rs
Normal file
4
neo/src/types/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod error_code;
|
||||
//pub mod event_type;
|
||||
pub mod flow;
|
||||
pub mod user_interactive_authorization;
|
@ -0,0 +1,37 @@
|
||||
{
|
||||
"query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "uuid: Uuid",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "user_uuid: Uuid",
|
||||
"ordinal": 1,
|
||||
"type_info": "Int64"
|
||||
},
|
||||
{
|
||||
"name": "device_id",
|
||||
"ordinal": 2,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "display_name",
|
||||
"ordinal": 3,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "1843c2b3e548d1dd13694a65ca1ba123da38668c3fc5bc431fe5884a6fc25f71"
|
||||
}
|
@ -1,116 +0,0 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::Query,
|
||||
http::StatusCode,
|
||||
routing::{get, post},
|
||||
Extension, Json,
|
||||
};
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::responses::registration::RegistrationResponse;
|
||||
use crate::{
|
||||
models::devices::Device,
|
||||
responses::{flow::Flows, registration::RegistrationSuccess},
|
||||
};
|
||||
use crate::{
|
||||
models::users::User,
|
||||
requests::registration::RegistrationRequest,
|
||||
responses::username_available::UsernameAvailable,
|
||||
types::{
|
||||
authentication_data::AuthenticationData, identifier::Identifier, matrix_user_id::UserId,
|
||||
},
|
||||
Config,
|
||||
};
|
||||
|
||||
use super::errors::{api_error::ApiError, registration_error::RegistrationError};
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/login", get(get_login).post(post_login))
|
||||
.route("/r0/register", post(post_register))
|
||||
.route("/r0/register/available", get(get_username_available))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_login() -> Json<Flows> {
|
||||
Json(Flows::new())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn post_login(body: String) -> StatusCode {
|
||||
dbg!(body);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_username_available(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<Json<UsernameAvailable>, ApiError> {
|
||||
let username = params
|
||||
.get("username")
|
||||
.ok_or(RegistrationError::MissingUserId)?;
|
||||
let user_id = UserId::new(username, &config.homeserver_name)?;
|
||||
let exists = User::exists(&db, &user_id).await?;
|
||||
|
||||
Ok(Json(UsernameAvailable::new(!exists)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn post_register(
|
||||
Extension(config): Extension<Arc<Config>>,
|
||||
Extension(db): Extension<SqlitePool>,
|
||||
Json(body): Json<RegistrationRequest>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<(StatusCode, Json<RegistrationResponse>), ApiError> {
|
||||
// Client tries to get available flows
|
||||
if body.auth().is_none() {
|
||||
return Ok((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(RegistrationResponse::user_interactive_authorization_info()),
|
||||
));
|
||||
}
|
||||
|
||||
let (user, device) = match &body.auth().unwrap() {
|
||||
AuthenticationData::Password(auth_data) => {
|
||||
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
|
||||
let user_id = UserId::new(username, &config.homeserver_name)
|
||||
.ok()
|
||||
.ok_or(RegistrationError::InvalidUserId)?;
|
||||
|
||||
if User::exists(&db, &user_id).await.unwrap() {
|
||||
todo!("Error out")
|
||||
}
|
||||
|
||||
let password = auth_data.password();
|
||||
|
||||
let display_name = match body.initial_device_display_name() {
|
||||
Some(display_name) => display_name.as_ref(),
|
||||
None => "Random displayname",
|
||||
};
|
||||
|
||||
let user = User::create(&db, &user_id, &user_id.to_string(), password)
|
||||
.await
|
||||
.unwrap();
|
||||
let device = Device::create(&db, &user, "test", display_name)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
(user, device)
|
||||
}
|
||||
};
|
||||
|
||||
if body.inhibit_login().unwrap_or(false) {
|
||||
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
|
||||
|
||||
Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp))))
|
||||
} else {
|
||||
let session = device.create_session(&db).await.unwrap();
|
||||
let resp =
|
||||
RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id());
|
||||
|
||||
Ok((StatusCode::OK, Json(RegistrationResponse::Success(resp))))
|
||||
}
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
use super::registration_error::RegistrationError;
|
||||
|
||||
macro_rules! map_err {
|
||||
($err:ident, $($type:path => $target:path),+) => {
|
||||
$(
|
||||
if $err.is::<$type>() {
|
||||
return $target($err.downcast().unwrap());
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApiError {
|
||||
#[error("Registration Error")]
|
||||
RegistrationError(#[from] RegistrationError),
|
||||
|
||||
#[error("Database Error")]
|
||||
DBError(#[from] sqlx::Error),
|
||||
|
||||
#[error("Generic Error")]
|
||||
Generic(anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ApiError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
map_err!(err,
|
||||
sqlx::Error => ApiError::DBError,
|
||||
RegistrationError => ApiError::RegistrationError
|
||||
);
|
||||
|
||||
ApiError::Generic(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
ApiError::RegistrationError(registration_error) => match registration_error {
|
||||
RegistrationError::InvalidUserId => {
|
||||
(StatusCode::OK, String::new()).into_response()
|
||||
}
|
||||
RegistrationError::MissingUserId => {
|
||||
(StatusCode::OK, String::new()).into_response()
|
||||
}
|
||||
},
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,2 +0,0 @@
|
||||
pub mod api_error;
|
||||
pub mod registration_error;
|
@ -1,7 +0,0 @@
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RegistrationError {
|
||||
#[error("UserId is missing")]
|
||||
MissingUserId,
|
||||
#[error("UserId is invalid")]
|
||||
InvalidUserId,
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
use axum::{routing::get, Json};
|
||||
|
||||
use crate::responses::versions::Versions;
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new().route("/versions", get(get_client_versions))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_client_versions() -> Json<Versions> {
|
||||
Json(Versions::default())
|
||||
}
|
78
src/main.rs
78
src/main.rs
@ -1,78 +0,0 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
handler::Handler,
|
||||
http::{Request, StatusCode},
|
||||
Extension, Router,
|
||||
};
|
||||
use tower_http::{
|
||||
cors::CorsLayer,
|
||||
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
mod api;
|
||||
mod models;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod types;
|
||||
|
||||
struct Config {
|
||||
db_path: String,
|
||||
homeserver_name: String,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
db_path: "sqlite://db.sqlite3".into(),
|
||||
homeserver_name: "fuckwit.dev".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
if std::env::var("RUST_LOG").is_err() {
|
||||
std::env::set_var("RUST_LOG", "debug");
|
||||
}
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let config = Arc::new(Config::default());
|
||||
|
||||
let pool = sqlx::SqlitePool::connect(&config.db_path)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::Any)
|
||||
.allow_methods(tower_http::cors::Any)
|
||||
.allow_headers(tower_http::cors::Any);
|
||||
|
||||
let tracing_layer = TraceLayer::new_for_http();
|
||||
|
||||
let client_server = Router::new()
|
||||
.merge(api::client_server::versions::routes())
|
||||
.merge(api::client_server::auth::routes());
|
||||
|
||||
let router = Router::new()
|
||||
.nest("/_matrix/client", client_server)
|
||||
.layer(cors)
|
||||
.layer(tracing_layer)
|
||||
.layer(Extension(pool))
|
||||
.layer(Extension(config))
|
||||
.fallback(fallback.into_service());
|
||||
|
||||
let _ = axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
|
||||
.serve(router.into_make_service())
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn fallback(request: Request<Body>) -> StatusCode {
|
||||
dbg!(request);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use super::{sessions::Session, users::User};
|
||||
|
||||
pub struct Device {
|
||||
id: i64,
|
||||
user_id: i64,
|
||||
device_id: String,
|
||||
display_name: String,
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub async fn create(
|
||||
conn: &SqlitePool,
|
||||
user: &User,
|
||||
device_id: &str,
|
||||
display_name: &str,
|
||||
) -> anyhow::Result<Self> {
|
||||
let user_id = user.id();
|
||||
Ok(sqlx::query_as!(Self, "insert into devices(user_id, device_id, display_name) values(?, ?, ?) returning id, user_id, device_id, display_name", user_id, device_id, display_name).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
pub async fn by_user(conn: &SqlitePool, user: &User) -> anyhow::Result<Self> {
|
||||
let user_id = user.id();
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
"select id, user_id, device_id, display_name from devices where user_id = ?",
|
||||
user_id
|
||||
)
|
||||
.fetch_one(conn)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn create_session(&self, conn: &SqlitePool) -> anyhow::Result<Session> {
|
||||
Ok(Session::create(conn, self, "random_session_id").await?)
|
||||
}
|
||||
|
||||
/// Get the device's id.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Get the device's user id.
|
||||
#[must_use]
|
||||
pub fn user_id(&self) -> i64 {
|
||||
self.user_id
|
||||
}
|
||||
|
||||
/// Get a reference to the device's device id.
|
||||
#[must_use]
|
||||
pub fn device_id(&self) -> &str {
|
||||
self.device_id.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the device's display name.
|
||||
#[must_use]
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.display_name.as_ref()
|
||||
}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use super::devices::Device;
|
||||
|
||||
pub struct Session {
|
||||
id: i64,
|
||||
device_id: i64,
|
||||
value: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn create(conn: &SqlitePool, device: &Device, value: &str) -> anyhow::Result<Self> {
|
||||
let device_id = device.id();
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
"insert into sessions(device_id, value) values(?, ?) returning id, device_id, value",
|
||||
device_id,
|
||||
value
|
||||
)
|
||||
.fetch_one(conn)
|
||||
.await?)
|
||||
}
|
||||
|
||||
/// Get the session's id.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Get the session's device id.
|
||||
#[must_use]
|
||||
pub fn device_id(&self) -> i64 {
|
||||
self.device_id
|
||||
}
|
||||
|
||||
/// Get a reference to the session's value.
|
||||
#[must_use]
|
||||
pub fn value(&self) -> &str {
|
||||
self.value.as_ref()
|
||||
}
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::types::matrix_user_id::UserId;
|
||||
|
||||
pub struct User {
|
||||
id: i64,
|
||||
user_id: String,
|
||||
display_name: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub async fn exists(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<bool> {
|
||||
Ok(
|
||||
sqlx::query!("select user_id from users where user_id = ?", user_id)
|
||||
.fetch_optional(conn)
|
||||
.await?
|
||||
.is_some(),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
conn: &SqlitePool,
|
||||
user_id: &UserId,
|
||||
display_name: &str,
|
||||
password: &str,
|
||||
) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(Self, "insert into users(user_id, display_name, password) values (?, ?, ?) returning id, user_id, display_name, password", user_id, display_name, password).fetch_one(conn).await?)
|
||||
}
|
||||
|
||||
pub async fn by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
"select id, user_id, display_name, password from users where user_id = ?",
|
||||
user_id
|
||||
)
|
||||
.fetch_one(conn)
|
||||
.await?)
|
||||
}
|
||||
|
||||
/// Get the user's id.
|
||||
#[must_use]
|
||||
pub fn id(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
/// Get a reference to the user's user id.
|
||||
#[must_use]
|
||||
pub fn user_id(&self) -> &str {
|
||||
self.user_id.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the user's password.
|
||||
#[must_use]
|
||||
pub fn password(&self) -> &str {
|
||||
self.password.as_ref()
|
||||
}
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
use crate::types::{authentication_data::AuthenticationData, flow::Flow, identifier::Identifier};
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct RegistrationRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
auth: Option<AuthenticationData>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
device_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
inhibit_login: Option<bool>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
initial_device_display_name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
password: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
username: Option<String>,
|
||||
}
|
||||
|
||||
impl RegistrationRequest {
|
||||
#[must_use]
|
||||
pub fn auth(&self) -> Option<&AuthenticationData> {
|
||||
self.auth.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's device id.
|
||||
#[must_use]
|
||||
pub fn device_id(&self) -> Option<&String> {
|
||||
self.device_id.as_ref()
|
||||
}
|
||||
|
||||
/// Get the registration request's inhibit login.
|
||||
#[must_use]
|
||||
pub fn inhibit_login(&self) -> Option<bool> {
|
||||
self.inhibit_login
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's initial device display name.
|
||||
#[must_use]
|
||||
pub fn initial_device_display_name(&self) -> Option<&String> {
|
||||
self.initial_device_display_name.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's password.
|
||||
#[must_use]
|
||||
pub fn password(&self) -> Option<&String> {
|
||||
self.password.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the registration request's username.
|
||||
#[must_use]
|
||||
pub fn username(&self) -> Option<&String> {
|
||||
self.username.as_ref()
|
||||
}
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
use crate::types::flow::Flow;
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
struct FlowWrapper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Flow,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize)]
|
||||
pub struct Flows {
|
||||
flows: Vec<FlowWrapper>,
|
||||
}
|
||||
|
||||
impl Flows {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
flows: vec![FlowWrapper {
|
||||
_type: Flow::Password,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
pub mod flow;
|
||||
pub mod registration;
|
||||
pub mod username_available;
|
||||
pub mod versions;
|
@ -1,34 +0,0 @@
|
||||
use crate::types::user_interactive_authorization::UserInteractiveAuthorizationInfo;
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum RegistrationResponse {
|
||||
Success(RegistrationSuccess),
|
||||
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
|
||||
}
|
||||
|
||||
impl RegistrationResponse {
|
||||
pub fn user_interactive_authorization_info() -> Self {
|
||||
RegistrationResponse::UserInteractiveAuthorizationInfo(
|
||||
UserInteractiveAuthorizationInfo::new(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct RegistrationSuccess {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
access_token: Option<String>,
|
||||
device_id: String,
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
impl RegistrationSuccess {
|
||||
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
|
||||
Self {
|
||||
access_token: access_token.and_then(|v| Some(v.to_owned())),
|
||||
device_id: device_id.to_owned(),
|
||||
user_id: user_id.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct UsernameAvailable {
|
||||
available: bool,
|
||||
}
|
||||
|
||||
impl UsernameAvailable {
|
||||
pub fn new(available: bool) -> Self {
|
||||
Self { available }
|
||||
}
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct Versions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
unstable_features: Option<HashMap<String, String>>,
|
||||
versions: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for Versions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
unstable_features: None,
|
||||
versions: vec!["v1.2".into()],
|
||||
}
|
||||
}
|
||||
}
|
@ -1,36 +0,0 @@
|
||||
use super::{flow::Flow, identifier::Identifier};
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AuthenticationData {
|
||||
Password(AuthenticationPassword),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct AuthenticationPassword {
|
||||
#[serde(rename = "type")]
|
||||
_type: Flow,
|
||||
identifier: Identifier,
|
||||
password: String,
|
||||
user: Option<String>,
|
||||
}
|
||||
|
||||
impl AuthenticationPassword {
|
||||
/// Get a reference to the authentication password's identifier.
|
||||
#[must_use]
|
||||
pub fn identifier(&self) -> &Identifier {
|
||||
&self.identifier
|
||||
}
|
||||
|
||||
/// Get a reference to the authentication password's password.
|
||||
#[must_use]
|
||||
pub fn password(&self) -> &str {
|
||||
self.password.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the authentication password's user.
|
||||
#[must_use]
|
||||
pub fn user(&self) -> Option<&String> {
|
||||
self.user.as_ref()
|
||||
}
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
use super::identifier_type::IdentifierType;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Identifier {
|
||||
User(IdentifierUser),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct IdentifierUser {
|
||||
#[serde(rename = "type")]
|
||||
_type: IdentifierType,
|
||||
user: Option<String>,
|
||||
}
|
||||
|
||||
impl IdentifierUser {
|
||||
/// Get a reference to the identifier user's user.
|
||||
#[must_use]
|
||||
pub fn user(&self) -> Option<&String> {
|
||||
self.user.as_ref()
|
||||
}
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum IdentifierType {
|
||||
User,
|
||||
}
|
||||
|
||||
impl serde::Serialize for IdentifierType {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(match self {
|
||||
IdentifierType::User => "m.id.user",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for IdentifierType {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdentifierVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for IdentifierVisitor {
|
||||
type Value = IdentifierType;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("Identifier")
|
||||
}
|
||||
|
||||
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
match v {
|
||||
"m.id.user" => Ok(IdentifierType::User),
|
||||
_ => Err(serde::de::Error::custom("Unknown identifier")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_str(IdentifierVisitor {})
|
||||
}
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(sqlx::Type)]
|
||||
#[sqlx(transparent)]
|
||||
#[repr(transparent)]
|
||||
pub struct UserId(String);
|
||||
|
||||
impl UserId {
|
||||
pub fn new(name: &str, server_name: &str) -> anyhow::Result<Self> {
|
||||
let user_id = Self(format!("@{name}:{server_name}"));
|
||||
|
||||
user_id.is_valid()?;
|
||||
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
pub fn local_part(&self) -> &str {
|
||||
let col_idx = self.0.find(':').expect("Will always have at least one ':'");
|
||||
self.0[1..col_idx].as_ref()
|
||||
}
|
||||
|
||||
fn is_valid(&self) -> anyhow::Result<()> {
|
||||
(self.0.len() <= 255)
|
||||
.then(|| ())
|
||||
.ok_or(UserIdError::TooLong(self.0.len()))?;
|
||||
|
||||
let local_part = self.local_part();
|
||||
local_part
|
||||
.bytes()
|
||||
.all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/'))
|
||||
.then(|| ())
|
||||
.ok_or(UserIdError::InvalidCharacters)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for UserId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UserIdError {
|
||||
#[error("UserId too long {0} (expected < 255)")]
|
||||
TooLong(usize),
|
||||
#[error("Invalid character present in user id")]
|
||||
InvalidCharacters,
|
||||
#[error("Invalid UserId given")]
|
||||
Invalid,
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
pub mod authentication_data;
|
||||
pub mod flow;
|
||||
pub mod identifier;
|
||||
pub mod identifier_type;
|
||||
pub mod matrix_user_id;
|
||||
pub mod user_interactive_authorization;
|
Reference in New Issue
Block a user