validate UserId more
This commit is contained in:
parent
9687490c8e
commit
b045bfb321
@ -70,9 +70,6 @@ async fn post_register(
|
|||||||
|
|
||||||
let (user, device) = match &body.auth().unwrap() {
|
let (user, device) = match &body.auth().unwrap() {
|
||||||
AuthenticationData::Password(auth_data) => {
|
AuthenticationData::Password(auth_data) => {
|
||||||
// let username = match auth_data.identifier() {
|
|
||||||
// Identifier::User(user_identifier) => user_identifier.user().unwrap(),
|
|
||||||
// };
|
|
||||||
let username = body.username().unwrap();
|
let username = body.username().unwrap();
|
||||||
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
|
let user_id = UserId::new(&username, &config.homeserver_name).unwrap();
|
||||||
if User::exists(&db, &user_id).await.unwrap() {
|
if User::exists(&db, &user_id).await.unwrap() {
|
||||||
@ -95,7 +92,6 @@ async fn post_register(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// dont log in the user after registration
|
|
||||||
if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() {
|
if *&body.inhibit_login().is_some() && *&body.inhibit_login().unwrap() {
|
||||||
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
|
let resp = RegistrationSuccess::new(None, device.device_id(), user.user_id());
|
||||||
|
|
||||||
|
@ -39,12 +39,12 @@ async fn main() {
|
|||||||
if std::env::var("RUST_LOG").is_err() {
|
if std::env::var("RUST_LOG").is_err() {
|
||||||
std::env::set_var("RUST_LOG", "debug");
|
std::env::set_var("RUST_LOG", "debug");
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
// init config
|
|
||||||
let config = Arc::new(Config::default());
|
let config = Arc::new(Config::default());
|
||||||
|
|
||||||
let pool = sqlx::SqlitePool::connect("sqlite://db.sqlite3")
|
let pool = sqlx::SqlitePool::connect(&config.db_path)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ use std::fmt::Display;
|
|||||||
|
|
||||||
#[derive(sqlx::Type)]
|
#[derive(sqlx::Type)]
|
||||||
#[sqlx(transparent)]
|
#[sqlx(transparent)]
|
||||||
|
#[repr(transparent)]
|
||||||
pub struct UserId(String);
|
pub struct UserId(String);
|
||||||
|
|
||||||
impl UserId {
|
impl UserId {
|
||||||
@ -13,11 +14,23 @@ impl UserId {
|
|||||||
Ok(user_id)
|
Ok(user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn local_part(&self) -> &str {
|
||||||
|
let col_idx = self.0.find(':').expect("Will always have at least one ':'");
|
||||||
|
self.0[1..col_idx].as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
fn is_valid(&self) -> anyhow::Result<()> {
|
fn is_valid(&self) -> anyhow::Result<()> {
|
||||||
(self.0.len() <= 255)
|
(self.0.len() <= 255)
|
||||||
.then(|| ())
|
.then(|| ())
|
||||||
.ok_or(UserIdError::TooLong(self.0.len()))?;
|
.ok_or(UserIdError::TooLong(self.0.len()))?;
|
||||||
|
|
||||||
|
let local_part = self.local_part();
|
||||||
|
local_part
|
||||||
|
.bytes()
|
||||||
|
.all(|b| matches!(b, b'0'..=b'9' | b'a'..=b'z' | b'-' | b'.' | b'=' | b'_' | b'/'))
|
||||||
|
.then(|| ())
|
||||||
|
.ok_or(UserIdError::InvalidCharacters)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -32,6 +45,8 @@ impl Display for UserId {
|
|||||||
pub enum UserIdError {
|
pub enum UserIdError {
|
||||||
#[error("UserId too long {0} (expected < 255)")]
|
#[error("UserId too long {0} (expected < 255)")]
|
||||||
TooLong(usize),
|
TooLong(usize),
|
||||||
|
#[error("Invalid character present in user id")]
|
||||||
|
InvalidCharacters,
|
||||||
#[error("Invalid UserId given")]
|
#[error("Invalid UserId given")]
|
||||||
Invalid,
|
Invalid,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user