diff --git a/src/api/client_server/auth.rs b/src/api/client_server/auth.rs index a47269e..5965d61 100644 --- a/src/api/client_server/auth.rs +++ b/src/api/client_server/auth.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use argon2::{password_hash::SaltString, Argon2, PasswordHasher}; +use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher}; use axum::{ extract::Query, http::StatusCode, @@ -50,12 +50,23 @@ async fn post_login( Extension(db): Extension, Json(body): Json, ) -> Result, ApiError> { - let user = UserId::new("name", "server_name") - .ok() - .ok_or(AuthenticationError::InvalidUserId)?; - let resp = AuthenticationSuccess::new("", "", &user); + match body { + AuthenticationData::Password(auth_data) => { + let user = auth_data.user().unwrap(); + let user_id = UserId::new(&user, config.homeserver_name()) + .ok() + .ok_or(AuthenticationError::InvalidUserId)?; + + let user = User::find_by_user_id(&db, &user_id).await?; - Ok(Json(AuthenticationResponse::Success(resp))) + user.password_correct(auth_data.password()).ok().ok_or(AuthenticationError::Forbidden)?; + + todo!("find_or_create device for user and create a session"); + + let resp = AuthenticationSuccess::new("", "", &user_id); + Ok(Json(AuthenticationResponse::Success(resp))) + } + } } #[tracing::instrument(skip_all)] @@ -67,7 +78,7 @@ async fn get_username_available( let username = params .get("username") .ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.homeserver_name) + let user_id = UserId::new(username, &config.homeserver_name()) .ok() .ok_or(RegistrationError::InvalidUserId)?; let exists = User::exists(&db, &user_id).await?; @@ -87,7 +98,7 @@ async fn post_register( let (user, device) = match &body.auth().expect("must be Some") { AuthenticationData::Password(auth_data) => { let username = body.username().ok_or(RegistrationError::MissingUserId)?; - let user_id = UserId::new(username, &config.homeserver_name) + let user_id = UserId::new(username, &config.homeserver_name()) .ok() .ok_or(RegistrationError::InvalidUserId)?; diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..606722e --- /dev/null +++ b/src/config.rs @@ -0,0 +1,27 @@ +pub struct Config { + db_path: String, + homeserver_name: String, +} + +impl Config { + /// Get a reference to the config's db path. + #[must_use] + pub fn db_path(&self) -> &str { + self.db_path.as_ref() + } + + /// Get a reference to the config's homeserver name. + #[must_use] + pub fn homeserver_name(&self) -> &str { + self.homeserver_name.as_ref() + } +} + +impl Default for Config { + fn default() -> Self { + Self { + db_path: "sqlite://db.sqlite3".into(), + homeserver_name: "fuckwit.dev".into(), + } + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 650d1f6..38af7dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ use axum::{ http::{Request, StatusCode}, Extension, Router, }; +use config::Config; use tower_http::{ cors::CorsLayer, trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer}, @@ -19,20 +20,7 @@ mod models; mod requests; mod responses; mod types; - -struct Config { - db_path: String, - homeserver_name: String, -} - -impl Default for Config { - fn default() -> Self { - Self { - db_path: "sqlite://db.sqlite3".into(), - homeserver_name: "fuckwit.dev".into(), - } - } -} +mod config; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -44,7 +32,7 @@ async fn main() -> anyhow::Result<()> { let config = Arc::new(Config::default()); - let pool = sqlx::SqlitePool::connect(&config.db_path).await?; + let pool = sqlx::SqlitePool::connect(&config.db_path()).await?; let cors = CorsLayer::new() .allow_origin(tower_http::cors::Any) diff --git a/src/models/users.rs b/src/models/users.rs index 8e2020c..6d92c92 100644 --- a/src/models/users.rs +++ b/src/models/users.rs @@ -1,5 +1,5 @@ use crate::types::uuid::Uuid; -use argon2::{password_hash::SaltString, Argon2, PasswordHasher}; +use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier}; use rand_core::OsRng; use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool}; @@ -68,7 +68,7 @@ impl User { .await?) } - pub async fn by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result { + pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result { Ok(sqlx::query_as!( Self, "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password @@ -79,6 +79,12 @@ impl User { .await?) } + pub fn password_correct(&self, password: &str) -> anyhow::Result { + let password_hash = PasswordHash::new(self.password())?; + + Ok(Argon2::default().verify_password(password.as_bytes(), &password_hash).is_ok()) + } + /// Get the user's id. #[must_use] pub fn uuid(&self) -> &Uuid {