diff options
author | Kjetil Orbekk <kj@orbekk.com> | 2022-10-08 17:22:48 -0400 |
---|---|---|
committer | Kjetil Orbekk <kj@orbekk.com> | 2022-10-08 17:22:48 -0400 |
commit | 30102e5da48b53806b33f04041a46bec4c3b2fa3 (patch) | |
tree | cf9fd3ce1f8c449cb4cb1b8837015c7b514b916b /server/src/auth.rs | |
parent | 1cbf881835fc33859a31645f886c5d3787ed48f8 (diff) |
Add token refresh and persist sessions in the db
Diffstat (limited to 'server/src/auth.rs')
-rw-r--r-- | server/src/auth.rs | 210 |
1 files changed, 185 insertions, 25 deletions
diff --git a/server/src/auth.rs b/server/src/auth.rs index 01ee467..44f16ea 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -2,20 +2,24 @@ use std::{ collections::HashMap, env, num::NonZeroUsize, + str::FromStr, sync::{Arc, Mutex}, }; use crate::error::BridgeError; -use chrono::Utc; +use chrono::{DateTime, Utc}; use lru::LruCache; use openidconnect::{ - core::{CoreClient, CoreProviderMetadata, CoreResponseType}, + core::{CoreClient, CoreProviderMetadata, CoreResponseType, CoreUserInfoClaims}, reqwest::async_http_client, url::Url, - AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, - IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, + AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, + CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, RefreshToken, Scope, + TokenResponse, }; +use protocol::UserInfo; use serde::{Deserialize, Serialize}; +use sqlx::PgPool; use tracing::info; use uuid::Uuid; @@ -24,22 +28,49 @@ pub struct LoginState { nonce: Nonce, } +#[derive(Debug)] +pub struct AuthenticatedSession { + pub session_id: SessionId, + expiration: DateTime<Utc>, + access_token: AccessToken, + refresh_token: RefreshToken, +} + pub struct Authenticator { pub client: CoreClient, - pub login_cache: Arc<Mutex<LruCache<EndUserId, LoginState>>>, + pub login_cache: Arc<Mutex<LruCache<SessionId, LoginState>>>, + pub db: PgPool, } #[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)] -pub struct EndUserId(Uuid); +pub struct SessionId(Uuid); -impl EndUserId { +impl SessionId { pub fn new() -> Self { Self(Uuid::new_v4()) } } +impl ToString for SessionId { + fn to_string(&self) -> String { + self.0.to_string() + } +} + +impl FromStr for SessionId { + type Err = BridgeError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + Ok(SessionId(Uuid::from_str(s)?)) + } +} + const LOGIN_CACHE_SIZE: usize = 50; +fn token_safe_time() -> chrono::Duration { + chrono::Duration::seconds(30) +} + pub const LOGIN_CALLBACK: &'static str = "/api/login_callback"; fn redirect_url(app_url: &str) -> RedirectUrl { RedirectUrl::new(format!("{}{}", app_url, LOGIN_CALLBACK)).unwrap() @@ -47,6 +78,7 @@ fn redirect_url(app_url: &str) -> RedirectUrl { impl Authenticator { pub async fn new( + db: PgPool, issuer_url: IssuerUrl, client_id: ClientId, client_secret: ClientSecret, @@ -63,6 +95,7 @@ impl Authenticator { .set_redirect_uri(redirect_uri); Self { + db, client, login_cache: Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(LOGIN_CACHE_SIZE).unwrap(), @@ -70,9 +103,10 @@ impl Authenticator { } } - pub async fn from_env() -> Self { + pub async fn from_env(db: PgPool) -> Self { let app_url = env::var("APP_URL").unwrap(); Authenticator::new( + db, IssuerUrl::new(env::var("OPENID_ISSUER_URL").unwrap()).unwrap(), ClientId::new(env::var("OPENID_CLIENT_ID").unwrap()), ClientSecret::new(env::var("OPENID_CLIENT_SECRET").unwrap()), @@ -81,7 +115,7 @@ impl Authenticator { .await } - pub async fn get_login_url(&self) -> (EndUserId, Url) { + pub async fn get_login_url(&self) -> (SessionId, Url) { let (auth_url, csrf_token, nonce) = self .client .authorize_url( @@ -92,7 +126,7 @@ impl Authenticator { .add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("profile".to_string())) .url(); - let user_id = EndUserId::new(); + let user_id = SessionId::new(); self.login_cache .lock() .unwrap() @@ -100,23 +134,41 @@ impl Authenticator { (user_id, auth_url) } + pub async fn maybe_refresh_token( + &self, + session: &mut AuthenticatedSession, + ) -> Result<(), BridgeError> { + if session.expiration > Utc::now() + token_safe_time() { + return Ok(()); + } + info!("Refreshing expiring token: {}", session.expiration); + let new_token = self + .client + .exchange_refresh_token(&session.refresh_token) + .request_async(async_http_client) + .await?; + info!("Got new token: {new_token:#?}"); + if let Some(refresh_token) = new_token.refresh_token() { + session.refresh_token = refresh_token.clone(); + } + session.access_token = new_token.access_token().clone(); + store_authenticated_session(&self.db, session).await?; + Ok(()) + } + pub async fn authenticate( &self, - user_id: EndUserId, + pool: &PgPool, + session_id: SessionId, auth_params: HashMap<String, String>, - ) -> Result<(), BridgeError> { + ) -> Result<AuthenticatedSession, BridgeError> { // TODO: If the token is missing from the cache, client should retry logging in. let state = self .login_cache .lock() .unwrap() - .pop(&user_id) + .pop(&session_id) .ok_or(BridgeError::InvalidRequest("token missing".to_string()))?; - info!( - "state: {:?}, {:?}", - state.csrf_token.secret(), - state.nonce.secret() - ); if Some(state.csrf_token.secret()) != auth_params.get("state") { return Err(BridgeError::InvalidRequest( "token validation failed".to_string(), @@ -136,16 +188,124 @@ impl Authenticator { .exchange_code(authorization_code) .request_async(async_http_client) .await?; - info!("Got token {token:#?}"); - let id_token = token - .id_token() - .ok_or(BridgeError::InvalidRequest("Server did not return an IdToken".to_string()))?; + let id_token = token.id_token().ok_or(BridgeError::InvalidRequest( + "Server did not return an IdToken".to_string(), + ))?; let claims = id_token.claims(&self.client.id_token_verifier(), &state.nonce)?; + + // Verify access token hash. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = + AccessTokenHash::from_token(token.access_token(), &id_token.signing_alg()?)?; + if actual_access_token_hash != *expected_access_token_hash { + return Err(BridgeError::InvalidRequest( + "Invalid access token".to_string(), + )); + } + } + + if claims.expiration() < Utc::now() { + return Err(BridgeError::Internal(format!( + "Token expired at {}", + claims.expiration() + ))); + } + + let refresh_token = token + .refresh_token() + .ok_or(BridgeError::Internal("Expected refresh token".to_string()))?; - info!("Got claims {claims:#?}"); + let mut session = AuthenticatedSession { + session_id, + expiration: claims.expiration(), + access_token: token.access_token().clone(), + refresh_token: refresh_token.clone(), + }; + + store_authenticated_session(pool, &mut session).await?; + Ok(session) + } + + pub async fn user_info( + &self, + session: &mut AuthenticatedSession, + ) -> Result<UserInfo, BridgeError> { + self.maybe_refresh_token(session).await?; + let user_info: CoreUserInfoClaims = self + .client + .user_info(session.access_token.clone(), None)? + .request_async(async_http_client) + .await?; + info!("Resolved user info: {user_info:#?}"); + Ok(UserInfo { + username: user_info + .preferred_username() + .ok_or(BridgeError::Internal( + "missing preferred username".to_string(), + ))? + .to_string(), + }) + } +} + +async fn store_authenticated_session( + pool: &PgPool, + session: &mut AuthenticatedSession, +) -> Result<(), BridgeError> { + info!( + "Refresh token length: {}", + session.refresh_token.secret().len() + ); + let record = sqlx::query!( + r#" + insert into sessions ( + id, + access_token, + access_token_expiration, + refresh_token + ) values ($1, $2, $3, $4) + on conflict (id) do update set + access_token = EXCLUDED.access_token, + access_token_expiration = EXCLUDED.access_token_expiration, + refresh_token = EXCLUDED.refresh_token, + last_refresh = now() + returning * + "#, + session.session_id.0, + session.access_token.secret(), + session.expiration, + session.refresh_token.secret() + ) + .fetch_one(pool) + .await?; + session.session_id = SessionId(record.id); + session.access_token = AccessToken::new(record.access_token); + session.expiration = record.access_token_expiration; + session.refresh_token = RefreshToken::new(record.refresh_token); + Ok(()) +} - // params: {"session_state": "909b9959-041b-4a98-84d0-5f978bc8a679", "code": "2b4e95d1-0000-4b28-b49d-7a9de731e82b.909b9959-041b-4a98-84d0-5f978bc8a679.a382d869-4e34-42f1-a64d-24a224b9d338", "state": "a7Hff_hF_FOCqPCxmA1ZXg - Err(BridgeError::Internal("todo".to_string())) +pub async fn fetch_authenticated_session( + pool: &PgPool, + session_id: &SessionId, +) -> Result<Option<AuthenticatedSession>, BridgeError> { + let record = sqlx::query!( + r#" + select * from sessions + where id = $1 + "#, + session_id.0, + ) + .fetch_optional(pool) + .await?; + match record { + None => Ok(None), + Some(record) => Ok(Some(AuthenticatedSession { + session_id: SessionId(record.id), + access_token: AccessToken::new(record.access_token), + expiration: record.access_token_expiration, + refresh_token: RefreshToken::new(record.refresh_token), + })), } } |