get back to where we were. but now with sea_orm and a more sane structure
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Patrick Michl 2022-07-14 22:29:35 +02:00
parent 35bde07b39
commit e33a734199
26 changed files with 126 additions and 312 deletions

11
Cargo.lock generated
View File

@ -1175,12 +1175,12 @@ name = "neo"
version = "0.1.0"
dependencies = [
"anyhow",
"argon2",
"axum",
"axum-macros",
"http",
"neo-entity",
"neo-migration",
"neo-util",
"rand 0.8.5",
"ruma",
"sea-orm",
@ -1213,6 +1213,15 @@ dependencies = [
"sea-orm-migration",
]
[[package]]
name = "neo-util"
version = "0.1.0"
dependencies = [
"anyhow",
"argon2",
"rand 0.8.5",
]
[[package]]
name = "nom"
version = "7.1.1"

View File

@ -3,5 +3,6 @@
members = [
"neo",
"neo-entity",
"neo-migration"
"neo-migration",
"neo-util"
]

View File

@ -1,4 +1,4 @@
pub mod users;
pub mod devices;
pub mod sessions;
pub mod prelude;
pub mod sessions;
pub mod users;

View File

@ -1,6 +1,6 @@
#[allow(unused_imports)]
pub use crate::{
devices::{self, Entity as Device},
sessions::{self, Entity as Session},
users::{self, Entity as User},
devices::{self, Entity as Device, Model as DeviceModel},
sessions::{self, Entity as Session, Model as SessionModel},
users::{self, Entity as User, Model as UserModel},
};

View File

@ -1,4 +1,4 @@
use sea_orm::entity::prelude::*;
use sea_orm::{entity::prelude::*, Set};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "sessions")]
@ -27,4 +27,11 @@ impl Related<super::devices::Entity> for Entity {
}
}
impl ActiveModelBehavior for ActiveModel {}
impl ActiveModelBehavior for ActiveModel {
fn new() -> Self {
Self {
uuid: Set(Uuid::new_v4()),
..ActiveModelTrait::default()
}
}
}

View File

