add thirdparty_protocols get and authorization middleware
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Patrick Michl 2022-05-02 21:25:20 +02:00
parent 54f67d435e
commit ba84efd384
12 changed files with 289 additions and 22 deletions

View File

@ -36,6 +36,42 @@
}, },
"query": "insert into devices(uuid, user_uuid, device_id, display_name)\n values(?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name" "query": "insert into devices(uuid, user_uuid, device_id, display_name)\n values(?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name"
}, },
"1843c2b3e548d1dd13694a65ca1ba123da38668c3fc5bc431fe5884a6fc25f71": {
"describe": {
"columns": [
{
"name": "uuid: Uuid",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "user_uuid: Uuid",
"ordinal": 1,
"type_info": "Int64"
},
{
"name": "device_id",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "display_name",
"ordinal": 3,
"type_info": "Text"
}
],
"nullable": [
false,
false,
false,
false
],
"parameters": {
"Right": 1
}
},
"query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?"
},
"221d0935dff8911fe58ac047d39e11b0472d2180d7c297291a5dc440e00efb80": { "221d0935dff8911fe58ac047d39e11b0472d2180d7c297291a5dc440e00efb80": {
"describe": { "describe": {
"columns": [ "columns": [
@ -66,6 +102,42 @@
}, },
"query": "insert into sessions(uuid, device_uuid, key)\n values(?, ?, ?)\n returning uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key" "query": "insert into sessions(uuid, device_uuid, key)\n values(?, ?, ?)\n returning uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key"
}, },
"383949b72c69bca95bf23ef06900cd1ac5a136cdd4a525cbb624d327ce0cdefb": {
"describe": {
"columns": [
{
"name": "uuid: Uuid",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "user_id: UserId",
"ordinal": 1,
"type_info": "Text"
},
{
"name": "display_name",
"ordinal": 2,
"type_info": "Text"
},
{
"name": "password",
"ordinal": 3,
"type_info": "Text"
}
],
"nullable": [
false,
false,
false,
false
],
"parameters": {
"Right": 1
}
},
"query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where uuid = ?"
},
"3fead3dac0e110757bc30be40bb0c6c2bc02127b6d9b6145bfc40fa5fe22ad06": { "3fead3dac0e110757bc30be40bb0c6c2bc02127b6d9b6145bfc40fa5fe22ad06": {
"describe": { "describe": {
"columns": [ "columns": [
@ -156,6 +228,36 @@
}, },
"query": "update users set uuid = ?, user_id = ?, display_name = ?, password = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password" "query": "update users set uuid = ?, user_id = ?, display_name = ?, password = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password"
}, },
"b38fd90504bea0c63e6517738c2354e6b057fcc6c643283019b27689e286bf2d": {
"describe": {
"columns": [
{
"name": "uuid: Uuid",
"ordinal": 0,
"type_info": "Text"
},
{
"name": "device_uuid: Uuid",
"ordinal": 1,
"type_info": "Int64"
},
{
"name": "key",
"ordinal": 2,
"type_info": "Text"
}
],
"nullable": [
false,
false,
false
],
"parameters": {
"Right": 1
}
},
"query": "select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key\n from sessions where key = ?"
},
"ddcc531c080b2a1c70166d29a940aaada6701abe2933c305a879e7f18baeaf3a": { "ddcc531c080b2a1c70166d29a940aaada6701abe2933c305a879e7f18baeaf3a": {
"describe": { "describe": {
"columns": [ "columns": [

View File

@ -13,7 +13,7 @@ pub struct ErrorResponse {
} }
impl ErrorResponse { impl ErrorResponse {
fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self { pub fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
Self { Self {
errcode, errcode,
error: error.to_owned(), error: error.to_owned(),

View File

@ -54,7 +54,7 @@ async fn post_login(
match body { match body {
AuthenticationData::Password(auth_data) => { AuthenticationData::Password(auth_data) => {
let user = auth_data.user().unwrap(); let user = auth_data.user().unwrap();
let user_id = UserId::new(&user, config.server_name()) let user_id = UserId::new(user, config.server_name())
.ok() .ok()
.ok_or(AuthenticationError::InvalidUserId)?; .ok_or(AuthenticationError::InvalidUserId)?;
@ -96,7 +96,7 @@ async fn get_username_available(
let username = params let username = params
.get("username") .get("username")
.ok_or(RegistrationError::MissingUserId)?; .ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.server_name()) let user_id = UserId::new(username, config.server_name())
.ok() .ok()
.ok_or(RegistrationError::InvalidUserId)?; .ok_or(RegistrationError::InvalidUserId)?;
let exists = User::exists(&db, &user_id).await?; let exists = User::exists(&db, &user_id).await?;
@ -117,7 +117,7 @@ async fn post_register(
let (user, device) = match &body.auth().expect("must be Some") { let (user, device) = match &body.auth().expect("must be Some") {
AuthenticationData::Password(auth_data) => { AuthenticationData::Password(auth_data) => {
let username = body.username().ok_or(RegistrationError::MissingUserId)?; let username = body.username().ok_or(RegistrationError::MissingUserId)?;
let user_id = UserId::new(username, &config.server_name()) let user_id = UserId::new(username, config.server_name())
.ok() .ok()
.ok_or(RegistrationError::InvalidUserId)?; .ok_or(RegistrationError::InvalidUserId)?;

View File

@ -1 +1,106 @@
use std::sync::Arc;
use axum::{
http::{Request, StatusCode},
middleware::Next,
response::IntoResponse,
Json,
};
use sqlx::SqlitePool;
use crate::{models::sessions::Session, types::error_code::ErrorCode};
use super::errors::ErrorResponse;
pub mod auth; pub mod auth;
pub mod thirdparty;
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let db: &SqlitePool = req.extensions().get().unwrap();
let auth_header = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
if auth_header.is_none() {
return (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(
ErrorCode::Forbidden,
"Authorization Header not given",
None,
)),
)
.into_response();
}
let auth_header = auth_header.expect("Validated above");
let idx = auth_header.find(' ');
let idx = match idx {
Some(idx) => idx,
None => {
return (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(
ErrorCode::Forbidden,
"Invalid Authorization Header",
None,
)),
)
.into_response()
}
};
let session = match Session::find_by_key(db, &auth_header[idx + 1..]).await {
Ok(session) => session,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::new(
ErrorCode::Unknown,
"Internal Server Error",
None,
)),
)
.into_response()
}
};
let session = match session {
Some(session) => session,
None => {
return (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
)
.into_response()
}
};
let device = match session.device(db).await {
Ok(device) => device,
Err(_) => {
return (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
)
.into_response()
}
};
let user = match device.user(db).await {
Ok(user) => user,
Err(_) => {
return (
StatusCode::FORBIDDEN,
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
)
.into_response()
}
};
req.extensions_mut().insert(Arc::new(user));
next.run(req).await.into_response()
}

View File

@ -0,0 +1,17 @@
use std::sync::Arc;
use axum::{routing::get, Extension};
use crate::{api::client_server::errors::api_error::ApiError, models::users::User};
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/r0/thirdparty/protocols", get(get_thirdparty_protocols))
.layer(axum::middleware::from_fn(super::authentication_middleware))
}
#[tracing::instrument(skip_all)]
async fn get_thirdparty_protocols(Extension(user): Extension<Arc<User>>) -> Result<String, ApiError> {
Ok("{}".into())
}

View File

@ -16,11 +16,11 @@ use tower_http::{
use tracing::Level; use tracing::Level;
mod api; mod api;
mod config;
mod models; mod models;
mod requests; mod requests;
mod responses; mod responses;
mod types; mod types;
mod config;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
@ -32,7 +32,7 @@ async fn main() -> anyhow::Result<()> {
let config = Arc::new(Config::default()); let config = Arc::new(Config::default());
let pool = sqlx::SqlitePool::connect(&config.db_path()).await?; let pool = sqlx::SqlitePool::connect(config.db_path()).await?;
let cors = CorsLayer::new() let cors = CorsLayer::new()
.allow_origin(tower_http::cors::Any) .allow_origin(tower_http::cors::Any)
@ -43,7 +43,8 @@ async fn main() -> anyhow::Result<()> {
let client_server = Router::new() let client_server = Router::new()
.merge(api::client_server::versions::routes()) .merge(api::client_server::versions::routes())
.merge(api::client_server::r0::auth::routes()); .merge(api::client_server::r0::auth::routes())
.merge(api::client_server::r0::thirdparty::routes());
let router = Router::new() let router = Router::new()
.nest("/_matrix/client", client_server) .nest("/_matrix/client", client_server)
@ -60,7 +61,7 @@ async fn main() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn fallback(request: Request<Body>) -> StatusCode { async fn fallback(mut request: Request<Body>) -> StatusCode {
dbg!(request); println!("{} {}", request.method(), request.uri());
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
} }

View File

@ -49,6 +49,19 @@ impl Device {
) )
} }
pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?",
uuid)
.fetch_one(conn).await?
)
}
pub async fn user(&self, conn: &SqlitePool) -> anyhow::Result<User> {
User::find_by_uuid(conn, &self.user_uuid).await
}
/// Get the device's id. /// Get the device's id.
#[must_use] #[must_use]
pub fn uuid(&self) -> &Uuid { pub fn uuid(&self) -> &Uuid {

View File

@ -1,4 +1,4 @@
use rand::{thread_rng, Rng, distributions::Alphanumeric}; use rand::{distributions::Alphanumeric, thread_rng, Rng};
use sqlx::SqlitePool; use sqlx::SqlitePool;
use crate::types::uuid::Uuid; use crate::types::uuid::Uuid;
@ -14,7 +14,7 @@ pub struct Session {
impl Session { impl Session {
pub fn new(device: &Device) -> anyhow::Result<Self> { pub fn new(device: &Device) -> anyhow::Result<Self> {
let mut rng = thread_rng(); let mut rng = thread_rng();
let key: String = (0..32).map(|_| rng.sample(Alphanumeric) as char).collect(); let key: String = (0..256).map(|_| rng.sample(Alphanumeric) as char).collect();
Ok(Self { Ok(Self {
uuid: uuid::Uuid::new_v4().into(), uuid: uuid::Uuid::new_v4().into(),
@ -37,6 +37,21 @@ impl Session {
.await?) .await?)
} }
pub async fn find_by_key(conn: &SqlitePool, key: &str) -> anyhow::Result<Option<Self>> {
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key
from sessions where key = ?",
key
)
.fetch_optional(conn)
.await?)
}
pub async fn device(&self, conn: &SqlitePool) -> anyhow::Result<Device> {
Device::find_by_uuid(conn, &self.device_uuid).await
}
/// Get the session's id. /// Get the session's id.
#[must_use] #[must_use]
pub fn uuid(&self) -> &Uuid { pub fn uuid(&self) -> &Uuid {

View File

@ -1,10 +1,11 @@
use crate::types::uuid::Uuid; use crate::types::uuid::Uuid;
use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier}; use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use rand::rngs::OsRng; use rand::rngs::OsRng;
use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool}; use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool};
use crate::types::user_id::UserId; use crate::types::user_id::UserId;
#[derive(Debug)]
pub struct User { pub struct User {
uuid: Uuid, uuid: Uuid,
user_id: UserId, user_id: UserId,
@ -68,6 +69,17 @@ impl User {
.await?) .await?)
} }
pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result<Self> {
Ok(sqlx::query_as!(
Self,
"select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password
from users where uuid = ?",
uuid
)
.fetch_one(conn)
.await?)
}
pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> { pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
Self, Self,
@ -82,7 +94,9 @@ impl User {
pub fn password_correct(&self, password: &str) -> anyhow::Result<bool> { pub fn password_correct(&self, password: &str) -> anyhow::Result<bool> {
let password_hash = PasswordHash::new(self.password())?; let password_hash = PasswordHash::new(self.password())?;
Ok(Argon2::default().verify_password(password.as_bytes(), &password_hash).is_ok()) Ok(Argon2::default()
.verify_password(password.as_bytes(), &password_hash)
.is_ok())
} }
/// Get the user's id. /// Get the user's id.

View File

@ -26,7 +26,7 @@ pub struct RegistrationSuccess {
impl RegistrationSuccess { impl RegistrationSuccess {
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self { pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
Self { Self {
access_token: access_token.and_then(|v| Some(v.to_owned())), access_token: access_token.map(|v| v.to_owned()),
device_id: device_id.to_owned(), device_id: device_id.to_owned(),
user_id: user_id.to_owned(), user_id: user_id.to_owned(),
} }

View File

@ -4,7 +4,7 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
enum Hostname { enum Hostname {
IPv4(Ipv4Addr), IPv4(Ipv4Addr),
IPv6(Ipv6Addr), IPv6(Ipv6Addr),
FQDN(String), Fqdn(String),
} }
pub struct ServerName { pub struct ServerName {
@ -47,12 +47,12 @@ impl ServerName {
if let Some(idx) = server_name.find(':') { if let Some(idx) = server_name.find(':') {
let (hostname, port) = server_name.split_at(idx); let (hostname, port) = server_name.split_at(idx);
Ok(Self { Ok(Self {
hostname: Hostname::FQDN(hostname.to_owned()), hostname: Hostname::Fqdn(hostname.to_owned()),
port: Some(port[1..].parse()?), port: Some(port[1..].parse()?),
}) })
} else { } else {
Ok(Self { Ok(Self {
hostname: Hostname::FQDN(server_name.to_owned()), hostname: Hostname::Fqdn(server_name.to_owned()),
port: None, port: None,
}) })
} }
@ -76,7 +76,7 @@ impl std::fmt::Display for ServerName {
match &self.hostname { match &self.hostname {
Hostname::IPv4(hostname) => write!(f, "{hostname}"), Hostname::IPv4(hostname) => write!(f, "{hostname}"),
Hostname::IPv6(hostname) => write!(f, "[{hostname}]"), Hostname::IPv6(hostname) => write!(f, "[{hostname}]"),
Hostname::FQDN(hostname) => write!(f, "{hostname}"), Hostname::Fqdn(hostname) => write!(f, "{hostname}"),
}?; }?;
if let Some(port) = self.port { if let Some(port) = self.port {
@ -134,7 +134,7 @@ mod tests {
let server_name = ServerName::new("example.com").unwrap(); let server_name = ServerName::new("example.com").unwrap();
assert_eq!( assert_eq!(
server_name.hostname(), server_name.hostname(),
&Hostname::FQDN("example.com".into()) &Hostname::Fqdn("example.com".into())
); );
assert_eq!(server_name.port(), None); assert_eq!(server_name.port(), None);
} }
@ -144,7 +144,7 @@ mod tests {
let server_name = ServerName::new("example.com:8080").unwrap(); let server_name = ServerName::new("example.com:8080").unwrap();
assert_eq!( assert_eq!(
server_name.hostname(), server_name.hostname(),
&Hostname::FQDN("example.com".into()) &Hostname::Fqdn("example.com".into())
); );
assert_eq!(server_name.port(), Some(8080)); assert_eq!(server_name.port(), Some(8080));
} }

View File

@ -4,7 +4,7 @@ use sqlx::{encode::IsNull, Sqlite};
use super::server_name::ServerName; use super::server_name::ServerName;
#[derive(Clone)] #[derive(Debug, Clone)]
#[repr(transparent)] #[repr(transparent)]
pub struct UserId(String); pub struct UserId(String);