finish login route
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Patrick Michl 2022-05-01 21:11:06 +02:00
parent c20b4c6a23
commit 2c2ac27c26
16 changed files with 277 additions and 27 deletions

View File

@ -9,3 +9,9 @@ steps:
SQLX_OFFLINE: 'true' SQLX_OFFLINE: 'true'
commands: commands:
- cargo check - cargo check
- name: cargo test
image: rust:latest
environment:
SQLX_OFFLINE: 'true'
commands:
- cargo test

29
Cargo.lock generated
View File

@ -604,7 +604,7 @@ dependencies = [
"anyhow", "anyhow",
"argon2", "argon2",
"axum", "axum",
"rand_core", "rand",
"serde", "serde",
"sqlx", "sqlx",
"thiserror", "thiserror",
@ -809,6 +809,12 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
[[package]]
name = "ppv-lite86"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.37" version = "1.0.37"
@ -827,6 +833,27 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]] [[package]]
name = "rand_core" name = "rand_core"
version = "0.6.3" version = "0.6.3"

View File

@ -16,5 +16,5 @@ sqlx = { version = "0.5", features = ["sqlite", "macros", "runtime-tokio-rustls"
anyhow = "1.0" anyhow = "1.0"
thiserror = "1.0" thiserror = "1.0"
argon2 = { version = "0.4", features = ["std"] } argon2 = { version = "0.4", features = ["std"] }
rand_core = { version = "0.6", features = ["std"] } rand = { version = "0.8.5", features = ["std"] }
uuid = { version = "1.0", features = ["v4"] } uuid = { version = "1.0", features = ["v4"] }

View File

@ -14,6 +14,8 @@ pub enum RegistrationError {
InvalidUserId, InvalidUserId,
#[error("The desired user ID is already taken")] #[error("The desired user ID is already taken")]
UserIdTaken, UserIdTaken,
#[error("Registration is disabled")]
RegistrationDisabled,
} }
impl IntoResponse for RegistrationError { impl IntoResponse for RegistrationError {
@ -42,6 +44,15 @@ impl IntoResponse for RegistrationError {
)), )),
) )
.into_response(), .into_response(),
RegistrationError::RegistrationDisabled => (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(
ErrorCode::Forbidden,
&self.to_string(),
None,
)),
)
.into_response(),
} }
} }
} }

View File

@ -1,3 +1,3 @@
pub mod auth;
pub mod errors; pub mod errors;
pub mod r0;
pub mod versions; pub mod versions;

View File

