validate UserId more

This commit is contained in:
Patrick Michl 2022-04-24 22:56:56 +02:00
parent 9687490c8e
commit b045bfb321
3 changed files with 17 additions and 6 deletions

View File

@ -70,9 +70,6 @@ async fn post_register(
let (user, device) = match &body.auth().unwrap() {
AuthenticationData::Password(auth_data) => {
// let username = match auth_data.identifier() {
// Identifier::User(user_identifier) => user_identifier.user().unwrap(),
// };
let username = body.username().unwrap();
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
if User::exists(&db, &user_id).await.unwrap() {
@ -95,7 +92,6 @@ async fn post_register(
}
};
// dont log in the user after registration
if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() {
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());

View File

@ -39,12 +39,12 @@ async fn main() {
if std::env::var("RUST_LOG").is_err() {
std::env::set_var("RUST_LOG", "debug");
}
tracing_subscriber::fmt::init();
// init config
let config = Arc::new(Config::default());
let pool = sqlx::SqlitePool::connect("sqlite://db.sqlite3")
let pool = sqlx::SqlitePool::connect(&config.db_path)
.await
.unwrap();

View File

@ -2,6 +2,7 @@ use std::fmt::Display;
#[derive(sqlx::Type)]
#[sqlx(transparent)]
#[repr(transparent)]
pub struct UserId(String);
impl UserId {
@ -13,11 +14,23 @@ impl UserId {
Ok(user_id)
}
pub fn local_part(&self) -> &str {
let col_idx = self.0.find(':').expect("Will always have at least one ':'");
self.0[1..col_idx].as_ref()
}
fn is_valid(&self) -> anyhow::Result<()> {
(self.0.len() <= 255)
.then(|| ())
.ok_or(UserIdError::TooLong(self.0.len()))?;
let local_part = self.local_part();
local_part
.bytes()
.all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/'))
.then(|| ())
.ok_or(UserIdError::InvalidCharacters)?;
Ok(())
}
}
@ -32,6 +45,8 @@ impl Display for UserId {
pub enum UserIdError {
#[error("UserId too long {0} (expected < 255)")]
TooLong(usize),
#[error("Invalid character present in user id")]
InvalidCharacters,
#[error("Invalid UserId given")]
Invalid,
}