@ -40,6 +40,7 @@ impl MigrationTrait for Migration {
manager
.create_index(
Index::create()
.name("user_id_index")
.table(User)
.col(users::Column::UserId)
.to_owned(),

View File

@ -40,6 +40,7 @@ impl MigrationTrait for Migration {
manager
.create_index(
Index::create()
.name("device_id_index")
.table(Device)
.col(devices::Column::DeviceId)
.to_owned(),
@ -48,6 +49,7 @@ impl MigrationTrait for Migration {
manager
.create_index(
Index::create()
.name("user_uuid_index")
.table(Device)
.col(devices::Column::UserUuid)
.to_owned(),

View File

@ -35,6 +35,7 @@ impl MigrationTrait for Migration {
manager
.create_index(
Index::create()
.name("device_uuid_index")
.table(Session)
.col(sessions::Column::DeviceUuid)
.to_owned(),

11
neo-util/Cargo.toml Normal file
View File

@ -0,0 +1,11 @@
[package]
name = "neo-util"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0"
argon2 = { version = "0.4", features = ["std"] }
rand = { version = "0.8.5", features = ["std"] }

1
neo-util/src/lib.rs Normal file
View File

@ -0,0 +1 @@
pub mod password;

17
neo-util/src/password.rs Normal file
View File

@ -0,0 +1,17 @@
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use rand::rngs::OsRng;
pub fn hash_password(password: &str) -> anyhow::Result<String> {
let argon2 = Argon2::default();
let salt = SaltString::generate(OsRng);
Ok(argon2
.hash_password(password.as_bytes(), &salt)?
.to_string())
}
pub fn password_correct(password: &str, hash: &str) -> anyhow::Result<bool> {
let password_hash = PasswordHash::new(hash)?;
Ok(Argon2::default()
.verify_password(password.as_bytes(), &password_hash)
.is_ok())
}

View File

@ -13,7 +13,6 @@ serde_json = "1.0"
tower-http = { version = "0.2", features = ["cors", "trace"]}
anyhow = "1.0"
thiserror = "1.0"
argon2 = { version = "0.4", features = ["std"] }
rand = { version = "0.8.5", features = ["std"] }
uuid = { version = "1.0", features = ["v4"] }
ruma = { version = "0.6.4", features = ["client-api", "compat"] }
@ -22,3 +21,4 @@ http = "0.2.8"
sea-orm = { version = "^0.8", features = ["sqlx-sqlite", "runtime-tokio-native-tls", "macros"], default-features = false }
neo-entity = { version = "*", path = "../neo-entity" }
neo-migration = { version = "*", path = "../neo-migration" }
neo-util = { version = "*", path = "../neo-util" }

View File

@ -5,13 +5,10 @@ use axum::{
routing::{get, post},
Extension,
};
// use neo_entity::{
// devices::{self, Entity as Device},
// sessions::{self, Entity as Session},
// users::{self, Entity as User},
// };
use neo_entity::prelude::*;
use neo_util::password;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use sea_orm::{ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set};
use crate::{
@ -68,27 +65,26 @@ async fn login(
.one(&db)
.await?
.ok_or(AuthenticationError::InvalidUserId)?;
// TODO: check password
//db_user
// .password_correct(&password)
// .map_err(|_| AuthenticationError::Forbidden)?;
if !password::password_correct(&password, &db_user.password_hash)
.map_err(|_| AuthenticationError::Forbidden)?
{
return Err(AuthenticationError::Forbidden.into());
}
let device = if let Some(device_id) = req.device_id {
Device::find()
.filter(
devices::Column::DeviceId
.eq(device_id.as_str())
.and(devices::Column::UserUuid.eq(db_user.uuid)),
)
.belongs_to(&db_user)
.filter(devices::Column::DeviceId.eq(device_id.as_str()))
.one(&db)
.await?
.unwrap()
//Device::find_for_user(&db, &db_user, device_id.as_str()).await?
} else {
let device_id = uuid::Uuid::new_v4().to_string();
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Generic Device".into());
let device = devices::ActiveModel {
device_id: Set(device_id),
display_name: Set(display_name),
@ -98,8 +94,15 @@ async fn login(
device.insert(&db).await?
};
let key = thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
let session = sessions::ActiveModel {
device_uuid: Set(device.uuid),
key: Set(key),
..Default::default()
};
let session = session.insert(&db).await?;
@ -164,21 +167,25 @@ async fn post_register(
return Err(AuthenticationError::InvalidUserId.into());
};
User::find()
if User::find()
.filter(users::Column::UserId.eq(user_id.as_str()))
.one(&db)
.await?
.ok_or(RegistrationError::UserIdTaken)?;
.is_some()
{
return Err(RegistrationError::UserIdTaken.into());
};
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Random Display Name".into());
// TODO: Hash password
let pw_hash = password::hash_password(&password)?;
let user = users::ActiveModel {
user_id: Set(user_id.to_string()),
display_name: Set(user_id.to_string()),
password_hash: Set(password),
password_hash: Set(pw_hash),
..Default::default()
}
.insert(&db)
@ -196,8 +203,15 @@ async fn post_register(
ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?,
);
if !req.inhibit_login {
let key = thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
let session = sessions::ActiveModel {
device_uuid: Set(device.uuid),
key: Set(key),
..Default::default()
}
.insert(&db)
@ -212,7 +226,7 @@ async fn post_register(
}
_ => todo!(),
},
// For clients not following using UIAA
// For clients not following/using UIAA
None => {
let password = req
.password
@ -225,20 +239,25 @@ async fn post_register(
return Err(AuthenticationError::InvalidUserId.into());
};
User::find()
if User::find()
.filter(users::Column::UserId.eq(user_id.as_str()))
.one(&db)
.await?
.ok_or(RegistrationError::UserIdTaken)?;
.is_some()
{
return Err(RegistrationError::UserIdTaken.into());
};
let display_name = req
.initial_device_display_name
.unwrap_or_else(|| "Random Display Name".into());
let pw_hash = password::hash_password(&password)?;
let user = users::ActiveModel {
user_id: Set(user_id.to_string()),
display_name: Set(user_id.to_string()),
password_hash: Set(password),
password_hash: Set(pw_hash),
..Default::default()
}
.insert(&db)
@ -255,8 +274,15 @@ async fn post_register(
let mut response =
Response::new(ruma::UserId::parse(&user.user_id).map_err(|e| anyhow::anyhow!(e))?);
if !req.inhibit_login {
let key = thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
let session = sessions::ActiveModel {
device_uuid: Set(device.uuid),
key: Set(key),
..Default::default()
}
.insert(&db)

View File

@ -4,8 +4,8 @@ use axum::routing::post;
use axum::Extension;
use crate::api::client_server::errors::api_error::ApiError;
use neo_entity::prelude::*;
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
use neo_entity::prelude::*;
use ruma::api::client::filter;
@ -16,7 +16,7 @@ pub fn routes() -> axum::Router {
}
async fn create_filter(
Extension(_user): Extension<Arc<User>>,
Extension(_user): Extension<Arc<UserModel>>,
RumaRequest(_req): RumaRequest<filter::create_filter::v3::IncomingRequest>,
) -> Result<RumaResponse<filter::create_filter::v3::Response>, ApiError> {
use filter::create_filter::v3::*;

View File

@ -4,8 +4,8 @@ use axum::routing::post;
use axum::Extension;
use crate::api::client_server::errors::api_error::ApiError;
use neo_entity::prelude::*;
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
use neo_entity::prelude::*;
use ruma::api::client::keys;
@ -16,7 +16,7 @@ pub fn routes() -> axum::Router {
}
async fn get_keys(
Extension(_user): Extension<Arc<User>>,
Extension(_user): Extension<Arc<UserModel>>,
RumaRequest(_req): RumaRequest<keys::get_keys::v3::IncomingRequest>,
) -> Result<RumaResponse<keys::get_keys::v3::Response>, ApiError> {
use keys::get_keys::v3::*;

View File

@ -11,7 +11,7 @@ use neo_entity::{
sessions::{self, Entity as Session},
users::Entity as User,
};
use sea_orm::{ColumnTrait, EntityTrait, ModelTrait, QueryFilter, DatabaseConnection};
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter};
use crate::types::error_code::ErrorCode;

View File

@ -4,8 +4,8 @@ use axum::routing::{get, put};
use axum::Extension;
use crate::api::client_server::errors::api_error::ApiError;
use neo_entity::prelude::*;
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
use neo_entity::prelude::*;
use ruma::api::client::presence;
use ruma::presence::PresenceState;
@ -18,7 +18,7 @@ pub fn routes() -> axum::Router {
}
async fn set_presence(
Extension(_user): Extension<Arc<User>>,
Extension(_user): Extension<Arc<UserModel>>,
RumaRequest(_req): RumaRequest<presence::set_presence::v3::IncomingRequest>,
) -> Result<RumaResponse<presence::set_presence::v3::Response>, ApiError> {
use presence::set_presence::v3::*;
@ -27,7 +27,7 @@ async fn set_presence(
}
async fn get_presence(
Extension(_user): Extension<Arc<User>>,
Extension(_user): Extension<Arc<UserModel>>,
RumaRequest(_req): RumaRequest<presence::get_presence::v3::IncomingRequest>,
) -> Result<RumaResponse<presence::get_presence::v3::Response>, ApiError> {
use presence::get_presence::v3::*;

View File

@ -4,8 +4,8 @@ use axum::routing::get;
use axum::Extension;
use crate::api::client_server::errors::api_error::ApiError;
use neo_entity::prelude::*;
use crate::ruma_wrapper::{RumaRequest, RumaResponse};
use neo_entity::prelude::*;
use ruma::api::client::push;
use ruma::push::Ruleset;
@ -17,8 +17,8 @@ pub fn routes() -> axum::Router {
}
async fn get_pushrules(
Extension(_user): Extension<Arc<User>>,
RumaRequest(_req): RumaRequest<push::get_pushrules_all::v3::IncomingRequest>
Extension(_user): Extension<Arc<UserModel>>,
RumaRequest(_req): RumaRequest<push::get_pushrules_all::v3::IncomingRequest>,
) -> Result<RumaResponse<push::get_pushrules_all::v3::Response>, ApiError> {
use push::get_pushrules_all::v3::*;

View File

@ -16,7 +16,7 @@ pub fn routes() -> axum::Router {
}
async fn sync_events(
Extension(_user): Extension<Arc<User>>,
Extension(_user): Extension<Arc<UserModel>>,
RumaRequest(req): RumaRequest<sync::sync_events::v3::IncomingRequest>,
) -> Result<RumaResponse<sync::sync_events::v3::Response>, ApiError> {
use sync::sync_events::v3::*;

View File

@ -4,10 +4,7 @@ use axum::{routing::get, Extension};
use neo_entity::prelude::*;
use crate::{
api::client_server::errors::api_error::ApiError,
ruma_wrapper::RumaResponse,
};
use crate::{api::client_server::errors::api_error::ApiError, ruma_wrapper::RumaResponse};
use ruma::api::client::thirdparty;
@ -18,7 +15,7 @@ pub fn routes() -> axum::Router {
}
async fn get_thirdparty_protocols(
Extension(_user): Extension<Arc<User>>,
Extension(_user): Extension<Arc<UserModel>>,
) -> Result<RumaResponse<thirdparty::get_protocols::v3::Response>, ApiError> {
Ok(RumaResponse(thirdparty::get_protocols::v3::Response::new(
BTreeMap::new(),

View File

@ -1,4 +1,4 @@
use ruma::{ServerName, OwnedServerName};
use ruma::{OwnedServerName, ServerName};
pub struct Config {
pub db_path: String,

View File

@ -7,13 +7,13 @@ use axum::{
Extension, Router,
};
use config::Config;
use neo_migration::{Migrator, MigratorTrait};
use sea_orm::Database;
use tower_http::{
cors::CorsLayer,
trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
};
use tracing::Level;
use neo_migration::{Migrator, MigratorTrait};
mod api;
mod config;
@ -24,7 +24,7 @@ mod types;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "info,sqlx=off");
std::env::set_var("RUST_LOG", "debug");
}
tracing_subscriber::fmt::init();

View File

@ -1,7 +1,4 @@
pub mod error_code;
//pub mod event_type;
pub mod flow;
pub mod server_name;
//pub mod user_id;
pub mod user_interactive_authorization;
//pub mod uuid;

View File

@ -1,127 +0,0 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
#[derive(Debug, PartialEq, Eq)]
pub enum Hostname {
IPv4(Ipv4Addr),
IPv6(Ipv6Addr),
Fqdn(String),
}
pub struct ServerName {
pub hostname: Hostname,
pub port: Option<u16>,
}
impl ServerName {
pub fn new(server_name: &str) -> anyhow::Result<Self> {
if let Ok(addr) = server_name.parse::<Ipv4Addr>() {
return Ok(Self {
hostname: Hostname::IPv4(addr),
port: None,
});
};
if let Ok(addr) = server_name.parse::<SocketAddrV4>() {
return Ok(Self {
hostname: Hostname::IPv4(addr.ip().to_owned()),
port: Some(addr.port()),
});
};
if server_name.ends_with(']') {
if let Ok(addr) = server_name[1..server_name.chars().count() - 1].parse::<Ipv6Addr>() {
return Ok(Self {
hostname: Hostname::IPv6(addr),
port: None,
});
};
}
if let Ok(addr) = server_name.parse::<SocketAddrV6>() {
return Ok(Self {
hostname: Hostname::IPv6(addr.ip().to_owned()),
port: Some(addr.port()),
});
};
if let Some(idx) = server_name.find(':') {
let (hostname, port) = server_name.split_at(idx);
Ok(Self {
hostname: Hostname::Fqdn(hostname.to_owned()),
port: Some(port[1..].parse()?),
})
} else {
Ok(Self {
hostname: Hostname::Fqdn(server_name.to_owned()),
port: None,
})
}
}
}
impl std::fmt::Display for ServerName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.hostname {
Hostname::IPv4(hostname) => write!(f, "{hostname}"),
Hostname::IPv6(hostname) => write!(f, "[{hostname}]"),
Hostname::Fqdn(hostname) => write!(f, "{hostname}"),
}?;
if let Some(port) = self.port {
write!(f, ":{port}")?;
}
Ok(())
}
}
mod tests {
use super::*;
#[test]
fn parse_ipv4_without_port() {
let server_name = ServerName::new("127.0.0.1").unwrap();
assert_eq!(
server_name.hostname,
Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port, None);
}
#[test]
fn parse_ipv4_with_port() {
let server_name = ServerName::new("127.0.0.1:8080").unwrap();
assert_eq!(
server_name.hostname,
Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port, Some(8080));
}
#[test]
fn parse_ipv6_without_port() {
let server_name = ServerName::new("[::1]").unwrap();
assert_eq!(server_name.hostname, Hostname::IPv6("::1".parse().unwrap()));
assert_eq!(server_name.port, None);
}
#[test]
fn parse_ipv6_with_port() {
let server_name = ServerName::new("[::1]:8080").unwrap();
assert_eq!(server_name.hostname, Hostname::IPv6("::1".parse().unwrap()));
assert_eq!(server_name.port, Some(8080));
}
#[test]
fn parse_fqdn_without_port() {
let server_name = ServerName::new("example.com").unwrap();
assert_eq!(server_name.hostname, Hostname::Fqdn("example.com".into()));
assert_eq!(server_name.port, None);
}
#[test]
fn parse_fqdn_with_port() {
let server_name = ServerName::new("example.com:8080").unwrap();
assert_eq!(server_name.hostname, Hostname::Fqdn("example.com".into()));
assert_eq!(server_name.port, Some(8080));
}
}

View File

@ -1,84 +0,0 @@
use std::fmt::Display;
use sqlx::{encode::IsNull, Sqlite};
use super::server_name::ServerName;
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct UserId(String);
impl sqlx::Type<Sqlite> for UserId {
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
<&str as sqlx::Type<Sqlite>>::type_info()
}
}
impl<'e> sqlx::Encode<'e, Sqlite> for UserId {
fn encode_by_ref(
&self,
buf: &mut <Sqlite as sqlx::database::HasArguments<'e>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
self.0.to_string().into(),
));
IsNull::No
}
}
impl<'d> sqlx::Decode<'d, Sqlite> for UserId {
fn decode(
value: <Sqlite as sqlx::database::HasValueRef<'d>>::ValueRef,
) -> Result<Self, sqlx::error::BoxDynError> {
let value = <String as sqlx::Decode<Sqlite>>::decode(value)?;
Ok(UserId(value))
}
}
impl UserId {
pub fn new(name: &str, server_name: &ServerName) -> anyhow::Result<Self> {
let user_id = Self(format!("@{name}:{server_name}"));
user_id.is_valid()?;
Ok(user_id)
}
pub fn local_part(&self) -> &str {
let col_idx = self.0.find(':').expect("Will always have at least one ':'");
self.0[1..col_idx].as_ref()
}
fn is_valid(&self) -> anyhow::Result<()> {
(self.0.len() <= 255)
.then(|| ())
.ok_or(UserIdError::TooLong(self.0.len()))?;
let local_part = self.local_part();
local_part
.bytes()
.all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/'))
.then(|| ())
.ok_or(UserIdError::InvalidCharacters)?;
Ok(())
}
}
impl Display for UserId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, thiserror::Error)]
pub enum UserIdError {
#[error("UserId too long {0} (expected < 255)")]
TooLong(usize),
#[error("Invalid character present in user id")]
InvalidCharacters,
#[error("Invalid UserId given")]
Invalid,
}

View File

@ -1,45 +0,0 @@
use sqlx::{encode::IsNull, Sqlite, Type};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Uuid(pub uuid::Uuid);
impl Uuid {
pub fn new_v4() -> Self {
Uuid(uuid::Uuid::new_v4())
}
}
impl Type<Sqlite> for Uuid {
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
<&str as Type<Sqlite>>::type_info()
}
}
impl<'e> sqlx::Encode<'e, Sqlite> for Uuid {
fn encode_by_ref(
&self,
buf: &mut <Sqlite as sqlx::database::HasArguments<'e>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
buf.push(sqlx::sqlite::SqliteArgumentValue::Text(
self.0.to_string().into(),
));
IsNull::No
}
}
impl<'d> sqlx::Decode<'d, Sqlite> for Uuid {
fn decode(
value: <Sqlite as sqlx::database::HasValueRef<'d>>::ValueRef,
) -> Result<Self, sqlx::error::BoxDynError> {
let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?;
Ok(Uuid(uuid::Uuid::parse_str(value)?))
}
}
impl From<uuid::Uuid> for Uuid {
fn from(uuid: uuid::Uuid) -> Self {
Uuid(uuid)
}
}