summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorKjetil Orbekk <kj@orbekk.com>2022-10-08 17:22:48 -0400
committerKjetil Orbekk <kj@orbekk.com>2022-10-08 17:22:48 -0400
commit30102e5da48b53806b33f04041a46bec4c3b2fa3 (patch)
treecf9fd3ce1f8c449cb4cb1b8837015c7b514b916b /server
parent1cbf881835fc33859a31645f886c5d3787ed48f8 (diff)
Add token refresh and persist sessions in the db
Diffstat (limited to 'server')
-rw-r--r--server/.env2
-rw-r--r--server/Cargo.toml3
-rw-r--r--server/migrations/20221008120534_init.up.sql3
-rw-r--r--server/src/auth.rs210
-rw-r--r--server/src/error.rs28
-rw-r--r--server/src/main.rs68
6 files changed, 260 insertions, 54 deletions
diff --git a/server/.env b/server/.env
index 61c7e89..e575250 100644
--- a/server/.env
+++ b/server/.env
@@ -1,4 +1,4 @@
-RUST_LOG=info
+RUST_LOG=info,tower_http=debug,server=debug
BIND_ADDRESS=[::]:11121
RUST_BACKTRACE=1
OPENID_ISSUER_URL=https://auth.orbekk.com/realms/test
diff --git a/server/Cargo.toml b/server/Cargo.toml
index 0651fef..423ada2 100644
--- a/server/Cargo.toml
+++ b/server/Cargo.toml
@@ -21,8 +21,9 @@ uuid = { version = "1.1.2", features = ["serde", "fast-rng", "v4"] }
tower-cookies = "0.7.0"
tower = { version = "0.4.13", features = ["full"] }
urlencoding = "2.1.2"
-sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls" , "postgres" ] }
+sqlx = { version = "0.6", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono"] }
anyhow = "1.0.65"
chrono = { version = "0.4.22", features = ["serde"] }
thiserror = "1.0.37"
reqwest = "0.11.12"
+cookie = "0.16.1"
diff --git a/server/migrations/20221008120534_init.up.sql b/server/migrations/20221008120534_init.up.sql
index 301d2eb..b3527eb 100644
--- a/server/migrations/20221008120534_init.up.sql
+++ b/server/migrations/20221008120534_init.up.sql
@@ -3,5 +3,6 @@ create table sessions (
id uuid primary key,
access_token varchar(2048) not null,
access_token_expiration timestamp with time zone not null,
- refresh_token varchar(512) not null
+ refresh_token varchar(1024) not null,
+ last_refresh timestamp with time zone not null default now()
);
diff --git a/server/src/auth.rs b/server/src/auth.rs
index 01ee467..44f16ea 100644
--- a/server/src/auth.rs
+++ b/server/src/auth.rs
@@ -2,20 +2,24 @@ use std::{
collections::HashMap,
env,
num::NonZeroUsize,
+ str::FromStr,
sync::{Arc, Mutex},
};
use crate::error::BridgeError;
-use chrono::Utc;
+use chrono::{DateTime, Utc};
use lru::LruCache;
use openidconnect::{
- core::{CoreClient, CoreProviderMetadata, CoreResponseType},
+ core::{CoreClient, CoreProviderMetadata, CoreResponseType, CoreUserInfoClaims},
reqwest::async_http_client,
url::Url,
- AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
- IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RedirectUrl, Scope, TokenResponse,
+ AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret,
+ CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, RedirectUrl, RefreshToken, Scope,
+ TokenResponse,
};
+use protocol::UserInfo;
use serde::{Deserialize, Serialize};
+use sqlx::PgPool;
use tracing::info;
use uuid::Uuid;
@@ -24,22 +28,49 @@ pub struct LoginState {
nonce: Nonce,
}
+#[derive(Debug)]
+pub struct AuthenticatedSession {
+ pub session_id: SessionId,
+ expiration: DateTime<Utc>,
+ access_token: AccessToken,
+ refresh_token: RefreshToken,
+}
+
pub struct Authenticator {
pub client: CoreClient,
- pub login_cache: Arc<Mutex<LruCache<EndUserId, LoginState>>>,
+ pub login_cache: Arc<Mutex<LruCache<SessionId, LoginState>>>,
+ pub db: PgPool,
}
#[derive(Eq, PartialEq, Hash, Debug, Clone, Serialize, Deserialize)]
-pub struct EndUserId(Uuid);
+pub struct SessionId(Uuid);
-impl EndUserId {
+impl SessionId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
+impl ToString for SessionId {
+ fn to_string(&self) -> String {
+ self.0.to_string()
+ }
+}
+
+impl FromStr for SessionId {
+ type Err = BridgeError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ Ok(SessionId(Uuid::from_str(s)?))
+ }
+}
+
const LOGIN_CACHE_SIZE: usize = 50;
+fn token_safe_time() -> chrono::Duration {
+ chrono::Duration::seconds(30)
+}
+
pub const LOGIN_CALLBACK: &'static str = "/api/login_callback";
fn redirect_url(app_url: &str) -> RedirectUrl {
RedirectUrl::new(format!("{}{}", app_url, LOGIN_CALLBACK)).unwrap()
@@ -47,6 +78,7 @@ fn redirect_url(app_url: &str) -> RedirectUrl {
impl Authenticator {
pub async fn new(
+ db: PgPool,
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
@@ -63,6 +95,7 @@ impl Authenticator {
.set_redirect_uri(redirect_uri);
Self {
+ db,
client,
login_cache: Arc::new(Mutex::new(LruCache::new(
NonZeroUsize::new(LOGIN_CACHE_SIZE).unwrap(),
@@ -70,9 +103,10 @@ impl Authenticator {
}
}
- pub async fn from_env() -> Self {
+ pub async fn from_env(db: PgPool) -> Self {
let app_url = env::var("APP_URL").unwrap();
Authenticator::new(
+ db,
IssuerUrl::new(env::var("OPENID_ISSUER_URL").unwrap()).unwrap(),
ClientId::new(env::var("OPENID_CLIENT_ID").unwrap()),
ClientSecret::new(env::var("OPENID_CLIENT_SECRET").unwrap()),
@@ -81,7 +115,7 @@ impl Authenticator {
.await
}
- pub async fn get_login_url(&self) -> (EndUserId, Url) {
+ pub async fn get_login_url(&self) -> (SessionId, Url) {
let (auth_url, csrf_token, nonce) = self
.client
.authorize_url(
@@ -92,7 +126,7 @@ impl Authenticator {
.add_scope(Scope::new("email".to_string()))
.add_scope(Scope::new("profile".to_string()))
.url();
- let user_id = EndUserId::new();
+ let user_id = SessionId::new();
self.login_cache
.lock()
.unwrap()
@@ -100,23 +134,41 @@ impl Authenticator {
(user_id, auth_url)
}
+ pub async fn maybe_refresh_token(
+ &self,
+ session: &mut AuthenticatedSession,
+ ) -> Result<(), BridgeError> {
+ if session.expiration > Utc::now() + token_safe_time() {
+ return Ok(());
+ }
+ info!("Refreshing expiring token: {}", session.expiration);
+ let new_token = self
+ .client
+ .exchange_refresh_token(&session.refresh_token)
+ .request_async(async_http_client)
+ .await?;
+ info!("Got new token: {new_token:#?}");
+ if let Some(refresh_token) = new_token.refresh_token() {
+ session.refresh_token = refresh_token.clone();
+ }
+ session.access_token = new_token.access_token().clone();
+ store_authenticated_session(&self.db, session).await?;
+ Ok(())
+ }
+
pub async fn authenticate(
&self,
- user_id: EndUserId,
+ pool: &PgPool,
+ session_id: SessionId,
auth_params: HashMap<String, String>,
- ) -> Result<(), BridgeError> {
+ ) -> Result<AuthenticatedSession, BridgeError> {
// TODO: If the token is missing from the cache, client should retry logging in.
let state = self
.login_cache
.lock()
.unwrap()
- .pop(&user_id)
+ .pop(&session_id)
.ok_or(BridgeError::InvalidRequest("token missing".to_string()))?;
- info!(
- "state: {:?}, {:?}",
- state.csrf_token.secret(),
- state.nonce.secret()
- );
if Some(state.csrf_token.secret()) != auth_params.get("state") {
return Err(BridgeError::InvalidRequest(
"token validation failed".to_string(),
@@ -136,16 +188,124 @@ impl Authenticator {
.exchange_code(authorization_code)
.request_async(async_http_client)
.await?;
- info!("Got token {token:#?}");
- let id_token = token
- .id_token()
- .ok_or(BridgeError::InvalidRequest("Server did not return an IdToken".to_string()))?;
+ let id_token = token.id_token().ok_or(BridgeError::InvalidRequest(
+ "Server did not return an IdToken".to_string(),
+ ))?;
let claims = id_token.claims(&self.client.id_token_verifier(), &state.nonce)?;
+
+ // Verify access token hash.
+ if let Some(expected_access_token_hash) = claims.access_token_hash() {
+ let actual_access_token_hash =
+ AccessTokenHash::from_token(token.access_token(), &id_token.signing_alg()?)?;
+ if actual_access_token_hash != *expected_access_token_hash {
+ return Err(BridgeError::InvalidRequest(
+ "Invalid access token".to_string(),
+ ));
+ }
+ }
+
+ if claims.expiration() < Utc::now() {
+ return Err(BridgeError::Internal(format!(
+ "Token expired at {}",
+ claims.expiration()
+ )));
+ }
+
+ let refresh_token = token
+ .refresh_token()
+ .ok_or(BridgeError::Internal("Expected refresh token".to_string()))?;
- info!("Got claims {claims:#?}");
+ let mut session = AuthenticatedSession {
+ session_id,
+ expiration: claims.expiration(),
+ access_token: token.access_token().clone(),
+ refresh_token: refresh_token.clone(),
+ };
+
+ store_authenticated_session(pool, &mut session).await?;
+ Ok(session)
+ }
+
+ pub async fn user_info(
+ &self,
+ session: &mut AuthenticatedSession,
+ ) -> Result<UserInfo, BridgeError> {
+ self.maybe_refresh_token(session).await?;
+ let user_info: CoreUserInfoClaims = self
+ .client
+ .user_info(session.access_token.clone(), None)?
+ .request_async(async_http_client)
+ .await?;
+ info!("Resolved user info: {user_info:#?}");
+ Ok(UserInfo {
+ username: user_info
+ .preferred_username()
+ .ok_or(BridgeError::Internal(
+ "missing preferred username".to_string(),
+ ))?
+ .to_string(),
+ })
+ }
+}
+
+async fn store_authenticated_session(
+ pool: &PgPool,
+ session: &mut AuthenticatedSession,
+) -> Result<(), BridgeError> {
+ info!(
+ "Refresh token length: {}",
+ session.refresh_token.secret().len()
+ );
+ let record = sqlx::query!(
+ r#"
+ insert into sessions (
+ id,
+ access_token,
+ access_token_expiration,
+ refresh_token
+ ) values ($1, $2, $3, $4)
+ on conflict (id) do update set
+ access_token = EXCLUDED.access_token,
+ access_token_expiration = EXCLUDED.access_token_expiration,
+ refresh_token = EXCLUDED.refresh_token,
+ last_refresh = now()
+ returning *
+ "#,
+ session.session_id.0,
+ session.access_token.secret(),
+ session.expiration,
+ session.refresh_token.secret()
+ )
+ .fetch_one(pool)
+ .await?;
+ session.session_id = SessionId(record.id);
+ session.access_token = AccessToken::new(record.access_token);
+ session.expiration = record.access_token_expiration;
+ session.refresh_token = RefreshToken::new(record.refresh_token);
+ Ok(())
+}
- // params: {"session_state": "909b9959-041b-4a98-84d0-5f978bc8a679", "code": "2b4e95d1-0000-4b28-b49d-7a9de731e82b.909b9959-041b-4a98-84d0-5f978bc8a679.a382d869-4e34-42f1-a64d-24a224b9d338", "state": "a7Hff_hF_FOCqPCxmA1ZXg
- Err(BridgeError::Internal("todo".to_string()))
+pub async fn fetch_authenticated_session(
+ pool: &PgPool,
+ session_id: &SessionId,
+) -> Result<Option<AuthenticatedSession>, BridgeError> {
+ let record = sqlx::query!(
+ r#"
+ select * from sessions
+ where id = $1
+ "#,
+ session_id.0,
+ )
+ .fetch_optional(pool)
+ .await?;
+ match record {
+ None => Ok(None),
+ Some(record) => Ok(Some(AuthenticatedSession {
+ session_id: SessionId(record.id),
+ access_token: AccessToken::new(record.access_token),
+ expiration: record.access_token_expiration,
+ refresh_token: RefreshToken::new(record.refresh_token),
+ })),
}
}
diff --git a/server/src/error.rs b/server/src/error.rs
index 439e81b..1a45e96 100644
--- a/server/src/error.rs
+++ b/server/src/error.rs
@@ -1,5 +1,9 @@
use axum::{http::StatusCode, response::IntoResponse};
-use openidconnect::{core::CoreErrorResponseType, StandardErrorResponse, ClaimsVerificationError};
+use openidconnect::{core::CoreErrorResponseType, ClaimsVerificationError, StandardErrorResponse};
+use tracing::error;
+
+type UserInfoError =
+ openidconnect::UserInfoError<openidconnect::reqwest::Error<reqwest::Error>>;
type RequestTokenError = openidconnect::RequestTokenError<
openidconnect::reqwest::Error<reqwest::Error>,
@@ -11,18 +15,34 @@ pub enum BridgeError {
#[error("Invalid request: {0}")]
InvalidRequest(String),
- #[error("Backend request failed")]
- Backend(#[from] RequestTokenError),
+ #[error("Requesting token failed")]
+ OpenidRequestTokenError(#[from] RequestTokenError),
+
+ #[error("Requesting user info failed")]
+ OpenidUserInfoError(#[from] UserInfoError),
+
+ #[error("Failed to configure OpenId request")]
+ OpenIdConfigurationError(#[from] openidconnect::ConfigurationError),
#[error("Unexpected authorization error")]
UnexpectedInvalidAuthorization(#[from] ClaimsVerificationError),
-
+
+ #[error("Authentication error")]
+ SigningFailed(#[from] openidconnect::SigningError),
+
+ #[error("Database error")]
+ SqlxError(#[from] sqlx::Error),
+
+ #[error("Uuid parse failed")]
+ UuidError(#[from] uuid::Error),
+
#[error("Internal server error: {0}")]
Internal(String),
}
impl IntoResponse for BridgeError {
fn into_response(self) -> axum::response::Response {
+ error!("Error occurred: {self:?}");
(StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {self}")).into_response()
}
}
diff --git a/server/src/main.rs b/server/src/main.rs
index 4183abb..87f95e4 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -1,10 +1,10 @@
-use std::{collections::HashMap, env, sync::Arc};
+use std::{collections::HashMap, env, str::FromStr, sync::Arc};
use axum::{
extract::{Extension, Query},
- response::{Redirect, IntoResponse},
+ response::Redirect,
routing::get,
- Json, Router, http::StatusCode,
+ Json, Router,
};
use protocol::UserInfo;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
@@ -13,9 +13,9 @@ use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod auth;
mod error;
-use crate::auth::{Authenticator, EndUserId};
-use sqlx::{postgres::PgPoolOptions, PgPool};
+use crate::auth::{Authenticator, SessionId};
use crate::error::BridgeError;
+use sqlx::{postgres::PgPoolOptions, PgPool};
pub struct ServerContext {
pub app_url: String,
@@ -39,7 +39,9 @@ async fn main() {
let db_url = env::var("DATABASE_URL").unwrap();
let db_pool: PgPool = PgPoolOptions::new()
.max_connections(10)
- .connect(&db_url).await.expect("db connection");
+ .connect(&db_url)
+ .await
+ .expect("db connection");
info!("Running db migrations");
sqlx::migrate!().run(&db_pool).await.expect("db migration");
@@ -51,8 +53,8 @@ async fn main() {
let state = Arc::new(ServerContext {
app_url: app_url,
- authenticator: Authenticator::from_env().await,
- db: db_pool,
+ authenticator: Authenticator::from_env(db_pool.clone()).await,
+ db: db_pool,
});
let app = Router::new()
@@ -69,32 +71,54 @@ async fn main() {
.unwrap();
}
-async fn user_info() -> Json<Option<UserInfo>> {
- Json(None)
+async fn user_info(
+ cookies: Cookies,
+ extension: ContextExtension,
+) -> Result<Json<Option<UserInfo>>, BridgeError> {
+ let cookie = match cookies.get("user-id") {
+ None => return Ok(Json(None)),
+ Some(v) => v,
+ };
+
+ let session_id: SessionId = match SessionId::from_str(cookie.value()) {
+ Err(e) => {
+ info!("Clearing cookie that failed to parse {cookie:?}: {e}");
+ cookies.remove(cookie.into_owned());
+ return Ok(Json(None));
+ }
+ Ok(s) => s,
+ };
+ let mut session =
+ match crate::auth::fetch_authenticated_session(&extension.db, &session_id).await? {
+ None => return Ok(Json(None)),
+ Some(v) => v,
+ };
+ Ok(Json(Some(extension.authenticator.user_info(&mut session).await?)))
}
async fn login_callback(
cookies: Cookies,
Query(params): Query<HashMap<String, String>>,
extension: ContextExtension,
-) -> Result<(), BridgeError> {
+) -> Result<Redirect, BridgeError> {
let cookie = cookies.get("user-id").unwrap();
- let user_id: EndUserId =
- serde_json::from_str(&urlencoding::decode(cookie.value()).unwrap()).unwrap();
- info!("cookie: {cookie:?}");
- info!("params: {params:?}");
- extension.authenticator.authenticate(user_id, params).await?;
- Ok(())
+ let user_id: SessionId = SessionId::from_str(cookie.value())?;
+ let session = extension
+ .authenticator
+ .authenticate(&extension.db, user_id, params)
+ .await?;
+ info!("Logged in session: {session:?}");
+ Ok(Redirect::temporary(&extension.app_url))
}
async fn login(cookies: Cookies, extension: ContextExtension) -> Redirect {
let (user_id, auth_url) = extension.authenticator.get_login_url().await;
info!("Creating auth url for {user_id:?}");
let user_id = serde_json::to_string(&user_id).unwrap();
- cookies.add(Cookie::new(
- "user-id",
- urlencoding::encode(&user_id).to_string(),
- ));
+ let mut cookie = Cookie::new("user-id", user_id.to_string());
+ cookie.set_http_only(true);
+ cookie.set_secure(true);
+ cookie.set_same_site(cookie::SameSite::Lax);
+ cookies.add(cookie);
Redirect::temporary(auth_url.as_str())
}
-