add thirdparty_protocols get and authorization middleware
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
54f67d435e
commit
ba84efd384
102
sqlx-data.json
102
sqlx-data.json
@ -36,6 +36,42 @@
|
||||
},
|
||||
"query": "insert into devices(uuid, user_uuid, device_id, display_name)\n values(?, ?, ?, ?)\n returning uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name"
|
||||
},
|
||||
"1843c2b3e548d1dd13694a65ca1ba123da38668c3fc5bc431fe5884a6fc25f71": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "uuid: Uuid",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "user_uuid: Uuid",
|
||||
"ordinal": 1,
|
||||
"type_info": "Int64"
|
||||
},
|
||||
{
|
||||
"name": "device_id",
|
||||
"ordinal": 2,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "display_name",
|
||||
"ordinal": 3,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
}
|
||||
},
|
||||
"query": "select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?"
|
||||
},
|
||||
"221d0935dff8911fe58ac047d39e11b0472d2180d7c297291a5dc440e00efb80": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
@ -66,6 +102,42 @@
|
||||
},
|
||||
"query": "insert into sessions(uuid, device_uuid, key)\n values(?, ?, ?)\n returning uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key"
|
||||
},
|
||||
"383949b72c69bca95bf23ef06900cd1ac5a136cdd4a525cbb624d327ce0cdefb": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "uuid: Uuid",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "user_id: UserId",
|
||||
"ordinal": 1,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "display_name",
|
||||
"ordinal": 2,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "password",
|
||||
"ordinal": 3,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
}
|
||||
},
|
||||
"query": "select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password\n from users where uuid = ?"
|
||||
},
|
||||
"3fead3dac0e110757bc30be40bb0c6c2bc02127b6d9b6145bfc40fa5fe22ad06": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
@ -156,6 +228,36 @@
|
||||
},
|
||||
"query": "update users set uuid = ?, user_id = ?, display_name = ?, password = ?\n where uuid = ?\n returning uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password"
|
||||
},
|
||||
"b38fd90504bea0c63e6517738c2354e6b057fcc6c643283019b27689e286bf2d": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"name": "uuid: Uuid",
|
||||
"ordinal": 0,
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"name": "device_uuid: Uuid",
|
||||
"ordinal": 1,
|
||||
"type_info": "Int64"
|
||||
},
|
||||
{
|
||||
"name": "key",
|
||||
"ordinal": 2,
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false
|
||||
],
|
||||
"parameters": {
|
||||
"Right": 1
|
||||
}
|
||||
},
|
||||
"query": "select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key\n from sessions where key = ?"
|
||||
},
|
||||
"ddcc531c080b2a1c70166d29a940aaada6701abe2933c305a879e7f18baeaf3a": {
|
||||
"describe": {
|
||||
"columns": [
|
||||
|
@ -13,7 +13,7 @@ pub struct ErrorResponse {
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
|
||||
pub fn new(errcode: ErrorCode, error: &str, retry_after_ms: Option<u64>) -> Self {
|
||||
Self {
|
||||
errcode,
|
||||
error: error.to_owned(),
|
||||
|
@ -54,7 +54,7 @@ async fn post_login(
|
||||
match body {
|
||||
AuthenticationData::Password(auth_data) => {
|
||||
let user = auth_data.user().unwrap();
|
||||
let user_id = UserId::new(&user, config.server_name())
|
||||
let user_id = UserId::new(user, config.server_name())
|
||||
.ok()
|
||||
.ok_or(AuthenticationError::InvalidUserId)?;
|
||||
|
||||
@ -96,7 +96,7 @@ async fn get_username_available(
|
||||
let username = params
|
||||
.get("username")
|
||||
.ok_or(RegistrationError::MissingUserId)?;
|
||||
let user_id = UserId::new(username, &config.server_name())
|
||||
let user_id = UserId::new(username, config.server_name())
|
||||
.ok()
|
||||
.ok_or(RegistrationError::InvalidUserId)?;
|
||||
let exists = User::exists(&db, &user_id).await?;
|
||||
@ -117,7 +117,7 @@ async fn post_register(
|
||||
let (user, device) = match &body.auth().expect("must be Some") {
|
||||
AuthenticationData::Password(auth_data) => {
|
||||
let username = body.username().ok_or(RegistrationError::MissingUserId)?;
|
||||
let user_id = UserId::new(username, &config.server_name())
|
||||
let user_id = UserId::new(username, config.server_name())
|
||||
.ok()
|
||||
.ok_or(RegistrationError::InvalidUserId)?;
|
||||
|
||||
|
@ -1 +1,106 @@
|
||||
pub mod auth;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
http::{Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
Json,
|
||||
};
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::{models::sessions::Session, types::error_code::ErrorCode};
|
||||
|
||||
use super::errors::ErrorResponse;
|
||||
|
||||
pub mod auth;
|
||||
pub mod thirdparty;
|
||||
|
||||
async fn authentication_middleware<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
let db: &SqlitePool = req.extensions().get().unwrap();
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok());
|
||||
|
||||
if auth_header.is_none() {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Forbidden,
|
||||
"Authorization Header not given",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let auth_header = auth_header.expect("Validated above");
|
||||
let idx = auth_header.find(' ');
|
||||
|
||||
let idx = match idx {
|
||||
Some(idx) => idx,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Forbidden,
|
||||
"Invalid Authorization Header",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let session = match Session::find_by_key(db, &auth_header[idx + 1..]).await {
|
||||
Ok(session) => session,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse::new(
|
||||
ErrorCode::Unknown,
|
||||
"Internal Server Error",
|
||||
None,
|
||||
)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let session = match session {
|
||||
Some(session) => session,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let device = match session.device(db).await {
|
||||
Ok(device) => device,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
let user = match device.user(db).await {
|
||||
Ok(user) => user,
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse::new(ErrorCode::Forbidden, "Forbidden", None)),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(Arc::new(user));
|
||||
|
||||
next.run(req).await.into_response()
|
||||
}
|
||||
|
17
src/api/client_server/r0/thirdparty.rs
Normal file
17
src/api/client_server/r0/thirdparty.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{routing::get, Extension};
|
||||
|
||||
use crate::{api::client_server::errors::api_error::ApiError, models::users::User};
|
||||
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/r0/thirdparty/protocols", get(get_thirdparty_protocols))
|
||||
.layer(axum::middleware::from_fn(super::authentication_middleware))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_thirdparty_protocols(Extension(user): Extension<Arc<User>>) -> Result<String, ApiError> {
|
||||
Ok("{}".into())
|
||||
}
|
11
src/main.rs
11
src/main.rs
@ -16,11 +16,11 @@ use tower_http::{
|
||||
use tracing::Level;
|
||||
|
||||
mod api;
|
||||
mod config;
|
||||
mod models;
|
||||
mod requests;
|
||||
mod responses;
|
||||
mod types;
|
||||
mod config;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
@ -32,7 +32,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let config = Arc::new(Config::default());
|
||||
|
||||
let pool = sqlx::SqlitePool::connect(&config.db_path()).await?;
|
||||
let pool = sqlx::SqlitePool::connect(config.db_path()).await?;
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::Any)
|
||||
@ -43,7 +43,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let client_server = Router::new()
|
||||
.merge(api::client_server::versions::routes())
|
||||
.merge(api::client_server::r0::auth::routes());
|
||||
.merge(api::client_server::r0::auth::routes())
|
||||
.merge(api::client_server::r0::thirdparty::routes());
|
||||
|
||||
let router = Router::new()
|
||||
.nest("/_matrix/client", client_server)
|
||||
@ -60,7 +61,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fallback(request: Request<Body>) -> StatusCode {
|
||||
dbg!(request);
|
||||
async fn fallback(mut request: Request<Body>) -> StatusCode {
|
||||
println!("{} {}", request.method(), request.uri());
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
|
@ -49,6 +49,19 @@ impl Device {
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
"select uuid as 'uuid: Uuid', user_uuid as 'user_uuid: Uuid', device_id, display_name from devices where uuid = ?",
|
||||
uuid)
|
||||
.fetch_one(conn).await?
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn user(&self, conn: &SqlitePool) -> anyhow::Result<User> {
|
||||
User::find_by_uuid(conn, &self.user_uuid).await
|
||||
}
|
||||
|
||||
/// Get the device's id.
|
||||
#[must_use]
|
||||
pub fn uuid(&self) -> &Uuid {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use rand::{thread_rng, Rng, distributions::Alphanumeric};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use sqlx::SqlitePool;
|
||||
|
||||
use crate::types::uuid::Uuid;
|
||||
@ -14,7 +14,7 @@ pub struct Session {
|
||||
impl Session {
|
||||
pub fn new(device: &Device) -> anyhow::Result<Self> {
|
||||
let mut rng = thread_rng();
|
||||
let key: String = (0..32).map(|_| rng.sample(Alphanumeric) as char).collect();
|
||||
let key: String = (0..256).map(|_| rng.sample(Alphanumeric) as char).collect();
|
||||
|
||||
Ok(Self {
|
||||
uuid: uuid::Uuid::new_v4().into(),
|
||||
@ -37,6 +37,21 @@ impl Session {
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn find_by_key(conn: &SqlitePool, key: &str) -> anyhow::Result<Option<Self>> {
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
"select uuid as 'uuid: Uuid', device_uuid as 'device_uuid: Uuid', key
|
||||
from sessions where key = ?",
|
||||
key
|
||||
)
|
||||
.fetch_optional(conn)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn device(&self, conn: &SqlitePool) -> anyhow::Result<Device> {
|
||||
Device::find_by_uuid(conn, &self.device_uuid).await
|
||||
}
|
||||
|
||||
/// Get the session's id.
|
||||
#[must_use]
|
||||
pub fn uuid(&self) -> &Uuid {
|
||||
|
@ -1,10 +1,11 @@
|
||||
use crate::types::uuid::Uuid;
|
||||
use argon2::{password_hash::SaltString, Argon2, PasswordHasher, PasswordHash, PasswordVerifier};
|
||||
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||
use rand::rngs::OsRng;
|
||||
use sqlx::{encode::IsNull, sqlite::SqliteTypeInfo, FromRow, Sqlite, SqlitePool};
|
||||
|
||||
use crate::types::user_id::UserId;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct User {
|
||||
uuid: Uuid,
|
||||
user_id: UserId,
|
||||
@ -68,6 +69,17 @@ impl User {
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn find_by_uuid(conn: &SqlitePool, uuid: &Uuid) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
"select uuid as 'uuid: Uuid', user_id as 'user_id: UserId', display_name, password
|
||||
from users where uuid = ?",
|
||||
uuid
|
||||
)
|
||||
.fetch_one(conn)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn find_by_user_id(conn: &SqlitePool, user_id: &UserId) -> anyhow::Result<Self> {
|
||||
Ok(sqlx::query_as!(
|
||||
Self,
|
||||
@ -82,7 +94,9 @@ impl User {
|
||||
pub fn password_correct(&self, password: &str) -> anyhow::Result<bool> {
|
||||
let password_hash = PasswordHash::new(self.password())?;
|
||||
|
||||
Ok(Argon2::default().verify_password(password.as_bytes(), &password_hash).is_ok())
|
||||
Ok(Argon2::default()
|
||||
.verify_password(password.as_bytes(), &password_hash)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
/// Get the user's id.
|
||||
|
@ -26,7 +26,7 @@ pub struct RegistrationSuccess {
|
||||
impl RegistrationSuccess {
|
||||
pub fn new(access_token: Option<&str>, device_id: &str, user_id: &str) -> Self {
|
||||
Self {
|
||||
access_token: access_token.and_then(|v| Some(v.to_owned())),
|
||||
access_token: access_token.map(|v| v.to_owned()),
|
||||
device_id: device_id.to_owned(),
|
||||
user_id: user_id.to_owned(),
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
|
||||
enum Hostname {
|
||||
IPv4(Ipv4Addr),
|
||||
IPv6(Ipv6Addr),
|
||||
FQDN(String),
|
||||
Fqdn(String),
|
||||
}
|
||||
|
||||
pub struct ServerName {
|
||||
@ -47,12 +47,12 @@ impl ServerName {
|
||||
if let Some(idx) = server_name.find(':') {
|
||||
let (hostname, port) = server_name.split_at(idx);
|
||||
Ok(Self {
|
||||
hostname: Hostname::FQDN(hostname.to_owned()),
|
||||
hostname: Hostname::Fqdn(hostname.to_owned()),
|
||||
port: Some(port[1..].parse()?),
|
||||
})
|
||||
} else {
|
||||
Ok(Self {
|
||||
hostname: Hostname::FQDN(server_name.to_owned()),
|
||||
hostname: Hostname::Fqdn(server_name.to_owned()),
|
||||
port: None,
|
||||
})
|
||||
}
|
||||
@ -76,7 +76,7 @@ impl std::fmt::Display for ServerName {
|
||||
match &self.hostname {
|
||||
Hostname::IPv4(hostname) => write!(f, "{hostname}"),
|
||||
Hostname::IPv6(hostname) => write!(f, "[{hostname}]"),
|
||||
Hostname::FQDN(hostname) => write!(f, "{hostname}"),
|
||||
Hostname::Fqdn(hostname) => write!(f, "{hostname}"),
|
||||
}?;
|
||||
|
||||
if let Some(port) = self.port {
|
||||
@ -134,7 +134,7 @@ mod tests {
|
||||
let server_name = ServerName::new("example.com").unwrap();
|
||||
assert_eq!(
|
||||
server_name.hostname(),
|
||||
&Hostname::FQDN("example.com".into())
|
||||
&Hostname::Fqdn("example.com".into())
|
||||
);
|
||||
assert_eq!(server_name.port(), None);
|
||||
}
|
||||
@ -144,7 +144,7 @@ mod tests {
|
||||
let server_name = ServerName::new("example.com:8080").unwrap();
|
||||
assert_eq!(
|
||||
server_name.hostname(),
|
||||
&Hostname::FQDN("example.com".into())
|
||||
&Hostname::Fqdn("example.com".into())
|
||||
);
|
||||
assert_eq!(server_name.port(), Some(8080));
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use sqlx::{encode::IsNull, Sqlite};
|
||||
|
||||
use super::server_name::ServerName;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct UserId(String);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user