use std::{collections::HashMap, env, str::FromStr, sync::Arc}; use auth::{AuthenticatedSession, LoggedInUser}; use axum::{ extract::{Extension, Query}, response::Redirect, routing::get, Json, Router, }; use protocol::{Table, UserInfo}; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use tower_http::trace::TraceLayer; use tracing::info; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod auth; mod error; use crate::auth::{Authenticator, SessionId}; use crate::error::BridgeError; use sqlx::{postgres::PgPoolOptions, PgPool}; pub struct ServerContext { pub app_url: String, pub authenticator: Authenticator, pub db: PgPool, } type ContextExtension = Extension>; #[tokio::main] async fn main() { dotenv::dotenv().ok(); tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( std::env::var("RUST_LOG").unwrap_or_else(|_| "".into()), )) .with(tracing_subscriber::fmt::layer()) .init(); info!("Opening database connection"); let db_url = env::var("DATABASE_URL").unwrap(); let db_pool: PgPool = PgPoolOptions::new() .max_connections(10) .connect(&db_url) .await .expect("db connection"); info!("Running db migrations"); sqlx::migrate!().run(&db_pool).await.expect("db migration"); let bind_address = env::var("BIND_ADDRESS").unwrap(); info!("Starting server on {}", bind_address); let app_url = env::var("APP_URL").unwrap(); let state = Arc::new(ServerContext { app_url, authenticator: Authenticator::from_env(db_pool.clone()).await, db: db_pool, }); let app = Router::new() .route("/api/user/info", get(user_info)) .route("/api/login", get(login)) .route(auth::LOGIN_CALLBACK, get(login_callback)) .layer(CookieManagerLayer::new()) .layer(Extension(state)) .layer(TraceLayer::new_for_http()); axum::Server::bind(&bind_address.parse().unwrap()) .serve(app.into_make_service()) .await .unwrap(); } async fn user_info( _user: LoggedInUser, cookies: Cookies, extension: ContextExtension, ) -> Result>, BridgeError> { let cookie = match cookies.get("user-id") { None => return Ok(Json(None)), Some(v) => v, }; let session_id: SessionId = match SessionId::from_str(cookie.value()) { Err(e) => { info!("Clearing cookie that failed to parse {cookie:?}: {e}"); cookies.remove(cookie.into_owned()); return Ok(Json(None)); } Ok(s) => s, }; let mut session = match crate::auth::fetch_authenticated_session(&extension.db, &session_id).await? { None => return Ok(Json(None)), Some(v) => v, }; Ok(Json(Some(UserInfo { username: extension.authenticator.user_info(&mut session).await?, table: get_table(&extension.db, &session).await?, }))) } async fn get_table( db: &PgPool, session: &AuthenticatedSession, ) -> Result, BridgeError> { Ok(sqlx::query_as!( Table, r#" select tables.id from table_players players natural join active_tables tables where player_id = $1 "#, session.player_id ) .fetch_optional(db) .await?) } async fn login_callback( cookies: Cookies, Query(params): Query>, extension: ContextExtension, ) -> Result { let cookie = cookies.get("user-id").unwrap(); let user_id: SessionId = SessionId::from_str(cookie.value())?; let session = extension .authenticator .authenticate(&extension.db, user_id, params) .await?; info!("Logged in session: {session:?}"); Ok(Redirect::temporary(&extension.app_url)) } async fn login(cookies: Cookies, extension: ContextExtension) -> Redirect { let (user_id, auth_url) = extension.authenticator.get_login_url().await; info!("Creating auth url for {user_id:?}"); let user_id = serde_json::to_string(&user_id).unwrap(); let mut cookie = Cookie::new("user-id", user_id.to_string()); cookie.set_http_only(true); cookie.set_secure(true); cookie.set_same_site(cookie::SameSite::Lax); cookies.add(cookie); Redirect::temporary(auth_url.as_str()) }