summaryrefslogtreecommitdiff
path: root/server/src/auth.rs
diff options
context:
space:
mode:
authorKjetil Orbekk <kj@orbekk.com>2022-10-08 17:22:48 -0400
committerKjetil Orbekk <kj@orbekk.com>2022-10-08 17:22:48 -0400
commit30102e5da48b53806b33f04041a46bec4c3b2fa3 (patch)
treecf9fd3ce1f8c449cb4cb1b8837015c7b514b916b /server/src/auth.rs
parent1cbf881835fc33859a31645f886c5d3787ed48f8 (diff)
Add token refresh and persist sessions in the db
Diffstat (limited to 'server/src/auth.rs')
-rw-r--r--server/src/auth.rs210
1 files changed, 185 insertions, 25 deletions
diff --git a/server/src/auth.rs b/server/src/auth.rs
index 01ee467..44f16ea 100644
--- a/server/src/auth.rs
+++ b/server/src/auth.rs
@@ -2,20 +2,24 @@ use std::{
collections::HashMap,
env,
num::NonZeroUsize,
+ str::FromStr,
sync::{Arc, Mutex},
};
use crate::error::BridgeError;
-use chrono::Utc;
+use chrono::{DateTime, Utc};
use lru::LruCache;
use openidconnect::{
- core::{CoreClient, CoreProviderMetadata, CoreResponseType},
+ core::{CoreClient, CoreProviderMetadata, CoreResponseType, CoreUserInfoClaims},
reqwest::async_http_client,
url::Url,
- AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
- IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse,
+ AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret,
+ CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, RefreshToken, Scope,
+ TokenResponse,
};
+use protocol::UserInfo;
use serde::{Deserialize, Serialize};
+use sqlx::PgPool;
use tracing::info;
use uuid::Uuid;
@@ -24,22 +28,49 @@ pub struct LoginState {
nonce: Nonce,
}
+#[derive(Debug)]
+pub struct AuthenticatedSession {
+ pub session_id: SessionId,
+ expiration: DateTime<Utc>,
+ access_token: AccessToken,
+ refresh_token: RefreshToken,
+}
+
pub struct Authenticator {
pub client: CoreClient,
- pub login_cache: Arc<Mutex<LruCache<EndUserId, LoginState>>>,
+ pub login_cache: Arc<Mutex<LruCache<SessionId, LoginState>>>,
+ pub db: PgPool,
}
#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
-pub struct EndUserId(Uuid);
+pub struct SessionId(Uuid);
-impl EndUserId {
+impl SessionId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
+impl ToString for SessionId {
+ fn to_string(&self) -> String {
+ self.0.to_string()
+ }
+}
+
+impl FromStr for SessionId {
+ type Err = BridgeError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ Ok(SessionId(Uuid::from_str(s)?))
+ }
+}
+
const LOGIN_CACHE_SIZE: usize = 50;
+fn token_safe_time() -> chrono::Duration {
+ chrono::Duration::seconds(30)
+}
+
pub const LOGIN_CALLBACK: &'static str = "/api/login_callback";
fn redirect_url(app_url: &str) -> RedirectUrl {
RedirectUrl::new(format!("{}{}", app_url, LOGIN_CALLBACK)).unwrap()
@@ -47,6 +78,7 @@ fn redirect_url(app_url: &str) -> RedirectUrl {
impl Authenticator {
pub async fn new(
+ db: PgPool,
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
@@ -63,6 +95,7 @@ impl Authenticator {
.set_redirect_uri(redirect_uri);
Self {
+ db,
client,
login_cache: Arc::new(Mutex::new(LruCache::new(
NonZeroUsize::new(LOGIN_CACHE_SIZE).unwrap(),
@@ -70,9 +103,10 @@ impl Authenticator {
}
}
- pub async fn from_env() -> Self {
+ pub async fn from_env(db: PgPool) -> Self {
let app_url = env::var("APP_URL").unwrap();
Authenticator::new(
+ db,
IssuerUrl::new(env::var("OPENID_ISSUER_URL").unwrap()).unwrap(),
ClientId::new(env::var("OPENID_CLIENT_ID").unwrap()),
ClientSecret::new(env::var("OPENID_CLIENT_SECRET").unwrap()),
@@ -81,7 +115,7 @@ impl Authenticator {
.await
}
- pub async fn get_login_url(&self) -> (EndUserId, Url) {
+ pub async fn get_login_url(&self) -> (SessionId, Url) {
let (auth_url, csrf_token, nonce) = self
.client
.authorize_url(
@@ -92,7 +126,7 @@ impl Authenticator {
.add_scope(Scope::new("email".to_string()))
.add_scope(Scope::new("profile".to_string()))
.url();
- let user_id = EndUserId::new();
+ let user_id = SessionId::new();
self.login_cache
.lock()
.unwrap()
@@ -100,23 +134,41 @@ impl Authenticator {
(user_id, auth_url)
}
+ pub async fn maybe_refresh_token(
+ &self,
+ session: &mut AuthenticatedSession,
+ ) -> Result<(), BridgeError> {
+ if session.expiration > Utc::now() + token_safe_time() {
+ return Ok(());
+ }
+ info!("Refreshing expiring token: {}", session.expiration);
+ let new_token = self
+ .client
+ .exchange_refresh_token(&session.refresh_token)
+ .request_async(async_http_client)
+ .await?;
+ info!("Got new token: {new_token:#?}");
+ if let Some(refresh_token) = new_token.refresh_token() {
+ session.refresh_token = refresh_token.clone();
+ }
+ session.access_token = new_token.access_token().clone();
+ store_authenticated_session(&self.db, session).await?;
+ Ok(())
+ }
+
pub async fn authenticate(
&self,
- user_id: EndUserId,
+ pool: &PgPool,
+ session_id: SessionId,
auth_params: HashMap<String, String>,
- ) -> Result<(), BridgeError> {
+ ) -> Result<AuthenticatedSession, BridgeError> {
// TODO: If the token is missing from the cache, client should retry logging in.
let state = self
.login_cache
.lock()
.unwrap()
- .pop(&user_id)
+ .pop(&session_id)
.ok_or(BridgeError::InvalidRequest("token missing".to_string()))?;
- info!(
- "state: {:?}, {:?}",
- state.csrf_token.secret(),
- state.nonce.secret()
- );
if Some(state.csrf_token.secret()) != auth_params.get("state") {
return Err(BridgeError::InvalidRequest(
"token validation failed".to_string(),
@@ -136,16 +188,124 @@ impl Authenticator {
.exchange_code(authorization_code)
.request_async(async_http_client)
.await?;
- info!("Got token {token:#?}");
- let id_token = token
- .id_token()
- .ok_or(BridgeError::InvalidRequest("Server did not return an IdToken".to_string()))?;
+ let id_token = token.id_token().ok_or(BridgeError::InvalidRequest(
+ "Server did not return an IdToken".to_string(),
+ ))?;
let claims = id_token.claims(&self.client.id_token_verifier(), &state.nonce)?;
+
+ // Verify access token hash.
+ if let Some(expected_access_token_hash) = claims.access_token_hash() {
+ let actual_access_token_hash =
+ AccessTokenHash::from_token(token.access_token(), &id_token.signing_alg()?)?;
+ if actual_access_token_hash != *expected_access_token_hash {
+ return Err(BridgeError::InvalidRequest(
+ "Invalid access token".to_string(),
+ ));
+ }
+ }
+
+ if claims.expiration() < Utc::now() {
+ return Err(BridgeError::Internal(format!(
+ "Token expired at {}",
+ claims.expiration()
+ )));
+ }
+
+ let refresh_token = token
+ .refresh_token()
+ .ok_or(BridgeError::Internal("Expected refresh token".to_string()))?;
- info!("Got claims {claims:#?}");
+ let mut session = AuthenticatedSession {
+ session_id,
+ expiration: claims.expiration(),
+ access_token: token.access_token().clone(),
+ refresh_token: refresh_token.clone(),
+ };
+
+ store_authenticated_session(pool, &mut session).await?;
+ Ok(session)
+ }
+
+ pub async fn user_info(
+ &self,
+ session: &mut AuthenticatedSession,
+ ) -> Result<UserInfo, BridgeError> {
+ self.maybe_refresh_token(session).await?;
+ let user_info: CoreUserInfoClaims = self
+ .client
+ .user_info(session.access_token.clone(), None)?
+ .request_async(async_http_client)
+ .await?;
+ info!("Resolved user info: {user_info:#?}");
+ Ok(UserInfo {
+ username: user_info
+ .preferred_username()
+ .ok_or(BridgeError::Internal(
+ "missing preferred username".to_string(),
+ ))?
+ .to_string(),
+ })
+ }
+}
+
+async fn store_authenticated_session(
+ pool: &PgPool,
+ session: &mut AuthenticatedSession,
+) -> Result<(), BridgeError> {
+ info!(
+ "Refresh token length: {}",
+ session.refresh_token.secret().len()
+ );
+ let record = sqlx::query!(
+ r#"
+ insert into sessions (
+ id,
+ access_token,
+ access_token_expiration,
+ refresh_token
+ ) values ($1, $2, $3, $4)
+ on conflict (id) do update set
+ access_token = EXCLUDED.access_token,
+ access_token_expiration = EXCLUDED.access_token_expiration,
+ refresh_token = EXCLUDED.refresh_token,
+ last_refresh = now()
+ returning *
+ "#,
+ session.session_id.0,
+ session.access_token.secret(),
+ session.expiration,
+ session.refresh_token.secret()
+ )
+ .fetch_one(pool)
+ .await?;
+ session.session_id = SessionId(record.id);
+ session.access_token = AccessToken::new(record.access_token);
+ session.expiration = record.access_token_expiration;
+ session.refresh_token = RefreshToken::new(record.refresh_token);
+ Ok(())
+}
- // params: {"session_state": "909b9959-041b-4a98-84d0-5f978bc8a679", "code": "2b4e95d1-0000-4b28-b49d-7a9de731e82b.909b9959-041b-4a98-84d0-5f978bc8a679.a382d869-4e34-42f1-a64d-24a224b9d338", "state": "a7Hff_hF_FOCqPCxmA1ZXg
- Err(BridgeError::Internal("todo".to_string()))
+pub async fn fetch_authenticated_session(
+ pool: &PgPool,
+ session_id: &SessionId,
+) -> Result<Option<AuthenticatedSession>, BridgeError> {
+ let record = sqlx::query!(
+ r#"
+ select * from sessions
+ where id = $1
+ "#,
+ session_id.0,
+ )
+ .fetch_optional(pool)
+ .await?;
+ match record {
+ None => Ok(None),
+ Some(record) => Ok(Some(AuthenticatedSession {
+ session_id: SessionId(record.id),
+ access_token: AccessToken::new(record.access_token),
+ expiration: record.access_token_expiration,
+ refresh_token: RefreshToken::new(record.refresh_token),
+ })),
}
}