use std::{ collections::HashMap, env, num::NonZeroUsize, str::FromStr, sync::{Arc, Mutex}, }; use crate::{ error::BridgeError, server::ServerContext, }; use async_trait::async_trait; use axum::{ extract::{FromRef, FromRequestParts}, http::request::Parts, response::{IntoResponse, Response}, }; use chrono::{DateTime, Utc}; use lru::LruCache; use openidconnect::{ core::{ CoreClient, CoreProviderMetadata, CoreResponseType, CoreUserInfoClaims, }, reqwest::async_http_client, url::Url, AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, RefreshToken, Scope, TokenResponse, }; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use tower_cookies::Cookies; use tracing::{debug, error, info}; use uuid::Uuid; pub struct LoginState { csrf_token: CsrfToken, nonce: Nonce, } #[derive(Debug)] pub struct AuthenticatedSession { pub player_id: String, pub session_id: SessionId, expiration: DateTime, access_token: AccessToken, refresh_token: RefreshToken, } impl AuthenticatedSession { pub fn new( player_id: String, session_id: SessionId, expiration: DateTime, access_token: AccessToken, refresh_token: RefreshToken, ) -> Self { Self { player_id, session_id, expiration, access_token, refresh_token, } } } #[async_trait] pub trait Authenticator { async fn user_info( &self, session: &mut AuthenticatedSession, ) -> Result; async fn authenticate( &self, pool: &PgPool, session_id: SessionId, auth_params: HashMap, ) -> Result; async fn get_login_url(&self) -> (SessionId, Url); } pub struct OauthAuthenticator { pub client: CoreClient, pub login_cache: Arc>>, pub db: PgPool, } #[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)] pub struct SessionId(Uuid); impl SessionId { pub fn new() -> Self { Self(Uuid::new_v4()) } } impl Default for SessionId { fn default() -> Self { Self::new() } } impl ToString for SessionId { fn to_string(&self) -> String { self.0.to_string() } } impl FromStr for SessionId { type Err = BridgeError; fn from_str(s: &str) -> Result { Ok(SessionId(Uuid::from_str(s)?)) } } const LOGIN_CACHE_SIZE: usize = 50; fn token_safe_time() -> chrono::Duration { chrono::Duration::seconds(30) } pub const LOGIN_CALLBACK: &str = "/api/login_callback"; fn redirect_url(app_url: &str) -> RedirectUrl { RedirectUrl::new(format!("{}{}", app_url, LOGIN_CALLBACK)).unwrap() } impl OauthAuthenticator { pub async fn new( db: PgPool, issuer_url: IssuerUrl, client_id: ClientId, client_secret: ClientSecret, redirect_uri: RedirectUrl, ) -> Self { // Use OpenID Connect Discovery to fetch the provider metadata. let provider_metadata = CoreProviderMetadata::discover_async(issuer_url, async_http_client) .await .unwrap(); let client = CoreClient::from_provider_metadata( provider_metadata, client_id, Some(client_secret), ) // Set the URL the user will be redirected to after the authorization process. .set_redirect_uri(redirect_uri); Self { db, client, login_cache: Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(LOGIN_CACHE_SIZE).unwrap(), ))), } } pub async fn from_env(db: PgPool) -> Self { let app_url = env::var("APP_URL").unwrap(); OauthAuthenticator::new( db, IssuerUrl::new(env::var("OPENID_ISSUER_URL").unwrap()).unwrap(), ClientId::new(env::var("OPENID_CLIENT_ID").unwrap()), ClientSecret::new(env::var("OPENID_CLIENT_SECRET").unwrap()), redirect_url(&app_url), ) .await } async fn maybe_refresh_token( &self, session: &mut AuthenticatedSession, ) -> Result<(), BridgeError> { if session.expiration > Utc::now() + token_safe_time() { return Ok(()); } info!("Refreshing expiring token: {}", session.expiration); let refresh_start = Utc::now(); let new_token = self .client .exchange_refresh_token(&session.refresh_token) .request_async(async_http_client) .await?; debug!("Got new token: {new_token:#?}"); // TODO: Validate token? if let Some(expires_in) = new_token.expires_in() { session.expiration = refresh_start + chrono::Duration::from_std(expires_in)?; } else { error!( "Token is missing expiration! Will refresh token every time." ); } if let Some(refresh_token) = new_token.refresh_token() { session.refresh_token = refresh_token.clone(); } session.access_token = new_token.access_token().clone(); store_authenticated_session(&self.db, session).await?; Ok(()) } } #[async_trait] impl Authenticator for OauthAuthenticator { async fn get_login_url(&self) -> (SessionId, Url) { let (auth_url, csrf_token, nonce) = self .client .authorize_url( AuthenticationFlow::::AuthorizationCode, CsrfToken::new_random, Nonce::new_random, ) .add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("profile".to_string())) .url(); let user_id = SessionId::new(); self.login_cache .lock() .unwrap() .put(user_id.clone(), LoginState { csrf_token, nonce }); (user_id, auth_url) } async fn user_info( &self, session: &mut AuthenticatedSession, ) -> Result { self.maybe_refresh_token(session).await?; let user_info: CoreUserInfoClaims = self .client .user_info(session.access_token.clone(), None)? .request_async(async_http_client) .await?; debug!("Resolved user info: {user_info:#?}"); Ok(user_info .preferred_username() .ok_or(BridgeError::Internal( "missing preferred username".to_string(), ))? .to_string()) } async fn authenticate( &self, pool: &PgPool, session_id: SessionId, auth_params: HashMap, ) -> Result { // TODO: If the token is missing from the cache, client should retry logging in. let state = self.login_cache.lock().unwrap().pop(&session_id).ok_or( BridgeError::InvalidRequest("token missing".to_string()), )?; if Some(state.csrf_token.secret()) != auth_params.get("state") { return Err(BridgeError::InvalidRequest( "token validation failed".to_string(), )); } let authorization_code = AuthorizationCode::new( auth_params .get("code") .ok_or(BridgeError::InvalidRequest( "missing 'code' param".to_string(), ))? .to_string(), ); let token = self .client .exchange_code(authorization_code) .request_async(async_http_client) .await?; let id_token = token.id_token().ok_or(BridgeError::InvalidRequest( "Server did not return an IdToken".to_string(), ))?; let claims = id_token.claims(&self.client.id_token_verifier(), &state.nonce)?; // Verify access token hash. if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token( token.access_token(), &id_token.signing_alg()?, )?; if actual_access_token_hash != *expected_access_token_hash { return Err(BridgeError::InvalidRequest( "Invalid access token".to_string(), )); } } if claims.expiration() < Utc::now() { return Err(BridgeError::Internal(format!( "Token expired at {}", claims.expiration() ))); } let refresh_token = token.refresh_token().ok_or( BridgeError::Internal("Expected refresh token".to_string()), )?; let mut session = AuthenticatedSession { player_id: claims.subject().to_string(), session_id, expiration: claims.expiration(), access_token: token.access_token().clone(), refresh_token: refresh_token.clone(), }; store_authenticated_session(pool, &mut session).await?; Ok(session) } } pub async fn store_authenticated_session( pool: &PgPool, session: &mut AuthenticatedSession, ) -> Result<(), BridgeError> { debug!( "Refresh token length: {}", session.refresh_token.secret().len() ); sqlx::query!( r#" insert into players (id) values ($1) on conflict do nothing "#, session.player_id ) .execute(pool) .await?; let record = sqlx::query!( r#" insert into sessions ( id, player_id, access_token, access_token_expiration, refresh_token ) values ($1, $2, $3, $4, $5) on conflict (id) do update set access_token = EXCLUDED.access_token, access_token_expiration = EXCLUDED.access_token_expiration, refresh_token = EXCLUDED.refresh_token, last_refresh = now() returning * "#, session.session_id.0, session.player_id, session.access_token.secret(), session.expiration, session.refresh_token.secret() ) .fetch_one(pool) .await?; session.player_id = record.player_id; session.session_id = SessionId(record.id); session.access_token = AccessToken::new(record.access_token); session.expiration = record.access_token_expiration; session.refresh_token = RefreshToken::new(record.refresh_token); Ok(()) } pub async fn fetch_authenticated_session( pool: &PgPool, session_id: &SessionId, ) -> Result, BridgeError> { let record = sqlx::query!( r#" select * from sessions where id = $1 "#, session_id.0, ) .fetch_optional(pool) .await?; match record { None => Ok(None), Some(record) => Ok(Some(AuthenticatedSession { player_id: record.player_id, session_id: SessionId(record.id), access_token: AccessToken::new(record.access_token), expiration: record.access_token_expiration, refresh_token: RefreshToken::new(record.refresh_token), })), } } #[async_trait] impl FromRequestParts for AuthenticatedSession where S: Send + Sync, Arc: FromRef, { type Rejection = Response; async fn from_request_parts( parts: &mut Parts, state: &S, ) -> Result { let cookies = Cookies::from_request_parts(parts, state) .await .map_err(|e| e.into_response())?; let state = Arc::::from_ref(state); let cookie = match cookies.get("user-id") { None => return Err(BridgeError::NotLoggedIn.into_response()), Some(v) => v, }; let session_id: SessionId = match SessionId::from_str(cookie.value()) { Err(e) => { info!("Clearing cookie that failed to parse {cookie:?}: {e}"); cookies.remove(cookie.into_owned()); return Err(BridgeError::NotLoggedIn.into_response()); } Ok(s) => s, }; let session = match crate::auth::fetch_authenticated_session( &state.db, &session_id, ) .await .map_err(|e| e.into_response())? { None => return Err(BridgeError::NotLoggedIn.into_response()), Some(v) => v, }; Ok(session) } }