@ -7,16 +7,19 @@ use axum::{
routing::{get, post}, routing::{get, post},
Extension, Json, Extension, Json,
}; };
use rand_core::OsRng;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use crate::{ use crate::{
api::client_server::errors::authentication_error::AuthenticationError, api::client_server::errors::{
api_error::ApiError, authentication_error::AuthenticationError,
registration_error::RegistrationError,
},
models::sessions::Session, models::sessions::Session,
responses::{ responses::{
authentication::{AuthenticationResponse, AuthenticationSuccess}, authentication::{AuthenticationResponse, AuthenticationSuccess},
registration::RegistrationResponse, registration::RegistrationResponse,
}, },
types::uuid::Uuid,
}; };
use crate::{ use crate::{
models::devices::Device, models::devices::Device,
@ -30,8 +33,6 @@ use crate::{
Config, Config,
}; };
use super::errors::{api_error::ApiError, registration_error::RegistrationError};
pub fn routes() -> axum::Router { pub fn routes() -> axum::Router {
axum::Router::new() axum::Router::new()
.route("/r0/login", get(get_login).post(post_login)) .route("/r0/login", get(get_login).post(post_login))
@ -53,17 +54,34 @@ async fn post_login(
match body { match body {
AuthenticationData::Password(auth_data) => { AuthenticationData::Password(auth_data) => {
let user = auth_data.user().unwrap(); let user = auth_data.user().unwrap();
let user_id = UserId::new(&user, config.homeserver_name()) let user_id = UserId::new(&user, config.server_name())
.ok() .ok()
.ok_or(AuthenticationError::InvalidUserId)?; .ok_or(AuthenticationError::InvalidUserId)?;
let user = User::find_by_user_id(&db, &user_id).await?; let user = User::find_by_user_id(&db, &user_id).await?;
user.password_correct(auth_data.password()).ok().ok_or(AuthenticationError::Forbidden)?; user.password_correct(auth_data.password())
.ok()
.ok_or(AuthenticationError::Forbidden)?;
todo!("find_or_create device for user and create a session"); let device = if let Some(device_id) = auth_data.device_id() {
Device::find_for_user(&db, &user, device_id).await?
} else {
let device_id = uuid::Uuid::new_v4().to_string();
let display_name =
if let Some(display_name) = auth_data.initial_device_display_name() {
display_name.as_ref()
} else {
"Generic Device"
};
Device::new(&user, &device_id, display_name)?
.create(&db)
.await?
};
let resp = AuthenticationSuccess::new("", "", &user_id); let session = Session::new(&device)?.create(&db).await?;
let resp = AuthenticationSuccess::new(session.key(), device.device_id(), &user_id);
Ok(Json(AuthenticationResponse::Success(resp))) Ok(Json(AuthenticationResponse::Success(resp)))
} }
} }
@ -78,7 +96,7 @@ async fn get_username_available(
let username = params let username = params
.get("username") .get("username")
.ok_or(RegistrationError::MissingUserId)?; .ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.homeserver_name()) let user_id = UserId::new(username, &config.server_name())
.ok() .ok()
.ok_or(RegistrationError::InvalidUserId)?; .ok_or(RegistrationError::InvalidUserId)?;
let exists = User::exists(&db, &user_id).await?; let exists = User::exists(&db, &user_id).await?;
@ -92,13 +110,14 @@ async fn post_register(
Extension(db): Extension<SqlitePool>, Extension(db): Extension<SqlitePool>,
Json(body): Json<RegistrationRequest>, Json(body): Json<RegistrationRequest>,
) -> Result<Json<RegistrationResponse>, ApiError> { ) -> Result<Json<RegistrationResponse>, ApiError> {
config.enable_registration().then(|| true).ok_or(RegistrationError::RegistrationDisabled)?;
body.auth() body.auth()
.ok_or(RegistrationError::AdditionalAuthenticationInformation)?; .ok_or(RegistrationError::AdditionalAuthenticationInformation)?;
let (user, device) = match &body.auth().expect("must be Some") { let (user, device) = match &body.auth().expect("must be Some") {
AuthenticationData::Password(auth_data) => { AuthenticationData::Password(auth_data) => {
let username = body.username().ok_or(RegistrationError::MissingUserId)?; let username = body.username().ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.homeserver_name()) let user_id = UserId::new(username, &config.server_name())
.ok() .ok()
.ok_or(RegistrationError::InvalidUserId)?; .ok_or(RegistrationError::InvalidUserId)?;
@ -115,7 +134,11 @@ async fn post_register(
.create(&db) .create(&db)
.await?; .await?;
let device = Device::new(&user, "test", display_name)? let device = Device::new(
&user,
uuid::Uuid::new_v4().to_string().as_ref(),
display_name,
)?
.create(&db) .create(&db)
.await?; .await?;

View File

@ -0,0 +1 @@
pub mod auth;

View File

@ -1,6 +1,9 @@
use crate::types::server_name::ServerName;
pub struct Config { pub struct Config {
db_path: String, db_path: String,
homeserver_name: String, server_name: ServerName,
enable_registration: bool
} }
impl Config { impl Config {
@ -12,8 +15,14 @@ impl Config {
/// Get a reference to the config's homeserver name. /// Get a reference to the config's homeserver name.
#[must_use] #[must_use]
pub fn homeserver_name(&self) -> &str { pub fn server_name(&self) -> &ServerName {
self.homeserver_name.as_ref() &self.server_name
}
/// Get the config's enable registration.
#[must_use]
pub fn enable_registration(&self) -> bool {
self.enable_registration
} }
} }
@ -21,7 +30,8 @@ impl Default for Config {
fn default() -> Self { fn default() -> Self {
Self { Self {
db_path: "sqlite://db.sqlite3".into(), db_path: "sqlite://db.sqlite3".into(),
homeserver_name: "fuckwit.dev".into(), server_name: ServerName::new("fuckwit.dev").unwrap(),
enable_registration: true
} }
} }
} }

View File

@ -43,7 +43,7 @@ async fn main() -> anyhow::Result<()> {
let client_server = Router::new() let client_server = Router::new()
.merge(api::client_server::versions::routes()) .merge(api::client_server::versions::routes())
.merge(api::client_server::auth::routes()); .merge(api::client_server::r0::auth::routes());
let router = Router::new() let router = Router::new()
.nest("/_matrix/client", client_server) .nest("/_matrix/client", client_server)

View File

@ -35,6 +35,20 @@ impl Device {
) )
} }
pub async fn find_for_user(
conn: &SqlitePool,
user: &User,
device_id: &str,
) -> anyhow::Result<Self> {
let user_uuid = user.uuid();
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where user_uuid = ?",
user_uuid)
.fetch_one(conn).await?
)
}
/// Get the device's id. /// Get the device's id.
#[must_use] #[must_use]
pub fn uuid(&self) -> &Uuid { pub fn uuid(&self) -> &Uuid {

View File

@ -1,3 +1,4 @@
use rand::{thread_rng, Rng, distributions::Alphanumeric};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use crate::types::uuid::Uuid; use crate::types::uuid::Uuid;
@ -12,10 +13,13 @@ pub struct Session {
impl Session { impl Session {
pub fn new(device: &Device) -> anyhow::Result<Self> { pub fn new(device: &Device) -> anyhow::Result<Self> {
let mut rng = thread_rng();
let key: String = (0..32).map(|_| rng.sample(Alphanumeric) as char).collect();
Ok(Self { Ok(Self {
uuid: uuid::Uuid::new_v4().into(), uuid: uuid::Uuid::new_v4().into(),
device_uuid: device.uuid().clone(), device_uuid: device.uuid().clone(),
key: String::new(), key,
}) })
} }

View File

@ -1,6 +1,6 @@
use crate::types::uuid::Uuid; use crate::types::uuid::Uuid;
use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier}; use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier};
use rand_core::OsRng; use rand::rngs::OsRng;
use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool}; use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool};
use crate::types::user_id::UserId; use crate::types::user_id::UserId;

View File

@ -44,7 +44,7 @@ impl AuthenticationPassword {
/// Get a mutable reference to the authentication password's initial device display name. /// Get a mutable reference to the authentication password's initial device display name.
#[must_use] #[must_use]
pub fn initial_device_display_name_mut(&self) -> Option<&String> { pub fn initial_device_display_name(&self) -> Option<&String> {
self.initial_device_display_name.as_ref() self.initial_device_display_name.as_ref()
} }
} }

View File

@ -6,3 +6,4 @@ pub mod identifier_type;
pub mod user_id; pub mod user_id;
pub mod user_interactive_authorization; pub mod user_interactive_authorization;
pub mod uuid; pub mod uuid;
pub mod server_name;

151
src/types/server_name.rs Normal file
View File

@ -0,0 +1,151 @@
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
#[derive(Debug, PartialEq, Eq)]
enum Hostname {
IPv4(Ipv4Addr),
IPv6(Ipv6Addr),
FQDN(String),
}
pub struct ServerName {
hostname: Hostname,
port: Option<u16>,
}
impl ServerName {
pub fn new(server_name: &str) -> anyhow::Result<Self> {
if let Ok(addr) = server_name.parse::<Ipv4Addr>() {
return Ok(Self {
hostname: Hostname::IPv4(addr),
port: None,
});
};
if let Ok(addr) = server_name.parse::<SocketAddrV4>() {
return Ok(Self {
hostname: Hostname::IPv4(addr.ip().to_owned()),
port: Some(addr.port()),
});
};
if server_name.ends_with(']') {
if let Ok(addr) = server_name[1..server_name.chars().count() - 1].parse::<Ipv6Addr>() {
return Ok(Self {
hostname: Hostname::IPv6(addr),
port: None,
});
};
}
if let Ok(addr) = server_name.parse::<SocketAddrV6>() {
return Ok(Self {
hostname: Hostname::IPv6(addr.ip().to_owned()),
port: Some(addr.port()),
});
};
if let Some(idx) = server_name.find(':') {
let (hostname, port) = server_name.split_at(idx);
Ok(Self {
hostname: Hostname::FQDN(hostname.to_owned()),
port: Some(port[1..].parse()?),
})
} else {
Ok(Self {
hostname: Hostname::FQDN(server_name.to_owned()),
port: None,
})
}
}
/// Get a reference to the server name's hostname.
#[must_use]
fn hostname(&self) -> &Hostname {
&self.hostname
}
/// Get the server name's port.
#[must_use]
fn port(&self) -> Option<u16> {
self.port
}
}
impl std::fmt::Display for ServerName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.hostname {
Hostname::IPv4(hostname) => write!(f, "{hostname}"),
Hostname::IPv6(hostname) => write!(f, "[{hostname}]"),
Hostname::FQDN(hostname) => write!(f, "{hostname}"),
}?;
if let Some(port) = self.port {
write!(f, ":{port}")?;
}
Ok(())
}
}
mod tests {
use super::*;
#[test]
fn parse_ipv4_without_port() {
let server_name = ServerName::new("127.0.0.1").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port(), None);
}
#[test]
fn parse_ipv4_with_port() {
let server_name = ServerName::new("127.0.0.1:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv4("127.0.0.1".parse().unwrap())
);
assert_eq!(server_name.port(), Some(8080));
}
#[test]
fn parse_ipv6_without_port() {
let server_name = ServerName::new("[::1]").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv6("::1".parse().unwrap())
);
assert_eq!(server_name.port(), None);
}
#[test]
fn parse_ipv6_with_port() {
let server_name = ServerName::new("[::1]:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::IPv6("::1".parse().unwrap())
);
assert_eq!(server_name.port(), Some(8080));
}
#[test]
fn parse_fqdn_without_port() {
let server_name = ServerName::new("example.com").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::FQDN("example.com".into())
);
assert_eq!(server_name.port(), None);
}
#[test]
fn parse_fqdn_with_port() {
let server_name = ServerName::new("example.com:8080").unwrap();
assert_eq!(
server_name.hostname(),
&Hostname::FQDN("example.com".into())
);
assert_eq!(server_name.port(), Some(8080));
}
}

View File

@ -2,6 +2,8 @@ use std::fmt::Display;
use sqlx::{encode::IsNull, Sqlite}; use sqlx::{encode::IsNull, Sqlite};
use super::server_name::ServerName;
#[derive(Clone)] #[derive(Clone)]
#[repr(transparent)] #[repr(transparent)]
pub struct UserId(String); pub struct UserId(String);
@ -36,7 +38,7 @@ impl<'d> sqlx::Decode<'d, Sqlite> for UserId {
} }
impl UserId { impl UserId {
pub fn new(name: &str, server_name: &str) -> anyhow::Result<Self> { pub fn new(name: &str, server_name: &ServerName) -> anyhow::Result<Self> {
let user_id = Self(format!("@{name}:{server_name}")); let user_id = Self(format!("@{name}:{server_name}"));
user_id.is_valid()?; user_id.is_valid()?;