From 30102e5da48b53806b33f04041a46bec4c3b2fa3 Mon Sep 17 00:00:00 2001 From: Kjetil Orbekk Date: Sat, 8 Oct 2022 17:22:48 -0400 Subject: Add token refresh and persist sessions in the db --- server/.env | 2 +- server/Cargo.toml | 3 +- server/migrations/20221008120534_init.up.sql | 3 +- server/src/auth.rs | 210 +++++++++++++++++++++++---- server/src/error.rs | 28 +++- server/src/main.rs | 68 ++++++--- 6 files changed, 260 insertions(+), 54 deletions(-) (limited to 'server') diff --git a/server/.env b/server/.env index 61c7e89..e575250 100644 --- a/server/.env +++ b/server/.env @@ -1,4 +1,4 @@ -RUST_LOG=info +RUST_LOG=info,tower_http=debug,server=debug BIND_ADDRESS=[::]:11121 RUST_BACKTRACE=1 OPENID_ISSUER_URL=https://auth.orbekk.com/realms/test diff --git a/server/Cargo.toml b/server/Cargo.toml index 0651fef..423ada2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -21,8 +21,9 @@ uuid = { version = "1.1.2", features = ["serde", "fast-rng", "v4"] } tower-cookies = "0.7.0" tower = { version = "0.4.13", features = ["full"] } urlencoding = "2.1.2" -sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls" , "postgres" ] } +sqlx = { version = "0.6", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono"] } anyhow = "1.0.65" chrono = { version = "0.4.22", features = ["serde"] } thiserror = "1.0.37" reqwest = "0.11.12" +cookie = "0.16.1" diff --git a/server/migrations/20221008120534_init.up.sql b/server/migrations/20221008120534_init.up.sql index 301d2eb..b3527eb 100644 --- a/server/migrations/20221008120534_init.up.sql +++ b/server/migrations/20221008120534_init.up.sql @@ -3,5 +3,6 @@ create table sessions ( id uuid primary key, access_token varchar(2048) not null, access_token_expiration timestamp with time zone not null, - refresh_token varchar(512) not null + refresh_token varchar(1024) not null, + last_refresh timestamp with time zone not null default now() ); diff --git a/server/src/auth.rs b/server/src/auth.rs index 01ee467..44f16ea 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -2,20 +2,24 @@ use std::{ collections::HashMap, env, num::NonZeroUsize, + str::FromStr, sync::{Arc, Mutex}, }; use crate::error::BridgeError; -use chrono::Utc; +use chrono::{DateTime, Utc}; use lru::LruCache; use openidconnect::{ - core::{CoreClient, CoreProviderMetadata, CoreResponseType}, + core::{CoreClient, CoreProviderMetadata, CoreResponseType, CoreUserInfoClaims}, reqwest::async_http_client, url::Url, - AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, - IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse, + AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, + CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, RefreshToken, Scope, + TokenResponse, }; +use protocol::UserInfo; use serde::{Deserialize, Serialize}; +use sqlx::PgPool; use tracing::info; use uuid::Uuid; @@ -24,22 +28,49 @@ pub struct LoginState { nonce: Nonce, } +#[derive(Debug)] +pub struct AuthenticatedSession { + pub session_id: SessionId, + expiration: DateTime, + access_token: AccessToken, + refresh_token: RefreshToken, +} + pub struct Authenticator { pub client: CoreClient, - pub login_cache: Arc>>, + pub login_cache: Arc>>, + pub db: PgPool, } #[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)] -pub struct EndUserId(Uuid); +pub struct SessionId(Uuid); -impl EndUserId { +impl SessionId { pub fn new() -> Self { Self(Uuid::new_v4()) } } +impl ToString for SessionId { + fn to_string(&self) -> String { + self.0.to_string() + } +} + +impl FromStr for SessionId { + type Err = BridgeError; + + fn from_str(s: &str) -> Result { + Ok(SessionId(Uuid::from_str(s)?)) + } +} + const LOGIN_CACHE_SIZE: usize = 50; +fn token_safe_time() -> chrono::Duration { + chrono::Duration::seconds(30) +} + pub const LOGIN_CALLBACK: &'static str = "/api/login_callback"; fn redirect_url(app_url: &str) -> RedirectUrl { RedirectUrl::new(format!("{}{}", app_url, LOGIN_CALLBACK)).unwrap() @@ -47,6 +78,7 @@ fn redirect_url(app_url: &str) -> RedirectUrl { impl Authenticator { pub async fn new( + db: PgPool, issuer_url: IssuerUrl, client_id: ClientId, client_secret: ClientSecret, @@ -63,6 +95,7 @@ impl Authenticator { .set_redirect_uri(redirect_uri); Self { + db, client, login_cache: Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(LOGIN_CACHE_SIZE).unwrap(), @@ -70,9 +103,10 @@ impl Authenticator { } } - pub async fn from_env() -> Self { + pub async fn from_env(db: PgPool) -> Self { let app_url = env::var("APP_URL").unwrap(); Authenticator::new( + db, IssuerUrl::new(env::var("OPENID_ISSUER_URL").unwrap()).unwrap(), ClientId::new(env::var("OPENID_CLIENT_ID").unwrap()), ClientSecret::new(env::var("OPENID_CLIENT_SECRET").unwrap()), @@ -81,7 +115,7 @@ impl Authenticator { .await } - pub async fn get_login_url(&self) -> (EndUserId, Url) { + pub async fn get_login_url(&self) -> (SessionId, Url) { let (auth_url, csrf_token, nonce) = self .client .authorize_url( @@ -92,7 +126,7 @@ impl Authenticator { .add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("profile".to_string())) .url(); - let user_id = EndUserId::new(); + let user_id = SessionId::new(); self.login_cache .lock() .unwrap() @@ -100,23 +134,41 @@ impl Authenticator { (user_id, auth_url) } + pub async fn maybe_refresh_token( + &self, + session: &mut AuthenticatedSession, + ) -> Result<(), BridgeError> { + if session.expiration > Utc::now() + token_safe_time() { + return Ok(()); + } + info!("Refreshing expiring token: {}", session.expiration); + let new_token = self + .client + .exchange_refresh_token(&session.refresh_token) + .request_async(async_http_client) + .await?; + info!("Got new token: {new_token:#?}"); + if let Some(refresh_token) = new_token.refresh_token() { + session.refresh_token = refresh_token.clone(); + } + session.access_token = new_token.access_token().clone(); + store_authenticated_session(&self.db, session).await?; + Ok(()) + } + pub async fn authenticate( &self, - user_id: EndUserId, + pool: &PgPool, + session_id: SessionId, auth_params: HashMap, - ) -> Result<(), BridgeError> { + ) -> Result { // TODO: If the token is missing from the cache, client should retry logging in. let state = self .login_cache .lock() .unwrap() - .pop(&user_id) + .pop(&session_id) .ok_or(BridgeError::InvalidRequest("token missing".to_string()))?; - info!( - "state: {:?}, {:?}", - state.csrf_token.secret(), - state.nonce.secret() - ); if Some(state.csrf_token.secret()) != auth_params.get("state") { return Err(BridgeError::InvalidRequest( "token validation failed".to_string(), @@ -136,16 +188,124 @@ impl Authenticator { .exchange_code(authorization_code) .request_async(async_http_client) .await?; - info!("Got token {token:#?}"); - let id_token = token - .id_token() - .ok_or(BridgeError::InvalidRequest("Server did not return an IdToken".to_string()))?; + let id_token = token.id_token().ok_or(BridgeError::InvalidRequest( + "Server did not return an IdToken".to_string(), + ))?; let claims = id_token.claims(&self.client.id_token_verifier(), &state.nonce)?; + + // Verify access token hash. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = + AccessTokenHash::from_token(token.access_token(), &id_token.signing_alg()?)?; + if actual_access_token_hash != *expected_access_token_hash { + return Err(BridgeError::InvalidRequest( + "Invalid access token".to_string(), + )); + } + } + + if claims.expiration() < Utc::now() { + return Err(BridgeError::Internal(format!( + "Token expired at {}", + claims.expiration() + ))); + } + + let refresh_token = token + .refresh_token() + .ok_or(BridgeError::Internal("Expected refresh token".to_string()))?; - info!("Got claims {claims:#?}"); + let mut session = AuthenticatedSession { + session_id, + expiration: claims.expiration(), + access_token: token.access_token().clone(), + refresh_token: refresh_token.clone(), + }; + + store_authenticated_session(pool, &mut session).await?; + Ok(session) + } + + pub async fn user_info( + &self, + session: &mut AuthenticatedSession, + ) -> Result { + self.maybe_refresh_token(session).await?; + let user_info: CoreUserInfoClaims = self + .client + .user_info(session.access_token.clone(), None)? + .request_async(async_http_client) + .await?; + info!("Resolved user info: {user_info:#?}"); + Ok(UserInfo { + username: user_info + .preferred_username() + .ok_or(BridgeError::Internal( + "missing preferred username".to_string(), + ))? + .to_string(), + }) + } +} + +async fn store_authenticated_session( + pool: &PgPool, + session: &mut AuthenticatedSession, +) -> Result<(), BridgeError> { + info!( + "Refresh token length: {}", + session.refresh_token.secret().len() + ); + let record = sqlx::query!( + r#" + insert into sessions ( + id, + access_token, + access_token_expiration, + refresh_token + ) values ($1, $2, $3, $4) + on conflict (id) do update set + access_token = EXCLUDED.access_token, + access_token_expiration = EXCLUDED.access_token_expiration, + refresh_token = EXCLUDED.refresh_token, + last_refresh = now() + returning * + "#, + session.session_id.0, + session.access_token.secret(), + session.expiration, + session.refresh_token.secret() + ) + .fetch_one(pool) + .await?; + session.session_id = SessionId(record.id); + session.access_token = AccessToken::new(record.access_token); + session.expiration = record.access_token_expiration; + session.refresh_token = RefreshToken::new(record.refresh_token); + Ok(()) +} - // params: {"session_state": "909b9959-041b-4a98-84d0-5f978bc8a679", "code": "2b4e95d1-0000-4b28-b49d-7a9de731e82b.909b9959-041b-4a98-84d0-5f978bc8a679.a382d869-4e34-42f1-a64d-24a224b9d338", "state": "a7Hff_hF_FOCqPCxmA1ZXg - Err(BridgeError::Internal("todo".to_string())) +pub async fn fetch_authenticated_session( + pool: &PgPool, + session_id: &SessionId, +) -> Result, BridgeError> { + let record = sqlx::query!( + r#" + select * from sessions + where id = $1 + "#, + session_id.0, + ) + .fetch_optional(pool) + .await?; + match record { + None => Ok(None), + Some(record) => Ok(Some(AuthenticatedSession { + session_id: SessionId(record.id), + access_token: AccessToken::new(record.access_token), + expiration: record.access_token_expiration, + refresh_token: RefreshToken::new(record.refresh_token), + })), } } diff --git a/server/src/error.rs b/server/src/error.rs index 439e81b..1a45e96 100644 --- a/server/src/error.rs +++ b/server/src/error.rs @@ -1,5 +1,9 @@ use axum::{http::StatusCode, response::IntoResponse}; -use openidconnect::{core::CoreErrorResponseType, StandardErrorResponse, ClaimsVerificationError}; +use openidconnect::{core::CoreErrorResponseType, ClaimsVerificationError, StandardErrorResponse}; +use tracing::error; + +type UserInfoError = + openidconnect::UserInfoError>; type RequestTokenError = openidconnect::RequestTokenError< openidconnect::reqwest::Error, @@ -11,18 +15,34 @@ pub enum BridgeError { #[error("Invalid request: {0}")] InvalidRequest(String), - #[error("Backend request failed")] - Backend(#[from] RequestTokenError), + #[error("Requesting token failed")] + OpenidRequestTokenError(#[from] RequestTokenError), + + #[error("Requesting user info failed")] + OpenidUserInfoError(#[from] UserInfoError), + + #[error("Failed to configure OpenId request")] + OpenIdConfigurationError(#[from] openidconnect::ConfigurationError), #[error("Unexpected authorization error")] UnexpectedInvalidAuthorization(#[from] ClaimsVerificationError), - + + #[error("Authentication error")] + SigningFailed(#[from] openidconnect::SigningError), + + #[error("Database error")] + SqlxError(#[from] sqlx::Error), + + #[error("Uuid parse failed")] + UuidError(#[from] uuid::Error), + #[error("Internal server error: {0}")] Internal(String), } impl IntoResponse for BridgeError { fn into_response(self) -> axum::response::Response { + error!("Error occurred: {self:?}"); (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {self}")).into_response() } } diff --git a/server/src/main.rs b/server/src/main.rs index 4183abb..87f95e4 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,10 +1,10 @@ -use std::{collections::HashMap, env, sync::Arc}; +use std::{collections::HashMap, env, str::FromStr, sync::Arc}; use axum::{ extract::{Extension, Query}, - response::{Redirect, IntoResponse}, + response::Redirect, routing::get, - Json, Router, http::StatusCode, + Json, Router, }; use protocol::UserInfo; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; @@ -13,9 +13,9 @@ use tracing::info; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod auth; mod error; -use crate::auth::{Authenticator, EndUserId}; -use sqlx::{postgres::PgPoolOptions, PgPool}; +use crate::auth::{Authenticator, SessionId}; use crate::error::BridgeError; +use sqlx::{postgres::PgPoolOptions, PgPool}; pub struct ServerContext { pub app_url: String, @@ -39,7 +39,9 @@ async fn main() { let db_url = env::var("DATABASE_URL").unwrap(); let db_pool: PgPool = PgPoolOptions::new() .max_connections(10) - .connect(&db_url).await.expect("db connection"); + .connect(&db_url) + .await + .expect("db connection"); info!("Running db migrations"); sqlx::migrate!().run(&db_pool).await.expect("db migration"); @@ -51,8 +53,8 @@ async fn main() { let state = Arc::new(ServerContext { app_url: app_url, - authenticator: Authenticator::from_env().await, - db: db_pool, + authenticator: Authenticator::from_env(db_pool.clone()).await, + db: db_pool, }); let app = Router::new() @@ -69,32 +71,54 @@ async fn main() { .unwrap(); } -async fn user_info() -> Json> { - Json(None) +async fn user_info( + cookies: Cookies, + extension: ContextExtension, +) -> Result>, BridgeError> { + let cookie = match cookies.get("user-id") { + None => return Ok(Json(None)), + Some(v) => v, + }; + + let session_id: SessionId = match SessionId::from_str(cookie.value()) { + Err(e) => { + info!("Clearing cookie that failed to parse {cookie:?}: {e}"); + cookies.remove(cookie.into_owned()); + return Ok(Json(None)); + } + Ok(s) => s, + }; + let mut session = + match crate::auth::fetch_authenticated_session(&extension.db, &session_id).await? { + None => return Ok(Json(None)), + Some(v) => v, + }; + Ok(Json(Some(extension.authenticator.user_info(&mut session).await?))) } async fn login_callback( cookies: Cookies, Query(params): Query>, extension: ContextExtension, -) -> Result<(), BridgeError> { +) -> Result { let cookie = cookies.get("user-id").unwrap(); - let user_id: EndUserId = - serde_json::from_str(&urlencoding::decode(cookie.value()).unwrap()).unwrap(); - info!("cookie: {cookie:?}"); - info!("params: {params:?}"); - extension.authenticator.authenticate(user_id, params).await?; - Ok(()) + let user_id: SessionId = SessionId::from_str(cookie.value())?; + let session = extension + .authenticator + .authenticate(&extension.db, user_id, params) + .await?; + info!("Logged in session: {session:?}"); + Ok(Redirect::temporary(&extension.app_url)) } async fn login(cookies: Cookies, extension: ContextExtension) -> Redirect { let (user_id, auth_url) = extension.authenticator.get_login_url().await; info!("Creating auth url for {user_id:?}"); let user_id = serde_json::to_string(&user_id).unwrap(); - cookies.add(Cookie::new( - "user-id", - urlencoding::encode(&user_id).to_string(), - )); + let mut cookie = Cookie::new("user-id", user_id.to_string()); + cookie.set_http_only(true); + cookie.set_secure(true); + cookie.set_same_site(cookie::SameSite::Lax); + cookies.add(cookie); Redirect::temporary(auth_url.as_str()) } - -- cgit v1.2.3