get back to where we were. but now with sea_orm and a more sane structure
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
35bde07b39
commit
e33a734199
11
Cargo.lock
generated
11
Cargo.lock
generated
@ -1175,12 +1175,12 @@ name = "neo"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"argon2",
|
||||
"axum",
|
||||
"axum-macros",
|
||||
"http",
|
||||
"neo-entity",
|
||||
"neo-migration",
|
||||
"neo-util",
|
||||
"rand 0.8.5",
|
||||
"ruma",
|
||||
"sea-orm",
|
||||
@ -1213,6 +1213,15 @@ dependencies = [
|
||||
"sea-orm-migration",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "neo-util"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"argon2",
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.1"
|
||||
|
@ -3,5 +3,6 @@
|
||||
members = [
|
||||
"neo",
|
||||
"neo-entity",
|
||||
"neo-migration"
|
||||
"neo-migration",
|
||||
"neo-util"
|
||||
]
|
||||
|
@ -1,4 +1,4 @@
|
||||
pub mod users;
|
||||
pub mod devices;
|
||||
pub mod sessions;
|
||||
pub mod prelude;
|
||||
pub mod sessions;
|
||||
pub mod users;
|
||||
|
@ -1,6 +1,6 @@
|
||||
#[allow(unused_imports)]
|
||||
pub use crate::{
|
||||
devices::{self, Entity as Device},
|
||||
sessions::{self, Entity as Session},
|
||||
users::{self, Entity as User},
|
||||
devices::{self, Entity as Device, Model as DeviceModel},
|
||||
sessions::{self, Entity as Session, Model as SessionModel},
|
||||
users::{self, Entity as User, Model as UserModel},
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use sea_orm::{entity::prelude::*, Set};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "sessions")]
|
||||
@ -27,4 +27,11 @@ impl Related<super::devices::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
uuid: Set(Uuid::new_v4()),
|
||||
..ActiveModelTrait::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ impl MigrationTrait for Migration {
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("user_id_index")
|
||||
.table(User)
|
||||
.col(users::Column::UserId)
|
||||
.to_owned(),
|
||||
|
@ -40,6 +40,7 @@ impl MigrationTrait for Migration {
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("device_id_index")
|
||||
.table(Device)
|
||||
.col(devices::Column::DeviceId)
|
||||
.to_owned(),
|
||||
@ -48,6 +49,7 @@ impl MigrationTrait for Migration {
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("user_uuid_index")
|
||||
.table(Device)
|
||||
.col(devices::Column::UserUuid)
|
||||
.to_owned(),
|
||||
|
@ -35,6 +35,7 @@ impl MigrationTrait for Migration {
|
||||
manager
|
||||
.create_index(
|
||||
Index::create()
|
||||
.name("device_uuid_index")
|
||||
.table(Session)
|
||||
.col(sessions::Column::DeviceUuid)
|
||||
.to_owned(),
|
||||
|
11
neo-util/Cargo.toml
Normal file
11
neo-util/Cargo.toml
Normal file
@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "neo-util"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
argon2 = { version = "0.4", features = ["std"] }
|
||||
rand = { version = "0.8.5", features = ["std"] }
|
1
neo-util/src/lib.rs
Normal file
1
neo-util/src/lib.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod password;
|
17
neo-util/src/password.rs
Normal file
17
neo-util/src/password.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
pub fn hash_password(password: &str) -> anyhow::Result<String> {
|
||||
let argon2 = Argon2::default();
|
||||
let salt = SaltString::generate(OsRng);
|
||||
Ok(argon2
|
||||
.hash_password(password.as_bytes(), &salt)?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
pub fn password_correct(password: &str, hash: &str) -> anyhow::Result<bool> {
|
||||
let password_hash = PasswordHash::new(hash)?;
|
||||
Ok(Argon2::default()
|
||||
.verify_password(password.as_bytes(), &password_hash)
|
||||
.is_ok())
|
||||
}
|
@ -13,7 +13,6 @@ serde_json = "1.0"
|
||||
tower-http = { version = "0.2", features = ["cors", "trace"]}
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
argon2 = { version = "0.4", features = ["std"] }
|
||||
rand = { version = "0.8.5", features = ["std"] }
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
ruma = { version = "0.6.4", features = ["client-api", "compat"] }
|
||||
@ -22,3 +21,4 @@ http = "0.2.8"
|
||||
sea-orm = { version = "^0.8", features = ["sqlx-sqlite", "runtime-tokio-native-tls", "macros"], default-features = false }
|
||||
neo-entity = { version = "*", path = "../neo-entity" }
|
||||
neo-migration = { version = "*", path = "../neo-migration" }
|
||||
neo-util = { version = "*", path = "../neo-util" }
|
||||
|
@ -5,13 +5,10 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Extension,
|
||||
};
|
||||
// use neo_entity::{
|
||||
// devices::{self, Entity as Device},
|
||||
// sessions::{self, Entity as Session},
|
||||
// users::{self, Entity as User},
|
||||
// };
|
||||
|
||||
use neo_entity::prelude::*;
|
||||
use neo_util::password;
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use sea_orm::{ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set};
|
||||
|
||||
use crate::{
|
||||
@ -68,27 +65,26 @@ async fn login(
|
||||
.one(&db)
|
||||
.await?
|
||||
.ok_or(AuthenticationError::InvalidUserId)?;
|
||||
// TODO: check password
|
||||
//db_user
|
||||
// .password_correct(&password)
|
||||
// .map_err(|_| AuthenticationError::Forbidden)?;
|
||||
|
||||
if !password::password_correct(&password, &db_user.password_hash)
|
||||
.map_err(|_| AuthenticationError::Forbidden)?
|
||||
{
|
||||
return Err(AuthenticationError::Forbidden.into());
|
||||
}
|
||||
|
||||
let device = if let Some(device_id) = req.device_id {
|
||||
Device::find()
|
||||
.filter(
|
||||
devices::Column::DeviceId
|
||||
.eq(device_id.as_str())
|
||||
.and(devices::Column::UserUuid.eq(db_user.uuid)),
|
||||
)
|
||||
.belongs_to(&db_user)
|
||||
.filter(devices::Column::DeviceId.eq(device_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.unwrap()
|
||||
//Device::find_for_user(&db, &db_user, device_id.as_str()).await?
|
||||
} else {
|
||||
let device_id = uuid::Uuid::new_v4().to_string();
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Generic Device".into());
|
||||
|
||||
let device = devices::ActiveModel {
|
||||
device_id: Set(device_id),
|
||||
display_name: Set(display_name),
|
||||
@ -98,8 +94,15 @@ async fn login(
|
||||
device.insert(&db).await?
|
||||
};
|
||||
|
||||
let key = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let session = sessions::ActiveModel {
|
||||
device_uuid: Set(device.uuid),
|
||||
key: Set(key),
|
||||
..Default::default()
|
||||
};
|
||||
let session = session.insert(&db).await?;
|
||||
@ -164,21 +167,25 @@ async fn post_register(
|
||||
return Err(AuthenticationError::InvalidUserId.into());
|
||||
};
|
||||
|
||||
User::find()
|
||||
if User::find()
|
||||
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.ok_or(RegistrationError::UserIdTaken)?;
|
||||
.is_some()
|
||||
{
|
||||
return Err(RegistrationError::UserIdTaken.into());
|
||||
};
|
||||
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Random Display Name".into());
|
||||
|
||||
// TODO: Hash password
|
||||
let pw_hash = password::hash_password(&password)?;
|
||||
|
||||
let user = users::ActiveModel {
|
||||
user_id: Set(user_id.to_string()),
|
||||
display_name: Set(user_id.to_string()),
|
||||
password_hash: Set(password),
|
||||
password_hash: Set(pw_hash),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
@ -196,8 +203,15 @@ async fn post_register(
|
||||
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
|
||||
);
|
||||
if !req.inhibit_login {
|
||||
let key = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let session = sessions::ActiveModel {
|
||||
device_uuid: Set(device.uuid),
|
||||
key: Set(key),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
@ -212,7 +226,7 @@ async fn post_register(
|
||||
}
|
||||
_ => todo!(),
|
||||
},
|
||||
// For clients not following using UIAA
|
||||
// For clients not following/using UIAA
|
||||
None => {
|
||||
let password = req
|
||||
.password
|
||||
@ -225,20 +239,25 @@ async fn post_register(
|
||||
return Err(AuthenticationError::InvalidUserId.into());
|
||||
};
|
||||
|
||||
User::find()
|
||||
if User::find()
|
||||
.filter(users::Column::UserId.eq(user_id.as_str()))
|
||||
.one(&db)
|
||||
.await?
|
||||
.ok_or(RegistrationError::UserIdTaken)?;
|
||||
.is_some()
|
||||
{
|
||||
return Err(RegistrationError::UserIdTaken.into());
|
||||
};
|
||||
|
||||
let display_name = req
|
||||
.initial_device_display_name
|
||||
.unwrap_or_else(|| "Random Display Name".into());
|
||||
|
||||
let pw_hash = password::hash_password(&password)?;
|
||||
|
||||
let user = users::ActiveModel {
|
||||
user_id: Set(user_id.to_string()),
|
||||
display_name: Set(user_id.to_string()),
|
||||
password_hash: Set(password),
|
||||
password_hash: Set(pw_hash),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
@ -255,8 +274,15 @@ async fn post_register(
|
||||
let mut response =
|
||||
Response::new(ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?);
|
||||
if !req.inhibit_login {
|
||||
let key = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let session = sessions::ActiveModel {
|
||||
device_uuid: Set(device.uuid),
|
||||
key: Set(key),
|
||||
..Default::default()
|
||||
}
|
||||
.insert(&db)
|
||||
|
@ -4,8 +4,8 @@ use axum::routing::post;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use neo_entity::prelude::*;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::filter;
|
||||
|
||||
@ -16,7 +16,7 @@ pub fn routes() -> axum::Router {
|
||||
}
|
||||
|
||||
async fn create_filter(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<filter::create_filter::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<filter::create_filter::v3::Response>, ApiError> {
|
||||
use filter::create_filter::v3::*;
|
||||
|
@ -4,8 +4,8 @@ use axum::routing::post;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use neo_entity::prelude::*;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::keys;
|
||||
|
||||
@ -16,7 +16,7 @@ pub fn routes() -> axum::Router {
|
||||
}
|
||||
|
||||
async fn get_keys(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<keys::get_keys::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<keys::get_keys::v3::Response>, ApiError> {
|
||||
use keys::get_keys::v3::*;
|
||||
|
@ -11,7 +11,7 @@ use neo_entity::{
|
||||
sessions::{self, Entity as Session},
|
||||
users::Entity as User,
|
||||
};
|
||||
use sea_orm::{ColumnTrait, EntityTrait, ModelTrait, QueryFilter, DatabaseConnection};
|
||||
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter};
|
||||
|
||||
use crate::types::error_code::ErrorCode;
|
||||
|
||||
|
@ -4,8 +4,8 @@ use axum::routing::{get, put};
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use neo_entity::prelude::*;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::presence;
|
||||
use ruma::presence::PresenceState;
|
||||
@ -18,7 +18,7 @@ pub fn routes() -> axum::Router {
|
||||
}
|
||||
|
||||
async fn set_presence(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<presence::set_presence::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<presence::set_presence::v3::Response>, ApiError> {
|
||||
use presence::set_presence::v3::*;
|
||||
@ -27,7 +27,7 @@ async fn set_presence(
|
||||
}
|
||||
|
||||
async fn get_presence(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<presence::get_presence::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<presence::get_presence::v3::Response>, ApiError> {
|
||||
use presence::get_presence::v3::*;
|
||||
|
@ -4,8 +4,8 @@ use axum::routing::get;
|
||||
use axum::Extension;
|
||||
|
||||
use crate::api::client_server::errors::api_error::ApiError;
|
||||
use neo_entity::prelude::*;
|
||||
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use ruma::api::client::push;
|
||||
use ruma::push::Ruleset;
|
||||
@ -17,8 +17,8 @@ pub fn routes() -> axum::Router {
|
||||
}
|
||||
|
||||
async fn get_pushrules(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
RumaRequest(_req): RumaRequest<push::get_pushrules_all::v3::IncomingRequest>
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(_req): RumaRequest<push::get_pushrules_all::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<push::get_pushrules_all::v3::Response>, ApiError> {
|
||||
use push::get_pushrules_all::v3::*;
|
||||
|
||||
|
@ -16,7 +16,7 @@ pub fn routes() -> axum::Router {
|
||||
}
|
||||
|
||||
async fn sync_events(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
RumaRequest(req): RumaRequest<sync::sync_events::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<sync::sync_events::v3::Response>, ApiError> {
|
||||
use sync::sync_events::v3::*;
|
||||
|
@ -4,10 +4,7 @@ use axum::{routing::get, Extension};
|
||||
|
||||
use neo_entity::prelude::*;
|
||||
|
||||
use crate::{
|
||||
api::client_server::errors::api_error::ApiError,
|
||||
ruma_wrapper::RumaResponse,
|
||||
};
|
||||
use crate::{api::client_server::errors::api_error::ApiError, ruma_wrapper::RumaResponse};
|
||||
|
||||
use ruma::api::client::thirdparty;
|
||||
|
||||
@ -18,7 +15,7 @@ pub fn routes() -> axum::Router {
|
||||
}
|
||||
|
||||
async fn get_thirdparty_protocols(
|
||||
Extension(_user): Extension<Arc<User>>,
|
||||
Extension(_user): Extension<Arc<UserModel>>,
|
||||
) -> Result<RumaResponse<thirdparty::get_protocols::v3::Response>, ApiError> {
|
||||
Ok(RumaResponse(thirdparty::get_protocols::v3::Response::new(
|
||||
BTreeMap::new(),
|
||||
|
@ -1,4 +1,4 @@
|
||||
use ruma::{ServerName, OwnedServerName};
|
||||
use ruma::{OwnedServerName, ServerName};
|
||||
|
||||
pub struct Config {
|
||||
pub db_path: String,
|
||||
|
@ -7,13 +7,13 @@ use axum::{
|
||||
Extension, Router,
|
||||
};
|
||||
use config::Config;
|
||||
use neo_migration::{Migrator, MigratorTrait};
|
||||
use sea_orm::Database;
|
||||
use tower_http::{
|
||||
cors::CorsLayer,
|
||||
trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
};
|
||||
use tracing::Level;
|
||||
use neo_migration::{Migrator, MigratorTrait};
|
||||
|
||||
mod api;
|
||||
mod config;
|
||||
@ -24,7 +24,7 @@ mod types;
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
if std::env::var("RUST_LOG").is_err() {
|
||||
std::env::set_var("RUST_LOG", "info,sqlx=off");
|
||||
std::env::set_var("RUST_LOG", "debug");
|
||||
}
|
||||
|
||||
tracing_subscriber::fmt::init();
|
||||
|
@ -1,7 +1,4 @@
|
||||
pub mod error_code;
|
||||
//pub mod event_type;
|
||||
pub mod flow;
|
||||
pub mod server_name;
|
||||
//pub mod user_id;
|
||||
pub mod user_interactive_authorization;
|
||||
//pub mod uuid;
|
||||
|
@ -1,127 +0,0 @@
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Hostname {
|
||||
IPv4(Ipv4Addr),
|
||||
IPv6(Ipv6Addr),
|
||||
Fqdn(String),
|
||||
}
|
||||
|
||||
pub struct ServerName {
|
||||
pub hostname: Hostname,
|
||||
pub port: Option<u16>,
|
||||
}
|
||||
|
||||
impl ServerName {
|
||||
pub fn new(server_name: &str) -> anyhow::Result<Self> {
|
||||
if let Ok(addr) = server_name.parse::<Ipv4Addr>() {
|
||||
return Ok(Self {
|
||||
hostname: Hostname::IPv4(addr),
|
||||
port: None,
|
||||
});
|
||||
};
|
||||
|
||||
if let Ok(addr) = server_name.parse::<SocketAddrV4>() {
|
||||
return Ok(Self {
|
||||
hostname: Hostname::IPv4(addr.ip().to_owned()),
|
||||
port: Some(addr.port()),
|
||||
});
|
||||
};
|
||||
|
||||
if server_name.ends_with(']') {
|
||||
if let Ok(addr) = server_name[1..server_name.chars().count() - 1].parse::<Ipv6Addr>() {
|
||||
return Ok(Self {
|
||||
hostname: Hostname::IPv6(addr),
|
||||
port: None,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
if let Ok(addr) = server_name.parse::<SocketAddrV6>() {
|
||||
return Ok(Self {
|
||||
hostname: Hostname::IPv6(addr.ip().to_owned()),
|
||||
port: Some(addr.port()),
|
||||
});
|
||||
};
|
||||
|
||||
if let Some(idx) = server_name.find(':') {
|
||||
let (hostname, port) = server_name.split_at(idx);
|
||||
Ok(Self {
|
||||
hostname: Hostname::Fqdn(hostname.to_owned()),
|
||||
port: Some(port[1..].parse()?),
|
||||
})
|
||||
} else {
|
||||
Ok(Self {
|
||||
hostname: Hostname::Fqdn(server_name.to_owned()),
|
||||
port: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ServerName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.hostname {
|
||||
Hostname::IPv4(hostname) => write!(f, "{hostname}"),
|
||||
Hostname::IPv6(hostname) => write!(f, "[{hostname}]"),
|
||||
Hostname::Fqdn(hostname) => write!(f, "{hostname}"),
|
||||
}?;
|
||||
|
||||
if let Some(port) = self.port {
|
||||
write!(f, ":{port}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_ipv4_without_port() {
|
||||
let server_name = ServerName::new("127.0.0.1").unwrap();
|
||||
assert_eq!(
|
||||
server_name.hostname,
|
||||
Hostname::IPv4("127.0.0.1".parse().unwrap())
|
||||
);
|
||||
assert_eq!(server_name.port, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ipv4_with_port() {
|
||||
let server_name = ServerName::new("127.0.0.1:8080").unwrap();
|
||||
assert_eq!(
|
||||
server_name.hostname,
|
||||
Hostname::IPv4("127.0.0.1".parse().unwrap())
|
||||
);
|
||||
assert_eq!(server_name.port, Some(8080));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ipv6_without_port() {
|
||||
let server_name = ServerName::new("[::1]").unwrap();
|
||||
assert_eq!(server_name.hostname, Hostname::IPv6("::1".parse().unwrap()));
|
||||
assert_eq!(server_name.port, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_ipv6_with_port() {
|
||||
let server_name = ServerName::new("[::1]:8080").unwrap();
|
||||
assert_eq!(server_name.hostname, Hostname::IPv6("::1".parse().unwrap()));
|
||||
assert_eq!(server_name.port, Some(8080));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fqdn_without_port() {
|
||||
let server_name = ServerName::new("example.com").unwrap();
|
||||
assert_eq!(server_name.hostname, Hostname::Fqdn("example.com".into()));
|
||||
assert_eq!(server_name.port, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fqdn_with_port() {
|
||||
let server_name = ServerName::new("example.com:8080").unwrap();
|
||||
assert_eq!(server_name.hostname, Hostname::Fqdn("example.com".into()));
|
||||
assert_eq!(server_name.port, Some(8080));
|
||||
}
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use sqlx::{encode::IsNull, Sqlite};
|
||||
|
||||
use super::server_name::ServerName;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct UserId(String);
|
||||
|
||||
impl sqlx::Type<Sqlite> for UserId {
|
||||
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
|
||||
<&str as sqlx::Type<Sqlite>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'e> sqlx::Encode<'e, Sqlite> for UserId {
|
||||
fn encode_by_ref(
|
||||
&self,
|
||||
buf: &mut <Sqlite as sqlx::database::HasArguments<'e>>::ArgumentBuffer,
|
||||
) -> sqlx::encode::IsNull {
|
||||
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
|
||||
self.0.to_string().into(),
|
||||
));
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'d> sqlx::Decode<'d, Sqlite> for UserId {
|
||||
fn decode(
|
||||
value: <Sqlite as sqlx::database::HasValueRef<'d>>::ValueRef,
|
||||
) -> Result<Self, sqlx::error::BoxDynError> {
|
||||
let value = <String as sqlx::Decode<Sqlite>>::decode(value)?;
|
||||
|
||||
Ok(UserId(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl UserId {
|
||||
pub fn new(name: &str, server_name: &ServerName) -> anyhow::Result<Self> {
|
||||
let user_id = Self(format!("@{name}:{server_name}"));
|
||||
|
||||
user_id.is_valid()?;
|
||||
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
pub fn local_part(&self) -> &str {
|
||||
let col_idx = self.0.find(':').expect("Will always have at least one ':'");
|
||||
self.0[1..col_idx].as_ref()
|
||||
}
|
||||
|
||||
fn is_valid(&self) -> anyhow::Result<()> {
|
||||
(self.0.len() <= 255)
|
||||
.then(|| ())
|
||||
.ok_or(UserIdError::TooLong(self.0.len()))?;
|
||||
|
||||
let local_part = self.local_part();
|
||||
local_part
|
||||
.bytes()
|
||||
.all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/'))
|
||||
.then(|| ())
|
||||
.ok_or(UserIdError::InvalidCharacters)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for UserId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UserIdError {
|
||||
#[error("UserId too long {0} (expected < 255)")]
|
||||
TooLong(usize),
|
||||
#[error("Invalid character present in user id")]
|
||||
InvalidCharacters,
|
||||
#[error("Invalid UserId given")]
|
||||
Invalid,
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
use sqlx::{encode::IsNull, Sqlite, Type};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct Uuid(pub uuid::Uuid);
|
||||
|
||||
impl Uuid {
|
||||
pub fn new_v4() -> Self {
|
||||
Uuid(uuid::Uuid::new_v4())
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Sqlite> for Uuid {
|
||||
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
|
||||
<&str as Type<Sqlite>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'e> sqlx::Encode<'e, Sqlite> for Uuid {
|
||||
fn encode_by_ref(
|
||||
&self,
|
||||
buf: &mut <Sqlite as sqlx::database::HasArguments<'e>>::ArgumentBuffer,
|
||||
) -> sqlx::encode::IsNull {
|
||||
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
|
||||
self.0.to_string().into(),
|
||||
));
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'d> sqlx::Decode<'d, Sqlite> for Uuid {
|
||||
fn decode(
|
||||
value: <Sqlite as sqlx::database::HasValueRef<'d>>::ValueRef,
|
||||
) -> Result<Self, sqlx::error::BoxDynError> {
|
||||
let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?;
|
||||
|
||||
Ok(Uuid(uuid::Uuid::parse_str(value)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<uuid::Uuid> for Uuid {
|
||||
fn from(uuid: uuid::Uuid) -> Self {
|
||||
Uuid(uuid)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user