Compare commits
20 Commits
b8e1235396
...
master
Author | SHA1 | Date | |
---|---|---|---|
d98e9ea9e3 | |||
9510d9c765 | |||
e33a734199 | |||
35bde07b39 | |||
26e39c7c06 | |||
6ed7b16bf6 | |||
277d7111c8 | |||
8ada363a92 | |||
29093c51e3 | |||
71590d6c60 | |||
ba84efd384 | |||
54f67d435e | |||
2c2ac27c26 | |||
c20b4c6a23 | |||
3b8c529183 | |||
341c516fcb | |||
b0895145de | |||
304f82baa4 | |||
601b2d4f42 | |||
2c91e99a4d |
11
.drone.yml
Normal file
11
.drone.yml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
kind: pipeline
|
||||||
|
type: docker
|
||||||
|
name: check
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: cargo check
|
||||||
|
image: rust:latest
|
||||||
|
environment:
|
||||||
|
SQLX_OFFLINE: 'true'
|
||||||
|
commands:
|
||||||
|
- cargo check
|
1862
Cargo.lock
generated
1862
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
23
Cargo.toml
23
Cargo.toml
@ -1,17 +1,8 @@
|
|||||||
[package]
|
[workspace]
|
||||||
name = "matrix"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
members = [
|
||||||
|
"neo",
|
||||||
[dependencies]
|
"neo-entity",
|
||||||
tokio = { version = "1.17.0", features = ["full"] }
|
"neo-migration",
|
||||||
axum = "0.5.3"
|
"neo-util"
|
||||||
tracing = "0.1.34"
|
]
|
||||||
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
|
|
||||||
serde = {version = "1.0.136", features = ["derive"] }
|
|
||||||
tower-http = { version = "0.2.5", features = ["cors", "trace"]}
|
|
||||||
sqlx = { version = "0.5.13", features = ["sqlite", "macros", "runtime-tokio-rustls"] }
|
|
||||||
anyhow = "1.0"
|
|
||||||
thiserror = "1.0"
|
|
||||||
|
@ -1,9 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
CREATE TABLE users(
|
|
||||||
id INTEGER PRIMARY KEY NOT NULL,
|
|
||||||
user_id CHAR(255) NOT NULL,
|
|
||||||
display_name TEXT NOT NULL,
|
|
||||||
password TEXT NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX user_id_index ON users (user_id);
|
|
@ -1,11 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
|
|
||||||
CREATE TABLE devices(
|
|
||||||
id INTEGER PRIMARY KEY NOT NULL,
|
|
||||||
user_id INT NOT NULL,
|
|
||||||
device_id TEXT NOT NULL,
|
|
||||||
display_name TEXT NOT NULL,
|
|
||||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX device_id_index ON devices (device_id);
|
|
@ -1,10 +0,0 @@
|
|||||||
-- Add migration script here
|
|
||||||
|
|
||||||
CREATE TABLE sessions(
|
|
||||||
id INTEGER PRIMARY KEY NOT NULL,
|
|
||||||
device_id INT NOT NULL,
|
|
||||||
value TEXT NOT NULL,
|
|
||||||
FOREIGN KEY(device_id) REFERENCES devices(id)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX value_index ON sessions (value);
|
|
2
neo-entity/.gitignore
vendored
Normal file
2
neo-entity/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
/Cargo.lock
|
12
neo-entity/Cargo.toml
Normal file
12
neo-entity/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "neo-entity"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
chrono = {version = "0.4", features = ["serde"] }
|
||||||
|
sea-orm = { version = "^0.9", features = ["macros", "with-chrono", "with-uuid", "with-json"], default-features = false }
|
||||||
|
serde = "1.0"
|
||||||
|
serde_json = "1.0"
|
||||||
|
uuid = { version = "*", features = ["v4", "serde"]}
|
48
neo-entity/src/devices.rs
Normal file
48
neo-entity/src/devices.rs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
use sea_orm::entity::prelude::*;
|
||||||
|
use sea_orm::Set;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "devices")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key, auto_increment = false)]
|
||||||
|
pub uuid: Uuid,
|
||||||
|
pub user_uuid: Uuid,
|
||||||
|
pub device_id: String,
|
||||||
|
pub display_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(
|
||||||
|
belongs_to = "super::users::Entity",
|
||||||
|
from = "Column::UserUuid",
|
||||||
|
to = "super::users::Column::Uuid",
|
||||||
|
on_update = "NoAction",
|
||||||
|
on_delete = "NoAction"
|
||||||
|
)]
|
||||||
|
Users,
|
||||||
|
#[sea_orm(has_many = "super::sessions::Entity")]
|
||||||
|
Sessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::users::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Users.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::sessions::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Sessions.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
uuid: Set(Uuid::new_v4()),
|
||||||
|
device_id: Set(Uuid::new_v4().to_string()),
|
||||||
|
..ActiveModelTrait::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
55
neo-entity/src/events.rs
Normal file
55
neo-entity/src/events.rs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
use sea_orm::{entity::prelude::*, Set};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "events")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key, auto_increment = false)]
|
||||||
|
pub uuid: Uuid,
|
||||||
|
pub room_uuid: Uuid,
|
||||||
|
pub r#type: String,
|
||||||
|
pub state_key: Option<String>,
|
||||||
|
pub sender_uuid: Uuid,
|
||||||
|
pub origin_server_ts: i64,
|
||||||
|
pub content: Json,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(
|
||||||
|
belongs_to = "super::users::Entity",
|
||||||
|
from = "Column::SenderUuid",
|
||||||
|
to = "super::users::Column::Uuid",
|
||||||
|
on_update = "NoAction",
|
||||||
|
on_delete = "NoAction"
|
||||||
|
)]
|
||||||
|
Users,
|
||||||
|
#[sea_orm(
|
||||||
|
belongs_to = "super::rooms::Entity",
|
||||||
|
from = "Column::RoomUuid",
|
||||||
|
to = "super::rooms::Column::Uuid",
|
||||||
|
on_update = "NoAction",
|
||||||
|
on_delete = "NoAction"
|
||||||
|
)]
|
||||||
|
Rooms,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::users::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Users.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::rooms::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Rooms.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
uuid: Set(Uuid::new_v4()),
|
||||||
|
..ActiveModelTrait::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,6 @@
|
|||||||
pub mod devices;
|
pub mod devices;
|
||||||
|
pub mod events;
|
||||||
|
pub mod prelude;
|
||||||
|
pub mod rooms;
|
||||||
pub mod sessions;
|
pub mod sessions;
|
||||||
pub mod users;
|
pub mod users;
|
6
neo-entity/src/prelude.rs
Normal file
6
neo-entity/src/prelude.rs
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
#[allow(unused_imports)]
|
||||||
|
pub use crate::{
|
||||||
|
devices::{self, Entity as Device, Model as DeviceModel},
|
||||||
|
sessions::{self, Entity as Session, Model as SessionModel},
|
||||||
|
users::{self, Entity as User, Model as UserModel},
|
||||||
|
};
|
30
neo-entity/src/rooms.rs
Normal file
30
neo-entity/src/rooms.rs
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
use sea_orm::{entity::prelude::*, Set};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "rooms")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key, auto_increment = false)]
|
||||||
|
pub uuid: Uuid,
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(has_many = "super::events::Entity")]
|
||||||
|
Events,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::events::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Events.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
uuid: Set(Uuid::new_v4()),
|
||||||
|
..ActiveModelTrait::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
37
neo-entity/src/sessions.rs
Normal file
37
neo-entity/src/sessions.rs
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
use sea_orm::{entity::prelude::*, Set};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "sessions")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key, auto_increment = false)]
|
||||||
|
pub uuid: Uuid,
|
||||||
|
pub device_uuid: Uuid,
|
||||||
|
pub key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(
|
||||||
|
belongs_to = "super::devices::Entity",
|
||||||
|
from = "Column::DeviceUuid",
|
||||||
|
to = "super::devices::Column::Uuid",
|
||||||
|
on_update = "NoAction",
|
||||||
|
on_delete = "NoAction"
|
||||||
|
)]
|
||||||
|
Devices,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::devices::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Devices.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
uuid: Set(Uuid::new_v4()),
|
||||||
|
..ActiveModelTrait::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
33
neo-entity/src/users.rs
Normal file
33
neo-entity/src/users.rs
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
use sea_orm::entity::prelude::*;
|
||||||
|
use sea_orm::Set;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "users")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key, auto_increment = false)]
|
||||||
|
pub uuid: Uuid,
|
||||||
|
pub user_id: String,
|
||||||
|
pub display_name: String,
|
||||||
|
pub password_hash: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(has_many = "super::devices::Entity")]
|
||||||
|
Devices,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::devices::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Devices.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
uuid: Set(Uuid::new_v4()),
|
||||||
|
..ActiveModelTrait::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
10
neo-migration/Cargo.toml
Normal file
10
neo-migration/Cargo.toml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
[package]
|
||||||
|
name = "neo-migration"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
sea-orm-migration = "^0.9"
|
||||||
|
neo-entity = { path = "../neo-entity" }
|
||||||
|
automod = "1"
|
37
neo-migration/README.md
Normal file
37
neo-migration/README.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Running Migrator CLI
|
||||||
|
|
||||||
|
- Apply all pending migrations
|
||||||
|
```sh
|
||||||
|
cargo run
|
||||||
|
```
|
||||||
|
```sh
|
||||||
|
cargo run -- up
|
||||||
|
```
|
||||||
|
- Apply first 10 pending migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- up -n 10
|
||||||
|
```
|
||||||
|
- Rollback last applied migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- down
|
||||||
|
```
|
||||||
|
- Rollback last 10 applied migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- down -n 10
|
||||||
|
```
|
||||||
|
- Drop all tables from the database, then reapply all migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- fresh
|
||||||
|
```
|
||||||
|
- Rollback all applied migrations, then reapply all migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- refresh
|
||||||
|
```
|
||||||
|
- Rollback all applied migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- reset
|
||||||
|
```
|
||||||
|
- Check the status of all migrations
|
||||||
|
```sh
|
||||||
|
cargo run -- status
|
||||||
|
```
|
18
neo-migration/src/lib.rs
Normal file
18
neo-migration/src/lib.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
pub use sea_orm_migration::prelude::*;
|
||||||
|
|
||||||
|
automod::dir!("src");
|
||||||
|
|
||||||
|
pub struct Migrator;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl MigratorTrait for Migrator {
|
||||||
|
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
|
||||||
|
vec![
|
||||||
|
Box::new(m20220707_092851_create_users::Migration),
|
||||||
|
Box::new(m20220707_112339_create_devices::Migration),
|
||||||
|
Box::new(m20220707_143304_create_sessions::Migration),
|
||||||
|
Box::new(m20220724_223253_create_rooms::Migration),
|
||||||
|
Box::new(m20220724_223335_create_events::Migration),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
45
neo-migration/src/m20220707_092851_create_users.rs
Normal file
45
neo-migration/src/m20220707_092851_create_users.rs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
use neo_entity::users::{self, Entity as User};
|
||||||
|
use sea_orm_migration::prelude::*;
|
||||||
|
|
||||||
|
#[derive(DeriveMigrationName)]
|
||||||
|
pub struct Migration;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl MigrationTrait for Migration {
|
||||||
|
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||||
|
manager
|
||||||
|
.create_table(
|
||||||
|
Table::create()
|
||||||
|
.table(User)
|
||||||
|
.if_not_exists()
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(users::Column::Uuid)
|
||||||
|
.uuid()
|
||||||
|
.primary_key()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(ColumnDef::new(users::Column::UserId).string().not_null())
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(users::Column::DisplayName)
|
||||||
|
.string()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(users::Column::PasswordHash)
|
||||||
|
.string()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("user_id_index")
|
||||||
|
.table(User)
|
||||||
|
.col(users::Column::UserId)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
54
neo-migration/src/m20220707_112339_create_devices.rs
Normal file
54
neo-migration/src/m20220707_112339_create_devices.rs
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
use neo_entity::devices::{self, Entity as Device};
|
||||||
|
use sea_orm_migration::prelude::*;
|
||||||
|
|
||||||
|
#[derive(DeriveMigrationName)]
|
||||||
|
pub struct Migration;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl MigrationTrait for Migration {
|
||||||
|
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||||
|
manager
|
||||||
|
.create_table(
|
||||||
|
Table::create()
|
||||||
|
.table(Device)
|
||||||
|
.if_not_exists()
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(devices::Column::Uuid)
|
||||||
|
.uuid()
|
||||||
|
.primary_key()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(ColumnDef::new(devices::Column::UserUuid).uuid().not_null())
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(devices::Column::DeviceId)
|
||||||
|
.string()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(devices::Column::DisplayName)
|
||||||
|
.string()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("device_id_index")
|
||||||
|
.table(Device)
|
||||||
|
.col(devices::Column::DeviceId)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("user_uuid_index")
|
||||||
|
.table(Device)
|
||||||
|
.col(devices::Column::UserUuid)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
40
neo-migration/src/m20220707_143304_create_sessions.rs
Normal file
40
neo-migration/src/m20220707_143304_create_sessions.rs
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
use neo_entity::sessions::{self, Entity as Session};
|
||||||
|
use sea_orm_migration::prelude::*;
|
||||||
|
|
||||||
|
#[derive(DeriveMigrationName)]
|
||||||
|
pub struct Migration;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl MigrationTrait for Migration {
|
||||||
|
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||||
|
manager
|
||||||
|
.create_table(
|
||||||
|
Table::create()
|
||||||
|
.table(Session)
|
||||||
|
.if_not_exists()
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(sessions::Column::Uuid)
|
||||||
|
.uuid()
|
||||||
|
.primary_key()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(sessions::Column::DeviceUuid)
|
||||||
|
.uuid()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(ColumnDef::new(sessions::Column::Key).string().not_null())
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("device_uuid_index")
|
||||||
|
.table(Session)
|
||||||
|
.col(sessions::Column::DeviceUuid)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
26
neo-migration/src/m20220724_223253_create_rooms.rs
Normal file
26
neo-migration/src/m20220724_223253_create_rooms.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
use neo_entity::rooms::{self, Entity as Room};
|
||||||
|
use sea_orm_migration::prelude::*;
|
||||||
|
|
||||||
|
#[derive(DeriveMigrationName)]
|
||||||
|
pub struct Migration;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl MigrationTrait for Migration {
|
||||||
|
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||||
|
manager
|
||||||
|
.create_table(
|
||||||
|
Table::create()
|
||||||
|
.table(Room)
|
||||||
|
.if_not_exists()
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(rooms::Column::Uuid)
|
||||||
|
.uuid()
|
||||||
|
.primary_key()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(ColumnDef::new(rooms::Column::Name).string().not_null())
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
72
neo-migration/src/m20220724_223335_create_events.rs
Normal file
72
neo-migration/src/m20220724_223335_create_events.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use neo_entity::events::{self, Entity as Event};
|
||||||
|
use sea_orm_migration::prelude::*;
|
||||||
|
|
||||||
|
#[derive(DeriveMigrationName)]
|
||||||
|
pub struct Migration;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl MigrationTrait for Migration {
|
||||||
|
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
|
||||||
|
manager
|
||||||
|
.create_table(
|
||||||
|
Table::create()
|
||||||
|
.table(Event)
|
||||||
|
.if_not_exists()
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(events::Column::Uuid)
|
||||||
|
.uuid()
|
||||||
|
.primary_key()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(ColumnDef::new(events::Column::RoomUuid).uuid().not_null())
|
||||||
|
.col(ColumnDef::new(events::Column::Type).string().not_null())
|
||||||
|
.col(ColumnDef::new(events::Column::StateKey).string())
|
||||||
|
.col(ColumnDef::new(events::Column::SenderUuid).uuid().not_null())
|
||||||
|
.col(
|
||||||
|
ColumnDef::new(events::Column::OriginServerTs)
|
||||||
|
.integer()
|
||||||
|
.not_null(),
|
||||||
|
)
|
||||||
|
.col(ColumnDef::new(events::Column::Content).json().not_null())
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("room_uuid_index")
|
||||||
|
.table(Event)
|
||||||
|
.col(events::Column::RoomUuid)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("type_index")
|
||||||
|
.table(Event)
|
||||||
|
.col(events::Column::Type)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("state_key_index")
|
||||||
|
.table(Event)
|
||||||
|
.col(events::Column::StateKey)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
manager
|
||||||
|
.create_index(
|
||||||
|
Index::create()
|
||||||
|
.name("type_state_key_index")
|
||||||
|
.table(Event)
|
||||||
|
.col(events::Column::Type)
|
||||||
|
.col(events::Column::StateKey)
|
||||||
|
.to_owned(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
11
neo-util/Cargo.toml
Normal file
11
neo-util/Cargo.toml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
[package]
|
||||||
|
name = "neo-util"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
|
argon2 = { version = "0.4", features = ["std"] }
|
||||||
|
rand = { version = "0.8.5", features = ["std"] }
|
20
neo-util/src/events.rs
Normal file
20
neo-util/src/events.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
pub static STATE_EVENTS: &[&str] = &[
|
||||||
|
"m.room.create",
|
||||||
|
"m.room.canonical_alias",
|
||||||
|
"m.room.join_rules",
|
||||||
|
"m.room.member",
|
||||||
|
"m.room.power_levels",
|
||||||
|
];
|
||||||
|
|
||||||
|
pub enum EventCategory {
|
||||||
|
StateEvent,
|
||||||
|
MessageEvent,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn classify_event(event_type: &str) -> EventCategory {
|
||||||
|
if STATE_EVENTS.contains(&event_type) {
|
||||||
|
EventCategory::StateEvent
|
||||||
|
} else {
|
||||||
|
EventCategory::MessageEvent
|
||||||
|
}
|
||||||
|
}
|
2
neo-util/src/lib.rs
Normal file
2
neo-util/src/lib.rs
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
pub mod events;
|
||||||
|
pub mod password;
|
17
neo-util/src/password.rs
Normal file
17
neo-util/src/password.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||||
|
use rand::rngs::OsRng;
|
||||||
|
|
||||||
|
pub fn hash_password(password: &str) -> anyhow::Result<String> {
|
||||||
|
let argon2 = Argon2::default();
|
||||||
|
let salt = SaltString::generate(OsRng);
|
||||||
|
Ok(argon2
|
||||||
|
.hash_password(password.as_bytes(), &salt)?
|
||||||
|
.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn password_correct(password: &str, hash: &str) -> anyhow::Result<bool> {
|
||||||
|
let password_hash = PasswordHash::new(hash)?;
|
||||||
|
Ok(Argon2::default()
|
||||||
|
.verify_password(password.as_bytes(), &password_hash)
|
||||||
|
.is_ok())
|
||||||
|
}
|
23
neo/Cargo.toml
Normal file
23
neo/Cargo.toml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
[package]
|
||||||
|
name = "neo"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
tokio = { version = "1.17", features = ["full"] }
|
||||||
|
axum = "0.5"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
serde = {version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
tower-http = { version = "0.2", features = ["cors", "trace"] }
|
||||||
|
anyhow = "1.0"
|
||||||
|
thiserror = "1.0"
|
||||||
|
rand = { version = "0.8.5", features = ["std"] }
|
||||||
|
uuid = { version = "1.0", features = ["v4"] }
|
||||||
|
ruma = { version = "0.6.4", features = ["client-api", "compat"] }
|
||||||
|
http = "0.2.8"
|
||||||
|
sea-orm = { version = "^0.9", features = ["sqlx-sqlite", "runtime-tokio-native-tls", "macros"], default-features = false }
|
||||||
|
neo-entity = { version = "*", path = "../neo-entity" }
|
||||||
|
neo-migration = { version = "*", path = "../neo-migration" }
|
||||||
|
neo-util = { version = "*", path = "../neo-util" }
|
82
neo/src/api/client_server/errors/api_error.rs
Normal file
82
neo/src/api/client_server/errors/api_error.rs
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum::Json;
|
||||||
|
|
||||||
|
use crate::types::error_code::ErrorCode;
|
||||||
|
|
||||||
|
use super::authentication_error::AuthenticationError;
|
||||||
|
use super::registration_error::RegistrationError;
|
||||||
|
use super::ErrorResponse;
|
||||||
|
|
||||||
|
macro_rules! map_err {
|
||||||
|
($err:ident, $($type:path => $target:path),+) => {
|
||||||
|
$(
|
||||||
|
if $err.is::<$type>() {
|
||||||
|
return $target($err.downcast().unwrap());
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum ApiError {
|
||||||
|
#[error("Registration Error")]
|
||||||
|
RegistrationError(#[from] RegistrationError),
|
||||||
|
|
||||||
|
#[error("Authentication Error")]
|
||||||
|
AuthenticationError(#[from] AuthenticationError),
|
||||||
|
|
||||||
|
#[error("Database Error")]
|
||||||
|
DBError(#[from] sea_orm::DbErr),
|
||||||
|
|
||||||
|
#[error("Generic Error")]
|
||||||
|
Generic(anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<anyhow::Error> for ApiError {
|
||||||
|
fn from(err: anyhow::Error) -> Self {
|
||||||
|
map_err!(err,
|
||||||
|
sea_orm::DbErr => ApiError::DBError,
|
||||||
|
RegistrationError => ApiError::RegistrationError,
|
||||||
|
AuthenticationError => ApiError::AuthenticationError
|
||||||
|
);
|
||||||
|
|
||||||
|
ApiError::Generic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for ApiError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
ApiError::RegistrationError(e) => e.into_response(),
|
||||||
|
ApiError::AuthenticationError(e) => e.into_response(),
|
||||||
|
ApiError::DBError(err) => {
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Unknown,
|
||||||
|
"Database error! If you are the application owner please take a look at your application logs.",
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
ApiError::Generic(err) => {
|
||||||
|
tracing::error!("{}", err.to_string());
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Unknown,
|
||||||
|
"Fatal error occured! If you are the application owner please take a look at your application logs.",
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
#[allow(unreachable_patterns)]
|
||||||
|
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
53
neo/src/api/client_server/errors/authentication_error.rs
Normal file
53
neo/src/api/client_server/errors/authentication_error.rs
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||||
|
|
||||||
|
use crate::types::error_code::ErrorCode;
|
||||||
|
|
||||||
|
use super::ErrorResponse;
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum AuthenticationError {
|
||||||
|
#[error("UserId is missing")]
|
||||||
|
MissingUserId,
|
||||||
|
#[error("The user ID is not a valid user name")]
|
||||||
|
InvalidUserId,
|
||||||
|
#[error("The provided authentication data was incorrect")]
|
||||||
|
Forbidden,
|
||||||
|
#[error("The user has been deactivated")]
|
||||||
|
UserDeactivated,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for AuthenticationError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
Self::InvalidUserId | Self::MissingUserId => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::InvalidUsername,
|
||||||
|
&self.to_string(),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Self::Forbidden => (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Forbidden,
|
||||||
|
&self.to_string(),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Self::UserDeactivated => (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::UserDeactivated,
|
||||||
|
&self.to_string(),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
23
neo/src/api/client_server/errors/mod.rs
Normal file
23
neo/src/api/client_server/errors/mod.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
use crate::types::error_code::ErrorCode;
|
||||||
|
|
||||||
|
pub mod api_error;
|
||||||
|
pub mod authentication_error;
|
||||||
|
pub mod registration_error;
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
errcode: ErrorCode,
|
||||||
|
error: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
retry_after_ms: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorResponse {
|
||||||
|
pub fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
|
||||||
|
Self {
|
||||||
|
errcode,
|
||||||
|
error: error.to_owned(),
|
||||||
|
retry_after_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
58
neo/src/api/client_server/errors/registration_error.rs
Normal file
58
neo/src/api/client_server/errors/registration_error.rs
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
use axum::{http::StatusCode, response::IntoResponse, Json};
|
||||||
|
|
||||||
|
use crate::{responses::registration::RegistrationResponse, types::error_code::ErrorCode};
|
||||||
|
|
||||||
|
use super::ErrorResponse;
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum RegistrationError {
|
||||||
|
#[error("The homeserver requires additional authentication information")]
|
||||||
|
AdditionalAuthenticationInformation,
|
||||||
|
#[error("UserId is missing")]
|
||||||
|
MissingUserId,
|
||||||
|
#[error("The desired user ID is not a valid user name")]
|
||||||
|
InvalidUserId,
|
||||||
|
#[error("The desired user ID is already taken")]
|
||||||
|
UserIdTaken,
|
||||||
|
#[error("Registration is disabled")]
|
||||||
|
RegistrationDisabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for RegistrationError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
RegistrationError::AdditionalAuthenticationInformation => (
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
Json(RegistrationResponse::user_interactive_authorization_info()),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
RegistrationError::InvalidUserId | RegistrationError::MissingUserId => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::InvalidUsername,
|
||||||
|
&self.to_string(),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
RegistrationError::UserIdTaken => (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::UserInUse,
|
||||||
|
&self.to_string(),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
RegistrationError::RegistrationDisabled => (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Forbidden,
|
||||||
|
&self.to_string(),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,3 @@
|
|||||||
pub mod auth;
|
|
||||||
pub mod versions;
|
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
|
pub mod r0;
|
||||||
|
pub mod versions;
|
299
neo/src/api/client_server/r0/auth.rs
Normal file
299
neo/src/api/client_server/r0/auth.rs
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::Query,
|
||||||
|
routing::{get, post},
|
||||||
|
Extension,
|
||||||
|
};
|
||||||
|
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
use neo_util::password;
|
||||||
|
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||||
|
use sea_orm::{ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
api::client_server::errors::{
|
||||||
|
api_error::ApiError, authentication_error::AuthenticationError,
|
||||||
|
registration_error::RegistrationError,
|
||||||
|
},
|
||||||
|
ruma_wrapper::{RumaRequest, RumaResponse},
|
||||||
|
Config,
|
||||||
|
};
|
||||||
|
|
||||||
|
use ruma::api::client::{
|
||||||
|
account, session,
|
||||||
|
uiaa::{IncomingAuthData, IncomingUserIdentifier},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/login", get(get_login_types).post(login))
|
||||||
|
.route("/r0/register", post(post_register))
|
||||||
|
.route("/r0/register/available", get(get_username_available))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_login_types() -> Result<RumaResponse<session::get_login_types::v3::Response>, ApiError>
|
||||||
|
{
|
||||||
|
use session::get_login_types::v3::*;
|
||||||
|
|
||||||
|
Ok(RumaResponse(session::get_login_types::v3::Response::new(
|
||||||
|
vec![LoginType::Password(PasswordLoginType::new())],
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn login(
|
||||||
|
Extension(config): Extension<Arc<Config>>,
|
||||||
|
Extension(db): Extension<DatabaseConnection>,
|
||||||
|
RumaRequest(req): RumaRequest<session::login::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<session::login::v3::Response>, ApiError> {
|
||||||
|
use session::login::v3::*;
|
||||||
|
|
||||||
|
match req.login_info {
|
||||||
|
IncomingLoginInfo::Password(incoming_password) => {
|
||||||
|
let password = incoming_password.password;
|
||||||
|
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
|
||||||
|
incoming_password.identifier
|
||||||
|
{
|
||||||
|
ruma::UserId::parse_with_server_name(user_id, &config.server_name)
|
||||||
|
.map_err(|_| AuthenticationError::InvalidUserId)?
|
||||||
|
} else {
|
||||||
|
return Err(AuthenticationError::InvalidUserId.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
let db_user = User::find()
|
||||||
|
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||||
|
.one(&db)
|
||||||
|
.await?
|
||||||
|
.ok_or(AuthenticationError::InvalidUserId)?;
|
||||||
|
|
||||||
|
if !password::password_correct(&password, &db_user.password_hash)
|
||||||
|
.map_err(|_| AuthenticationError::Forbidden)?
|
||||||
|
{
|
||||||
|
return Err(AuthenticationError::Forbidden.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let device = if let Some(device_id) = req.device_id {
|
||||||
|
Device::find()
|
||||||
|
.belongs_to(&db_user)
|
||||||
|
.filter(devices::Column::DeviceId.eq(device_id.as_str()))
|
||||||
|
.one(&db)
|
||||||
|
.await?
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
let device_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let display_name = req
|
||||||
|
.initial_device_display_name
|
||||||
|
.unwrap_or_else(|| "Generic Device".into());
|
||||||
|
|
||||||
|
let device = devices::ActiveModel {
|
||||||
|
device_id: Set(device_id),
|
||||||
|
display_name: Set(display_name),
|
||||||
|
user_uuid: Set(db_user.uuid),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
device.insert(&db).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let key = thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(32)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let session = sessions::ActiveModel {
|
||||||
|
device_uuid: Set(device.uuid),
|
||||||
|
key: Set(key),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let session = session.insert(&db).await?;
|
||||||
|
let response = Response::new(
|
||||||
|
user_id,
|
||||||
|
session.key,
|
||||||
|
ruma::OwnedDeviceId::from(device.device_id),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(RumaResponse(response))
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_username_available(
|
||||||
|
Extension(config): Extension<Arc<Config>>,
|
||||||
|
Extension(db): Extension<DatabaseConnection>,
|
||||||
|
Query(params): Query<HashMap<String, String>>,
|
||||||
|
) -> Result<RumaResponse<account::get_username_availability::v3::Response>, ApiError> {
|
||||||
|
use account::get_username_availability::v3::*;
|
||||||
|
tracing::debug!("username_available hit");
|
||||||
|
|
||||||
|
let username = params
|
||||||
|
.get("username")
|
||||||
|
.ok_or(RegistrationError::MissingUserId)?
|
||||||
|
.to_owned();
|
||||||
|
let user_id = ruma::UserId::parse_with_server_name(username, &config.server_name)
|
||||||
|
.map_err(|_| RegistrationError::InvalidUserId)?;
|
||||||
|
|
||||||
|
let available = User::find()
|
||||||
|
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||||
|
.one(&db)
|
||||||
|
.await?
|
||||||
|
.is_none();
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new(available)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn post_register(
|
||||||
|
Extension(config): Extension<Arc<Config>>,
|
||||||
|
Extension(db): Extension<DatabaseConnection>,
|
||||||
|
RumaRequest(req): RumaRequest<account::register::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<account::register::v3::Response>, ApiError> {
|
||||||
|
use account::register::v3::*;
|
||||||
|
|
||||||
|
config
|
||||||
|
.enable_registration
|
||||||
|
.then(|| true)
|
||||||
|
.ok_or(RegistrationError::RegistrationDisabled)?;
|
||||||
|
|
||||||
|
match req.auth {
|
||||||
|
Some(auth) => match auth {
|
||||||
|
IncomingAuthData::Password(incoming_password) => {
|
||||||
|
let password = incoming_password.password;
|
||||||
|
let user_id = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) =
|
||||||
|
incoming_password.identifier
|
||||||
|
{
|
||||||
|
ruma::UserId::parse_with_server_name(user_id, &config.server_name)
|
||||||
|
.map_err(|_| AuthenticationError::InvalidUserId)?
|
||||||
|
} else {
|
||||||
|
return Err(AuthenticationError::InvalidUserId.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
if User::find()
|
||||||
|
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||||
|
.one(&db)
|
||||||
|
.await?
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
return Err(RegistrationError::UserIdTaken.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
let display_name = req
|
||||||
|
.initial_device_display_name
|
||||||
|
.unwrap_or_else(|| "Random Display Name".into());
|
||||||
|
|
||||||
|
let pw_hash = password::hash_password(&password)?;
|
||||||
|
|
||||||
|
let user = users::ActiveModel {
|
||||||
|
user_id: Set(user_id.to_string()),
|
||||||
|
display_name: Set(user_id.to_string()),
|
||||||
|
password_hash: Set(pw_hash),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.insert(&db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let device = devices::ActiveModel {
|
||||||
|
display_name: Set(display_name),
|
||||||
|
user_uuid: Set(user.uuid),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.insert(&db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut response = Response::new(
|
||||||
|
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
|
||||||
|
);
|
||||||
|
if !req.inhibit_login {
|
||||||
|
let key = thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(32)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let session = sessions::ActiveModel {
|
||||||
|
device_uuid: Set(device.uuid),
|
||||||
|
key: Set(key),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.insert(&db)
|
||||||
|
.await?;
|
||||||
|
response.access_token = Some(session.key);
|
||||||
|
}
|
||||||
|
if !req.inhibit_login {
|
||||||
|
response.device_id = Some(device.device_id.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(RumaResponse(response))
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
},
|
||||||
|
// For clients not following/using UIAA
|
||||||
|
None => {
|
||||||
|
let password = req
|
||||||
|
.password
|
||||||
|
.ok_or("password missing")
|
||||||
|
.map_err(|_| RegistrationError::AdditionalAuthenticationInformation)?;
|
||||||
|
let user_id = if let Some(username) = req.username {
|
||||||
|
ruma::UserId::parse_with_server_name(username, &config.server_name)
|
||||||
|
.map_err(|e| anyhow::anyhow!(e))?
|
||||||
|
} else {
|
||||||
|
return Err(AuthenticationError::InvalidUserId.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
if User::find()
|
||||||
|
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||||
|
.one(&db)
|
||||||
|
.await?
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
return Err(RegistrationError::UserIdTaken.into());
|
||||||
|
};
|
||||||
|
|
||||||
|
let display_name = req
|
||||||
|
.initial_device_display_name
|
||||||
|
.unwrap_or_else(|| "Random Display Name".into());
|
||||||
|
|
||||||
|
let pw_hash = password::hash_password(&password)?;
|
||||||
|
|
||||||
|
let user = users::ActiveModel {
|
||||||
|
user_id: Set(user_id.to_string()),
|
||||||
|
display_name: Set(user_id.to_string()),
|
||||||
|
password_hash: Set(pw_hash),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.insert(&db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let device = devices::ActiveModel {
|
||||||
|
display_name: Set(display_name),
|
||||||
|
user_uuid: Set(user.uuid),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.insert(&db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut response =
|
||||||
|
Response::new(ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?);
|
||||||
|
if !req.inhibit_login {
|
||||||
|
let key = thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(32)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let session = sessions::ActiveModel {
|
||||||
|
device_uuid: Set(device.uuid),
|
||||||
|
key: Set(key),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
.insert(&db)
|
||||||
|
.await?;
|
||||||
|
response.access_token = Some(session.key);
|
||||||
|
}
|
||||||
|
if !req.inhibit_login {
|
||||||
|
response.device_id = Some(ruma::OwnedDeviceId::from(device.device_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(RumaResponse(response))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
25
neo/src/api/client_server/r0/filter.rs
Normal file
25
neo/src/api/client_server/r0/filter.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::routing::post;
|
||||||
|
use axum::Extension;
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
|
||||||
|
use ruma::api::client::filter;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/user/:user_id/filter", post(create_filter))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_filter(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
RumaRequest(_req): RumaRequest<filter::create_filter::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<filter::create_filter::v3::Response>, ApiError> {
|
||||||
|
use filter::create_filter::v3::*;
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new("a".into())))
|
||||||
|
}
|
25
neo/src/api/client_server/r0/keys.rs
Normal file
25
neo/src/api/client_server/r0/keys.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::routing::post;
|
||||||
|
use axum::Extension;
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
|
||||||
|
use ruma::api::client::keys;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/keys/query", post(get_keys))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_keys(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
RumaRequest(_req): RumaRequest<keys::get_keys::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<keys::get_keys::v3::Response>, ApiError> {
|
||||||
|
use keys::get_keys::v3::*;
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new()))
|
||||||
|
}
|
143
neo/src/api/client_server/r0/mod.rs
Normal file
143
neo/src/api/client_server/r0/mod.rs
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
http::{Request, StatusCode},
|
||||||
|
middleware::Next,
|
||||||
|
response::IntoResponse,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use neo_entity::{
|
||||||
|
devices::Entity as Device,
|
||||||
|
sessions::{self, Entity as Session},
|
||||||
|
users::Entity as User,
|
||||||
|
};
|
||||||
|
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter};
|
||||||
|
|
||||||
|
use crate::types::error_code::ErrorCode;
|
||||||
|
|
||||||
|
use super::errors::ErrorResponse;
|
||||||
|
|
||||||
|
pub mod auth;
|
||||||
|
pub mod filter;
|
||||||
|
pub mod keys;
|
||||||
|
pub mod presence;
|
||||||
|
pub mod push;
|
||||||
|
pub mod room;
|
||||||
|
pub mod sync;
|
||||||
|
pub mod thirdparty;
|
||||||
|
|
||||||
|
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||||
|
let db: &DatabaseConnection = req.extensions().get().unwrap();
|
||||||
|
let auth_header = req
|
||||||
|
.headers()
|
||||||
|
.get(axum::http::header::AUTHORIZATION)
|
||||||
|
.and_then(|header| header.to_str().ok());
|
||||||
|
|
||||||
|
if auth_header.is_none() {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Forbidden,
|
||||||
|
"Authorization Header not given",
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
let auth_header = auth_header.expect("Validated above");
|
||||||
|
let idx = auth_header.find(' ');
|
||||||
|
|
||||||
|
let idx = match idx {
|
||||||
|
Some(idx) => idx,
|
||||||
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Forbidden,
|
||||||
|
"Invalid Authorization Header",
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = match Session::find()
|
||||||
|
.filter(sessions::Column::Key.eq(&auth_header[idx + 1..]))
|
||||||
|
.one(db)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(session) => session,
|
||||||
|
Err(_) => {
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse::new(
|
||||||
|
ErrorCode::Unknown,
|
||||||
|
"Internal Server Error",
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = match session {
|
||||||
|
Some(session) => session,
|
||||||
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let device = match session.find_related(Device).one(db).await {
|
||||||
|
Ok(device) => device,
|
||||||
|
Err(_) => {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let device = match device {
|
||||||
|
Some(device) => device,
|
||||||
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let user = match device.find_related(User).one(db).await {
|
||||||
|
Ok(user) => user,
|
||||||
|
Err(_) => {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let user = match user {
|
||||||
|
Some(user) => user,
|
||||||
|
None => {
|
||||||
|
return (
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
req.extensions_mut().insert(Arc::new(user));
|
||||||
|
|
||||||
|
next.run(req).await.into_response()
|
||||||
|
}
|
36
neo/src/api/client_server/r0/presence.rs
Normal file
36
neo/src/api/client_server/r0/presence.rs
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::routing::{get, put};
|
||||||
|
use axum::Extension;
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
|
||||||
|
use ruma::api::client::presence;
|
||||||
|
use ruma::presence::PresenceState;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/presence/:user_id/status", put(set_presence))
|
||||||
|
.route("/r0/presence/:user_id/status", get(get_presence))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn set_presence(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
RumaRequest(_req): RumaRequest<presence::set_presence::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<presence::set_presence::v3::Response>, ApiError> {
|
||||||
|
use presence::set_presence::v3::*;
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_presence(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
RumaRequest(_req): RumaRequest<presence::get_presence::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<presence::get_presence::v3::Response>, ApiError> {
|
||||||
|
use presence::get_presence::v3::*;
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new(PresenceState::Unavailable)))
|
||||||
|
}
|
26
neo/src/api/client_server/r0/push.rs
Normal file
26
neo/src/api/client_server/r0/push.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::routing::get;
|
||||||
|
use axum::Extension;
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
|
||||||
|
use ruma::api::client::push;
|
||||||
|
use ruma::push::Ruleset;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/pushrules/", get(get_pushrules))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_pushrules(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
RumaRequest(_req): RumaRequest<push::get_pushrules_all::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<push::get_pushrules_all::v3::Response>, ApiError> {
|
||||||
|
use push::get_pushrules_all::v3::*;
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new(Ruleset::new())))
|
||||||
|
}
|
19
neo/src/api/client_server/r0/room.rs
Normal file
19
neo/src/api/client_server/r0/room.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
use axum::routing::post;
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
use crate::ruma_wrapper::RumaRequest;
|
||||||
|
|
||||||
|
use ruma::api::client::room;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/createRoom", post(create_room))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_room(
|
||||||
|
RumaRequest(req): RumaRequest<room::create_room::v3::IncomingRequest>,
|
||||||
|
) -> Result<String, ApiError> {
|
||||||
|
dbg!(req);
|
||||||
|
Ok("".into())
|
||||||
|
}
|
29
neo/src/api/client_server/r0/sync.rs
Normal file
29
neo/src/api/client_server/r0/sync.rs
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::routing::get;
|
||||||
|
use axum::Extension;
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
|
||||||
|
use ruma::api::client::sync;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/sync", get(sync_events))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sync_events(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
RumaRequest(req): RumaRequest<sync::sync_events::v3::IncomingRequest>,
|
||||||
|
) -> Result<RumaResponse<sync::sync_events::v3::Response>, ApiError> {
|
||||||
|
use sync::sync_events::v3::*;
|
||||||
|
|
||||||
|
if let Some(timeout) = req.timeout {
|
||||||
|
tokio::time::sleep(timeout).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(RumaResponse(Response::new("todo".into())))
|
||||||
|
}
|
23
neo/src/api/client_server/r0/thirdparty.rs
Normal file
23
neo/src/api/client_server/r0/thirdparty.rs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
use std::{collections::BTreeMap, sync::Arc};
|
||||||
|
|
||||||
|
use axum::{routing::get, Extension};
|
||||||
|
|
||||||
|
use neo_entity::prelude::*;
|
||||||
|
|
||||||
|
use crate::{api::client_server::errors::api_error::ApiError, ruma_wrapper::RumaResponse};
|
||||||
|
|
||||||
|
use ruma::api::client::thirdparty;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/r0/thirdparty/protocols", get(get_thirdparty_protocols))
|
||||||
|
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_thirdparty_protocols(
|
||||||
|
Extension(_user): Extension<Arc<UserModel>>,
|
||||||
|
) -> Result<RumaResponse<thirdparty::get_protocols::v3::Response>, ApiError> {
|
||||||
|
Ok(RumaResponse(thirdparty::get_protocols::v3::Response::new(
|
||||||
|
BTreeMap::new(),
|
||||||
|
)))
|
||||||
|
}
|
20
neo/src/api/client_server/versions.rs
Normal file
20
neo/src/api/client_server/versions.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use axum::routing::get;
|
||||||
|
|
||||||
|
use crate::ruma_wrapper::RumaResponse;
|
||||||
|
use ruma::api::client::discovery;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router {
|
||||||
|
axum::Router::new().route("/versions", get(get_client_versions))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument]
|
||||||
|
async fn get_client_versions() -> RumaResponse<discovery::get_supported_versions::Response> {
|
||||||
|
use discovery::get_supported_versions::*;
|
||||||
|
|
||||||
|
RumaResponse(Response::new(vec![
|
||||||
|
"r0.5.0".to_owned(),
|
||||||
|
"r0.6.0".to_owned(),
|
||||||
|
"v1.1".to_owned(),
|
||||||
|
"v1.2".to_owned(),
|
||||||
|
]))
|
||||||
|
}
|
17
neo/src/config.rs
Normal file
17
neo/src/config.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
use ruma::{OwnedServerName, ServerName};
|
||||||
|
|
||||||
|
pub struct Config {
|
||||||
|
pub db_path: String,
|
||||||
|
pub server_name: OwnedServerName,
|
||||||
|
pub enable_registration: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
db_path: "sqlite://db.sqlite3".into(),
|
||||||
|
server_name: ServerName::parse("fuckwit.dev").unwrap(),
|
||||||
|
enable_registration: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
77
neo/src/main.rs
Normal file
77
neo/src/main.rs
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
handler::Handler,
|
||||||
|
http::{Request, StatusCode},
|
||||||
|
Extension, Router,
|
||||||
|
};
|
||||||
|
use config::Config;
|
||||||
|
use neo_migration::{Migrator, MigratorTrait};
|
||||||
|
use sea_orm::Database;
|
||||||
|
use tower_http::{
|
||||||
|
cors::CorsLayer,
|
||||||
|
trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||||
|
};
|
||||||
|
use tracing::Level;
|
||||||
|
|
||||||
|
mod api;
|
||||||
|
mod config;
|
||||||
|
mod responses;
|
||||||
|
mod ruma_wrapper;
|
||||||
|
mod types;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
if std::env::var("RUST_LOG").is_err() {
|
||||||
|
std::env::set_var("RUST_LOG", "debug,sqlx=off");
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
let config = Arc::new(Config::default());
|
||||||
|
|
||||||
|
let pool = Database::connect(config.db_path.clone()).await?;
|
||||||
|
Migrator::up(&pool, None).await?;
|
||||||
|
|
||||||
|
// TODO: set correct CORS headers
|
||||||
|
let cors = CorsLayer::new()
|
||||||
|
.allow_origin(tower_http::cors::Any)
|
||||||
|
.allow_methods(tower_http::cors::Any)
|
||||||
|
.allow_headers(tower_http::cors::Any);
|
||||||
|
|
||||||
|
let tracing_layer = TraceLayer::new_for_http()
|
||||||
|
.make_span_with(DefaultMakeSpan::new().level(Level::INFO))
|
||||||
|
.on_request(DefaultOnRequest::default().level(Level::INFO))
|
||||||
|
.on_response(DefaultOnResponse::default().level(Level::INFO));
|
||||||
|
|
||||||
|
let client_server = Router::new()
|
||||||
|
.merge(api::client_server::versions::routes())
|
||||||
|
.merge(api::client_server::r0::auth::routes())
|
||||||
|
.merge(api::client_server::r0::thirdparty::routes())
|
||||||
|
.merge(api::client_server::r0::room::routes())
|
||||||
|
.merge(api::client_server::r0::presence::routes())
|
||||||
|
.merge(api::client_server::r0::push::routes())
|
||||||
|
.merge(api::client_server::r0::filter::routes())
|
||||||
|
.merge(api::client_server::r0::sync::routes())
|
||||||
|
.merge(api::client_server::r0::keys::routes());
|
||||||
|
|
||||||
|
let router = Router::new()
|
||||||
|
.nest("/_matrix/client", client_server)
|
||||||
|
.layer(cors)
|
||||||
|
.layer(tracing_layer)
|
||||||
|
.fallback(fallback.into_service())
|
||||||
|
.layer(Extension(pool))
|
||||||
|
.layer(Extension(config));
|
||||||
|
|
||||||
|
let _ = axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
|
||||||
|
.serve(router.into_make_service())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fallback(request: Request<Body>) -> StatusCode {
|
||||||
|
tracing::error!("{} {}", request.method(), request.uri());
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
}
|
15
neo/src/responses/registration.rs
Normal file
15
neo/src/responses/registration.rs
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
use crate::types::user_interactive_authorization::UserInteractiveAuthorizationInfo;
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum RegistrationResponse {
|
||||||
|
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RegistrationResponse {
|
||||||
|
pub fn user_interactive_authorization_info() -> Self {
|
||||||
|
RegistrationResponse::UserInteractiveAuthorizationInfo(
|
||||||
|
UserInteractiveAuthorizationInfo::new(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
66
neo/src/ruma_wrapper.rs
Normal file
66
neo/src/ruma_wrapper.rs
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
use axum::{
|
||||||
|
body::{Bytes, Full, HttpBody},
|
||||||
|
extract::{FromRequest, Path},
|
||||||
|
response::IntoResponse,
|
||||||
|
BoxError,
|
||||||
|
};
|
||||||
|
use http::StatusCode;
|
||||||
|
use ruma::{
|
||||||
|
api::{IncomingRequest, OutgoingResponse},
|
||||||
|
exports::bytes::{BufMut, BytesMut},
|
||||||
|
serde::CanonicalJsonValue,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::api::client_server::errors::api_error::ApiError;
|
||||||
|
|
||||||
|
pub struct RumaRequest<R>(pub R)
|
||||||
|
where
|
||||||
|
R: IncomingRequest;
|
||||||
|
|
||||||
|
#[axum::async_trait]
|
||||||
|
impl<R, B> FromRequest<B> for RumaRequest<R>
|
||||||
|
where
|
||||||
|
R: IncomingRequest,
|
||||||
|
B: HttpBody + Send,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<BoxError>,
|
||||||
|
{
|
||||||
|
type Rejection = ApiError;
|
||||||
|
|
||||||
|
async fn from_request(
|
||||||
|
req: &mut axum::extract::RequestParts<B>,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let path_params = Path::<Vec<String>>::from_request(req)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
let body = Bytes::from_request(req)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
|
||||||
|
let json = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||||
|
|
||||||
|
let mut buf = BytesMut::new().writer();
|
||||||
|
serde_json::to_writer(&mut buf, &json).expect("can't fail");
|
||||||
|
let body = buf.into_inner().freeze();
|
||||||
|
|
||||||
|
let builder = http::Request::builder().uri(req.uri()).method(req.method());
|
||||||
|
let request = builder.body(body).map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
|
||||||
|
Ok(Self(
|
||||||
|
R::try_from_http_request(request, &path_params).map_err(|e| anyhow::anyhow!(e))?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RumaResponse<R>(pub R)
|
||||||
|
where
|
||||||
|
R: OutgoingResponse;
|
||||||
|
|
||||||
|
impl<R: OutgoingResponse> IntoResponse for RumaResponse<R> {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self.0.try_into_http_response::<BytesMut>() {
|
||||||
|
Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(),
|
||||||
|
Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
#[allow(unused)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub enum ErrorCode {
|
pub enum ErrorCode {
|
||||||
@ -12,6 +13,7 @@ pub enum ErrorCode {
|
|||||||
UserInUse,
|
UserInUse,
|
||||||
InvalidUsername,
|
InvalidUsername,
|
||||||
Exclusive,
|
Exclusive,
|
||||||
|
UserDeactivated,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl serde::Serialize for ErrorCode {
|
impl serde::Serialize for ErrorCode {
|
||||||
@ -31,6 +33,7 @@ impl serde::Serialize for ErrorCode {
|
|||||||
ErrorCode::UserInUse => "M_USER_IN_USE",
|
ErrorCode::UserInUse => "M_USER_IN_USE",
|
||||||
ErrorCode::InvalidUsername => "M_INVALID_USERNAME",
|
ErrorCode::InvalidUsername => "M_INVALID_USERNAME",
|
||||||
ErrorCode::Exclusive => "M_EXCLUSIVE",
|
ErrorCode::Exclusive => "M_EXCLUSIVE",
|
||||||
|
ErrorCode::UserDeactivated => "M_USER_DEACTIVATED",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
82
neo/src/types/event_type.rs
Normal file
82
neo/src/types/event_type.rs
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
use sqlx::{encode::IsNull, Sqlite};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum EventType {
|
||||||
|
RoomCreate,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl sqlx::Type<Sqlite> for EventType {
|
||||||
|
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
|
||||||
|
<&str as sqlx::Type<Sqlite>>::type_info()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'e> sqlx::Encode<'e, Sqlite> for EventType {
|
||||||
|
fn encode_by_ref(
|
||||||
|
&self,
|
||||||
|
buf: &mut <Sqlite as sqlx::database::HasArguments<'e>>::ArgumentBuffer,
|
||||||
|
) -> sqlx::encode::IsNull {
|
||||||
|
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
|
||||||
|
match self {
|
||||||
|
EventType::RoomCreate => "m.room.create",
|
||||||
|
EventType::Unknown => "???",
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
));
|
||||||
|
|
||||||
|
IsNull::No
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'d> sqlx::Decode<'d, Sqlite> for EventType {
|
||||||
|
fn decode(
|
||||||
|
value: <Sqlite as sqlx::database::HasValueRef<'d>>::ValueRef,
|
||||||
|
) -> Result<Self, sqlx::error::BoxDynError> {
|
||||||
|
Ok(match <&str as sqlx::Decode<Sqlite>>::decode(value)? {
|
||||||
|
"m.room.create" => EventType::RoomCreate,
|
||||||
|
_ => EventType::Unknown,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl serde::Serialize for EventType {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
serializer.serialize_str(match self {
|
||||||
|
EventType::RoomCreate => "m.room.create",
|
||||||
|
EventType::Unknown => "dev.fuckwit.unknown_event",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> serde::Deserialize<'de> for EventType {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
struct IdentifierVisitor;
|
||||||
|
|
||||||
|
impl<'de> serde::de::Visitor<'de> for IdentifierVisitor {
|
||||||
|
type Value = EventType;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
formatter.write_str("Identifier")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: serde::de::Error,
|
||||||
|
{
|
||||||
|
match v {
|
||||||
|
"m.id.user" => Ok(EventType::RoomCreate),
|
||||||
|
_ => Err(serde::de::Error::custom("Unknown identifier")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_str(IdentifierVisitor {})
|
||||||
|
}
|
||||||
|
}
|
4
neo/src/types/mod.rs
Normal file
4
neo/src/types/mod.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pub mod error_code;
|
||||||
|
//pub mod event_type;
|
||||||
|
pub mod flow;
|
||||||
|
pub mod user_interactive_authorization;
|
@ -0,0 +1,37 @@
|
|||||||
|
{
|
||||||
|
"query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?",
|
||||||
|
"describe": {
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "uuid: Uuid",
|
||||||
|
"ordinal": 0,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_uuid: Uuid",
|
||||||
|
"ordinal": 1,
|
||||||
|
"type_info": "Int64"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "device_id",
|
||||||
|
"ordinal": 2,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "display_name",
|
||||||
|
"ordinal": 3,
|
||||||
|
"type_info": "Text"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"Right": 1
|
||||||
|
},
|
||||||
|
"nullable": [
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"hash": "1843c2b3e548d1dd13694a65ca1ba123da38668c3fc5bc431fe5884a6fc25f71"
|
||||||
|
}
|
@ -1,107 +0,0 @@
|
|||||||
use std::{collections::HashMap, sync::Arc};
|
|
||||||
|
|
||||||
use axum::{
|
|
||||||
extract::Query,
|
|
||||||
http::StatusCode,
|
|
||||||
routing::{get, post},
|
|
||||||
Extension, Json,
|
|
||||||
};
|
|
||||||
use sqlx::SqlitePool;
|
|
||||||
|
|
||||||
use crate::responses::registration::RegistrationResponse;
|
|
||||||
use crate::{
|
|
||||||
models::devices::Device,
|
|
||||||
responses::{flow::Flows, registration::RegistrationSuccess},
|
|
||||||
};
|
|
||||||
use crate::{
|
|
||||||
models::users::User,
|
|
||||||
requests::registration::RegistrationRequest,
|
|
||||||
responses::username_available::UsernameAvailable,
|
|
||||||
types::{authentication_data::AuthenticationData, matrix_user_id::UserId},
|
|
||||||
Config,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::errors::{api_error::ApiError, registration_error::RegistrationError};
|
|
||||||
|
|
||||||
pub fn routes() -> axum::Router {
|
|
||||||
axum::Router::new()
|
|
||||||
.route("/r0/login", get(get_login).post(post_login))
|
|
||||||
.route("/r0/register", post(post_register))
|
|
||||||
.route("/r0/register/available", get(get_username_available))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn get_login() -> Json<Flows> {
|
|
||||||
Json(Flows::new())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
async fn post_login(body: String) -> StatusCode {
|
|
||||||
dbg!(body);
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
async fn get_username_available(
|
|
||||||
Extension(config): Extension<Arc<Config>>,
|
|
||||||
Extension(db): Extension<SqlitePool>,
|
|
||||||
Query(params): Query<HashMap<String, String>>,
|
|
||||||
) -> Result<Json<UsernameAvailable>, ApiError> {
|
|
||||||
let username = params
|
|
||||||
.get("username")
|
|
||||||
.ok_or(RegistrationError::MissingUserId)?;
|
|
||||||
let user_id = UserId::new(username, &config.homeserver_name)
|
|
||||||
.ok()
|
|
||||||
.ok_or(RegistrationError::InvalidUserId)?;
|
|
||||||
let exists = User::exists(&db, &user_id).await?;
|
|
||||||
|
|
||||||
Ok(Json(UsernameAvailable::new(!exists)))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
|
||||||
async fn post_register(
|
|
||||||
Extension(config): Extension<Arc<Config>>,
|
|
||||||
Extension(db): Extension<SqlitePool>,
|
|
||||||
Json(body): Json<RegistrationRequest>,
|
|
||||||
) -> Result<Json<RegistrationResponse>, ApiError> {
|
|
||||||
body.auth()
|
|
||||||
.ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
|
|
||||||
|
|
||||||
let (user, device) = match &body.auth().expect("must be Some") {
|
|
||||||
AuthenticationData::Password(auth_data) => {
|
|
||||||
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
|
|
||||||
let user_id = UserId::new(username, &config.homeserver_name)
|
|
||||||
.ok()
|
|
||||||
.ok_or(RegistrationError::InvalidUserId)?;
|
|
||||||
|
|
||||||
User::exists(&db, &user_id)
|
|
||||||
.await?
|
|
||||||
.then(|| ())
|
|
||||||
.ok_or(RegistrationError::UserIdTaken)?;
|
|
||||||
|
|
||||||
let password = auth_data.password();
|
|
||||||
|
|
||||||
let display_name = match body.initial_device_display_name() {
|
|
||||||
Some(display_name) => display_name.as_ref(),
|
|
||||||
None => "Random displayname",
|
|
||||||
};
|
|
||||||
|
|
||||||
let user = User::create(&db, &user_id, &user_id.to_string(), password).await?;
|
|
||||||
let device = Device::create(&db, &user, "test", display_name).await?;
|
|
||||||
|
|
||||||
(user, device)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if body.inhibit_login().unwrap_or(false) {
|
|
||||||
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
|
|
||||||
|
|
||||||
Ok(Json(RegistrationResponse::Success(resp)))
|
|
||||||
} else {
|
|
||||||
let session = device.create_session(&db).await?;
|
|
||||||
let resp =
|
|
||||||
RegistrationSuccess::new(Some(session.value()), device.device_id(), user.user_id());
|
|
||||||
|
|
||||||
Ok(Json(RegistrationResponse::Success(resp)))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,112 +0,0 @@
|
|||||||
use axum::http::StatusCode;
|
|
||||||
use axum::response::IntoResponse;
|
|
||||||
use axum::Json;
|
|
||||||
use sqlx::Statement;
|
|
||||||
|
|
||||||
use crate::responses::registration::RegistrationResponse;
|
|
||||||
use crate::types::error_code::ErrorCode;
|
|
||||||
|
|
||||||
use super::registration_error::RegistrationError;
|
|
||||||
|
|
||||||
macro_rules! map_err {
|
|
||||||
($err:ident, $($type:path => $target:path),+) => {
|
|
||||||
$(
|
|
||||||
if $err.is::<$type>() {
|
|
||||||
return $target($err.downcast().unwrap());
|
|
||||||
}
|
|
||||||
)*
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
|
||||||
struct ErrorResponse {
|
|
||||||
errcode: ErrorCode,
|
|
||||||
error: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
retry_after_ms: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ErrorResponse {
|
|
||||||
fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
|
|
||||||
Self {
|
|
||||||
errcode,
|
|
||||||
error: error.to_owned(),
|
|
||||||
retry_after_ms,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum ApiError {
|
|
||||||
#[error("Registration Error")]
|
|
||||||
RegistrationError(#[from] RegistrationError),
|
|
||||||
|
|
||||||
#[error("Database Error")]
|
|
||||||
DBError(#[from] sqlx::Error),
|
|
||||||
|
|
||||||
#[error("Generic Error")]
|
|
||||||
Generic(anyhow::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<anyhow::Error> for ApiError {
|
|
||||||
fn from(err: anyhow::Error) -> Self {
|
|
||||||
map_err!(err,
|
|
||||||
sqlx::Error => ApiError::DBError,
|
|
||||||
RegistrationError => ApiError::RegistrationError
|
|
||||||
);
|
|
||||||
|
|
||||||
ApiError::Generic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IntoResponse for ApiError {
|
|
||||||
fn into_response(self) -> axum::response::Response {
|
|
||||||
match self {
|
|
||||||
ApiError::RegistrationError(registration_error) => match registration_error {
|
|
||||||
RegistrationError::AdditionalAuthenticationInformation => (
|
|
||||||
StatusCode::UNAUTHORIZED,
|
|
||||||
Json(RegistrationResponse::user_interactive_authorization_info()),
|
|
||||||
).into_response(),
|
|
||||||
RegistrationError::InvalidUserId => (StatusCode::OK, Json(
|
|
||||||
ErrorResponse::new(
|
|
||||||
ErrorCode::InvalidUsername,
|
|
||||||
®istration_error.to_string(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
)).into_response(),
|
|
||||||
RegistrationError::MissingUserId => (StatusCode::OK, String::new()).into_response(),
|
|
||||||
RegistrationError::UserIdTaken => (
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
Json(ErrorResponse::new(
|
|
||||||
ErrorCode::UserInUse,
|
|
||||||
®istration_error.to_string(),
|
|
||||||
None,
|
|
||||||
)),
|
|
||||||
)
|
|
||||||
.into_response(),
|
|
||||||
},
|
|
||||||
ApiError::DBError(err) => {
|
|
||||||
tracing::error!("{}", err.to_string());
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse::new(
|
|
||||||
ErrorCode::Unknown,
|
|
||||||
"Database error! If you are the application owner please take a look at your application logs.",
|
|
||||||
None,
|
|
||||||
)),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
ApiError::Generic(err) => (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse::new(
|
|
||||||
ErrorCode::Unknown,
|
|
||||||
"Fatal error occured! If you are the application owner please take a look at your application logs.",
|
|
||||||
None,
|
|
||||||
)),
|
|
||||||
)
|
|
||||||
.into_response(),
|
|
||||||
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,2 +0,0 @@
|
|||||||
pub mod api_error;
|
|
||||||
pub mod registration_error;
|
|
@ -1,11 +0,0 @@
|
|||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum RegistrationError {
|
|
||||||
#[error("The homeserver requires additional authentication information")]
|
|
||||||
AdditionalAuthenticationInformation,
|
|
||||||
#[error("UserId is missing")]
|
|
||||||
MissingUserId,
|
|
||||||
#[error("The desired user ID is not a valid user name")]
|
|
||||||
InvalidUserId,
|
|
||||||
#[error("The desired user ID is already taken")]
|
|
||||||
UserIdTaken,
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
use axum::{routing::get, Json};
|
|
||||||
|
|
||||||
use crate::responses::versions::Versions;
|
|
||||||
|
|
||||||
pub fn routes() -> axum::Router {
|
|
||||||
axum::Router::new().route("/versions", get(get_client_versions))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn get_client_versions() -> Json<Versions> {
|
|
||||||
Json(Versions::default())
|
|
||||||
}
|
|
78
src/main.rs
78
src/main.rs
@ -1,78 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use axum::{
|
|
||||||
body::Body,
|
|
||||||
handler::Handler,
|
|
||||||
http::{Request, StatusCode},
|
|
||||||
Extension, Router,
|
|
||||||
};
|
|
||||||
use tower_http::{
|
|
||||||
cors::CorsLayer,
|
|
||||||
trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
|
||||||
};
|
|
||||||
use tracing::Level;
|
|
||||||
|
|
||||||
mod api;
|
|
||||||
mod models;
|
|
||||||
mod requests;
|
|
||||||
mod responses;
|
|
||||||
mod types;
|
|
||||||
|
|
||||||
struct Config {
|
|
||||||
db_path: String,
|
|
||||||
homeserver_name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Config {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
db_path: "sqlite://db.sqlite3".into(),
|
|
||||||
homeserver_name: "fuckwit.dev".into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() {
|
|
||||||
if std::env::var("RUST_LOG").is_err() {
|
|
||||||
std::env::set_var("RUST_LOG", "debug");
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
|
|
||||||
let config = Arc::new(Config::default());
|
|
||||||
|
|
||||||
let pool = sqlx::SqlitePool::connect(&config.db_path)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let cors = CorsLayer::new()
|
|
||||||
.allow_origin(tower_http::cors::Any)
|
|
||||||
.allow_methods(tower_http::cors::Any)
|
|
||||||
.allow_headers(tower_http::cors::Any);
|
|
||||||
|
|
||||||
let tracing_layer = TraceLayer::new_for_http();
|
|
||||||
|
|
||||||
let client_server = Router::new()
|
|
||||||
.merge(api::client_server::versions::routes())
|
|
||||||
.merge(api::client_server::auth::routes());
|
|
||||||
|
|
||||||
let router = Router::new()
|
|
||||||
.nest("/_matrix/client", client_server)
|
|
||||||
.layer(cors)
|
|
||||||
.layer(tracing_layer)
|
|
||||||
.layer(Extension(pool))
|
|
||||||
.layer(Extension(config))
|
|
||||||
.fallback(fallback.into_service());
|
|
||||||
|
|
||||||
let _ = axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
|
|
||||||
.serve(router.into_make_service())
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn fallback(request: Request<Body>) -> StatusCode {
|
|
||||||
dbg!(request);
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
|
||||||
}
|
|
@ -1,61 +0,0 @@
|
|||||||
use sqlx::SqlitePool;
|
|
||||||
|
|
||||||
use super::{sessions::Session, users::User};
|
|
||||||
|
|
||||||
pub struct Device {
|
|
||||||
id: i64,
|
|
||||||
user_id: i64,
|
|
||||||
device_id: String,
|
|
||||||
display_name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Device {
|
|
||||||
pub async fn create(
|
|
||||||
conn: &SqlitePool,
|
|
||||||
user: &User,
|
|
||||||
device_id: &str,
|
|
||||||
display_name: &str,
|
|
||||||
) -> anyhow::Result<Self> {
|
|
||||||
let user_id = user.id();
|
|
||||||
Ok(sqlx::query_as!(Self, "insert into devices(user_id, device_id, display_name) values(?, ?, ?) returning id, user_id, device_id, display_name", user_id, device_id, display_name).fetch_one(conn).await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn by_user(conn: &SqlitePool, user: &User) -> anyhow::Result<Self> {
|
|
||||||
let user_id = user.id();
|
|
||||||
Ok(sqlx::query_as!(
|
|
||||||
Self,
|
|
||||||
"select id, user_id, device_id, display_name from devices where user_id = ?",
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
.fetch_one(conn)
|
|
||||||
.await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create_session(&self, conn: &SqlitePool) -> anyhow::Result<Session> {
|
|
||||||
Ok(Session::create(conn, self, "random_session_id").await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the device's id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn id(&self) -> i64 {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the device's user id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn user_id(&self) -> i64 {
|
|
||||||
self.user_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the device's device id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn device_id(&self) -> &str {
|
|
||||||
self.device_id.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the device's display name.
|
|
||||||
#[must_use]
|
|
||||||
pub fn display_name(&self) -> &str {
|
|
||||||
self.display_name.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,41 +0,0 @@
|
|||||||
use sqlx::SqlitePool;
|
|
||||||
|
|
||||||
use super::devices::Device;
|
|
||||||
|
|
||||||
pub struct Session {
|
|
||||||
id: i64,
|
|
||||||
device_id: i64,
|
|
||||||
value: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Session {
|
|
||||||
pub async fn create(conn: &SqlitePool, device: &Device, value: &str) -> anyhow::Result<Self> {
|
|
||||||
let device_id = device.id();
|
|
||||||
Ok(sqlx::query_as!(
|
|
||||||
Self,
|
|
||||||
"insert into sessions(device_id, value) values(?, ?) returning id, device_id, value",
|
|
||||||
device_id,
|
|
||||||
value
|
|
||||||
)
|
|
||||||
.fetch_one(conn)
|
|
||||||
.await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the session's id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn id(&self) -> i64 {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the session's device id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn device_id(&self) -> i64 {
|
|
||||||
self.device_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the session's value.
|
|
||||||
#[must_use]
|
|
||||||
pub fn value(&self) -> &str {
|
|
||||||
self.value.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,58 +0,0 @@
|
|||||||
use sqlx::SqlitePool;
|
|
||||||
|
|
||||||
use crate::types::matrix_user_id::UserId;
|
|
||||||
|
|
||||||
pub struct User {
|
|
||||||
id: i64,
|
|
||||||
user_id: String,
|
|
||||||
display_name: String,
|
|
||||||
password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl User {
|
|
||||||
pub async fn exists(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<bool> {
|
|
||||||
Ok(
|
|
||||||
sqlx::query!("select user_id from users where user_id = ?", user_id)
|
|
||||||
.fetch_optional(conn)
|
|
||||||
.await?
|
|
||||||
.is_some(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create(
|
|
||||||
conn: &SqlitePool,
|
|
||||||
user_id: &UserId,
|
|
||||||
display_name: &str,
|
|
||||||
password: &str,
|
|
||||||
) -> anyhow::Result<Self> {
|
|
||||||
Ok(sqlx::query_as!(Self, "insert into users(user_id, display_name, password) values (?, ?, ?) returning id, user_id, display_name, password", user_id, display_name, password).fetch_one(conn).await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
|
|
||||||
Ok(sqlx::query_as!(
|
|
||||||
Self,
|
|
||||||
"select id, user_id, display_name, password from users where user_id = ?",
|
|
||||||
user_id
|
|
||||||
)
|
|
||||||
.fetch_one(conn)
|
|
||||||
.await?)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the user's id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn id(&self) -> i64 {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the user's user id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn user_id(&self) -> &str {
|
|
||||||
self.user_id.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the user's password.
|
|
||||||
#[must_use]
|
|
||||||
pub fn password(&self) -> &str {
|
|
||||||
self.password.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,59 +0,0 @@
|
|||||||
use crate::types::{authentication_data::AuthenticationData, flow::Flow, identifier::Identifier};
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
|
||||||
pub struct RegistrationRequest {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
auth: Option<AuthenticationData>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
device_id: Option<String>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
inhibit_login: Option<bool>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
initial_device_display_name: Option<String>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
password: Option<String>,
|
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
username: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RegistrationRequest {
|
|
||||||
#[must_use]
|
|
||||||
pub fn auth(&self) -> Option<&AuthenticationData> {
|
|
||||||
self.auth.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the registration request's device id.
|
|
||||||
#[must_use]
|
|
||||||
pub fn device_id(&self) -> Option<&String> {
|
|
||||||
self.device_id.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the registration request's inhibit login.
|
|
||||||
#[must_use]
|
|
||||||
pub fn inhibit_login(&self) -> Option<bool> {
|
|
||||||
self.inhibit_login
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the registration request's initial device display name.
|
|
||||||
#[must_use]
|
|
||||||
pub fn initial_device_display_name(&self) -> Option<&String> {
|
|
||||||
self.initial_device_display_name.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the registration request's password.
|
|
||||||
#[must_use]
|
|
||||||
pub fn password(&self) -> Option<&String> {
|
|
||||||
self.password.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the registration request's username.
|
|
||||||
#[must_use]
|
|
||||||
pub fn username(&self) -> Option<&String> {
|
|
||||||
self.username.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
use crate::types::flow::Flow;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Serialize)]
|
|
||||||
struct FlowWrapper {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
_type: Flow,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Serialize)]
|
|
||||||
pub struct Flows {
|
|
||||||
flows: Vec<FlowWrapper>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Flows {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
flows: vec![FlowWrapper {
|
|
||||||
_type: Flow::Password,
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,4 +0,0 @@
|
|||||||
pub mod flow;
|
|
||||||
pub mod registration;
|
|
||||||
pub mod username_available;
|
|
||||||
pub mod versions;
|
|
@ -1,34 +0,0 @@
|
|||||||
use crate::types::user_interactive_authorization::UserInteractiveAuthorizationInfo;
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum RegistrationResponse {
|
|
||||||
Success(RegistrationSuccess),
|
|
||||||
UserInteractiveAuthorizationInfo(UserInteractiveAuthorizationInfo),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RegistrationResponse {
|
|
||||||
pub fn user_interactive_authorization_info() -> Self {
|
|
||||||
RegistrationResponse::UserInteractiveAuthorizationInfo(
|
|
||||||
UserInteractiveAuthorizationInfo::new(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
|
||||||
pub struct RegistrationSuccess {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
access_token: Option<String>,
|
|
||||||
device_id: String,
|
|
||||||
user_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RegistrationSuccess {
|
|
||||||
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
|
|
||||||
Self {
|
|
||||||
access_token: access_token.and_then(|v| Some(v.to_owned())),
|
|
||||||
device_id: device_id.to_owned(),
|
|
||||||
user_id: user_id.to_owned(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,10 +0,0 @@
|
|||||||
#[derive(Debug, serde::Serialize)]
|
|
||||||
pub struct UsernameAvailable {
|
|
||||||
available: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UsernameAvailable {
|
|
||||||
pub fn new(available: bool) -> Self {
|
|
||||||
Self { available }
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
|
||||||
pub struct Versions {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
unstable_features: Option<HashMap<String, String>>,
|
|
||||||
versions: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Versions {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
unstable_features: None,
|
|
||||||
versions: vec!["v1.2".into()],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,36 +0,0 @@
|
|||||||
use super::{flow::Flow, identifier::Identifier};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum AuthenticationData {
|
|
||||||
Password(AuthenticationPassword),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct AuthenticationPassword {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
_type: Flow,
|
|
||||||
identifier: Identifier,
|
|
||||||
password: String,
|
|
||||||
user: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AuthenticationPassword {
|
|
||||||
/// Get a reference to the authentication password's identifier.
|
|
||||||
#[must_use]
|
|
||||||
pub fn identifier(&self) -> &Identifier {
|
|
||||||
&self.identifier
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the authentication password's password.
|
|
||||||
#[must_use]
|
|
||||||
pub fn password(&self) -> &str {
|
|
||||||
self.password.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a reference to the authentication password's user.
|
|
||||||
#[must_use]
|
|
||||||
pub fn user(&self) -> Option<&String> {
|
|
||||||
self.user.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
use super::identifier_type::IdentifierType;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum Identifier {
|
|
||||||
User(IdentifierUser),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct IdentifierUser {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
_type: IdentifierType,
|
|
||||||
user: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl IdentifierUser {
|
|
||||||
/// Get a reference to the identifier user's user.
|
|
||||||
#[must_use]
|
|
||||||
pub fn user(&self) -> Option<&String> {
|
|
||||||
self.user.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,44 +0,0 @@
|
|||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum IdentifierType {
|
|
||||||
User,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl serde::Serialize for IdentifierType {
|
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: serde::Serializer,
|
|
||||||
{
|
|
||||||
serializer.serialize_str(match self {
|
|
||||||
IdentifierType::User => "m.id.user",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'de> serde::Deserialize<'de> for IdentifierType {
|
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
struct IdentifierVisitor;
|
|
||||||
|
|
||||||
impl<'de> serde::de::Visitor<'de> for IdentifierVisitor {
|
|
||||||
type Value = IdentifierType;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
formatter.write_str("Identifier")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
match v {
|
|
||||||
"m.id.user" => Ok(IdentifierType::User),
|
|
||||||
_ => Err(serde::de::Error::custom("Unknown identifier")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deserializer.deserialize_str(IdentifierVisitor {})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,52 +0,0 @@
|
|||||||
use std::fmt::Display;
|
|
||||||
|
|
||||||
#[derive(sqlx::Type)]
|
|
||||||
#[sqlx(transparent)]
|
|
||||||
#[repr(transparent)]
|
|
||||||
pub struct UserId(String);
|
|
||||||
|
|
||||||
impl UserId {
|
|
||||||
pub fn new(name: &str, server_name: &str) -> anyhow::Result<Self> {
|
|
||||||
let user_id = Self(format!("@{name}:{server_name}"));
|
|
||||||
|
|
||||||
user_id.is_valid()?;
|
|
||||||
|
|
||||||
Ok(user_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn local_part(&self) -> &str {
|
|
||||||
let col_idx = self.0.find(':').expect("Will always have at least one ':'");
|
|
||||||
self.0[1..col_idx].as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_valid(&self) -> anyhow::Result<()> {
|
|
||||||
(self.0.len() <= 255)
|
|
||||||
.then(|| ())
|
|
||||||
.ok_or(UserIdError::TooLong(self.0.len()))?;
|
|
||||||
|
|
||||||
let local_part = self.local_part();
|
|
||||||
local_part
|
|
||||||
.bytes()
|
|
||||||
.all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/'))
|
|
||||||
.then(|| ())
|
|
||||||
.ok_or(UserIdError::InvalidCharacters)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for UserId {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum UserIdError {
|
|
||||||
#[error("UserId too long {0} (expected < 255)")]
|
|
||||||
TooLong(usize),
|
|
||||||
#[error("Invalid character present in user id")]
|
|
||||||
InvalidCharacters,
|
|
||||||
#[error("Invalid UserId given")]
|
|
||||||
Invalid,
|
|
||||||
}
|
|
@ -1,7 +0,0 @@
|
|||||||
pub mod authentication_data;
|
|
||||||
pub mod flow;
|
|
||||||
pub mod identifier;
|
|
||||||
pub mod identifier_type;
|
|
||||||
pub mod matrix_user_id;
|
|
||||||
pub mod user_interactive_authorization;
|
|
||||||
pub mod error_code;
|
|
Reference in New Issue
Block a user