From b045bfb321f44df8c6a1f070cfa78c42a4c3900b Mon Sep 17 00:00:00 2001 From: Patrick Michl Date: Sun, 24 Apr 2022 22:56:56 +0200 Subject: [PATCH] validate UserId more --- src/api/client_server/auth.rs | 4 ---- src/main.rs | 4 ++-- src/types/matrix_user_id.rs | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/api/client_server/auth.rs b/src/api/client_server/auth.rs index 35bcc59..8b4c8db 100644 --- a/src/api/client_server/auth.rs +++ b/src/api/client_server/auth.rs @@ -70,9 +70,6 @@ async fn post_register( let (user, device) = match &body.auth().unwrap() { AuthenticationData::Password(auth_data) => { - // let username = match auth_data.identifier() { - // Identifier::User(user_identifier) => user_identifier.user().unwrap(), - // }; let username = body.username().unwrap(); let user_id = UserId::new(&username, &config.homeserver_name).unwrap(); if User::exists(&db, &user_id).await.unwrap() { @@ -95,7 +92,6 @@ async fn post_register( } }; - // dont log in the user after registration if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() { let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id()); diff --git a/src/main.rs b/src/main.rs index 88101c4..f5257ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,12 +39,12 @@ async fn main() { if std::env::var("RUST_LOG").is_err() { std::env::set_var("RUST_LOG", "debug"); } + tracing_subscriber::fmt::init(); - // init config let config = Arc::new(Config::default()); - let pool = sqlx::SqlitePool::connect("sqlite://db.sqlite3") + let pool = sqlx::SqlitePool::connect(&config.db_path) .await .unwrap(); diff --git a/src/types/matrix_user_id.rs b/src/types/matrix_user_id.rs index 007315a..6f724f1 100644 --- a/src/types/matrix_user_id.rs +++ b/src/types/matrix_user_id.rs @@ -2,6 +2,7 @@ use std::fmt::Display; #[derive(sqlx::Type)] #[sqlx(transparent)] +#[repr(transparent)] pub struct UserId(String); impl UserId { @@ -13,11 +14,23 @@ impl UserId { Ok(user_id) } + pub fn local_part(&self) -> &str { + let col_idx = self.0.find(':').expect("Will always have at least one ':'"); + self.0[1..col_idx].as_ref() + } + fn is_valid(&self) -> anyhow::Result<()> { (self.0.len() <= 255) .then(|| ()) .ok_or(UserIdError::TooLong(self.0.len()))?; + let local_part = self.local_part(); + local_part + .bytes() + .all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/')) + .then(|| ()) + .ok_or(UserIdError::InvalidCharacters)?; + Ok(()) } } @@ -32,6 +45,8 @@ impl Display for UserId { pub enum UserIdError { #[error("UserId too long {0} (expected < 255)")] TooLong(usize), + #[error("Invalid character present in user id")] + InvalidCharacters, #[error("Invalid UserId given")] Invalid, }