diff options
author | Kjetil Orbekk <kj@orbekk.com> | 2022-10-08 18:33:22 -0400 |
---|---|---|
committer | Kjetil Orbekk <kj@orbekk.com> | 2022-10-08 18:33:22 -0400 |
commit | a7d833d6b7729f09bef891b0c8b7bd998ac17abf (patch) | |
tree | 018bba6c2ff1a58ed5b739939f63a3929d0dc662 | |
parent | 30102e5da48b53806b33f04041a46bec4c3b2fa3 (diff) |
Add bridge table to db; introduce player ids from oauth subject ids
-rw-r--r-- | Cargo.lock | 6 | ||||
-rw-r--r-- | protocol/Cargo.toml | 1 | ||||
-rw-r--r-- | protocol/src/lib.rs | 7 | ||||
-rw-r--r-- | server/.env | 2 | ||||
-rw-r--r-- | server/Cargo.toml | 1 | ||||
-rw-r--r-- | server/migrations/20221008120534_init.down.sql | 4 | ||||
-rw-r--r-- | server/migrations/20221008120534_init.up.sql | 18 | ||||
-rw-r--r-- | server/src/auth.rs | 52 | ||||
-rw-r--r-- | server/src/error.rs | 8 | ||||
-rw-r--r-- | server/src/main.rs | 28 |
10 files changed, 101 insertions, 26 deletions
@@ -1437,6 +1437,7 @@ version = "0.1.0" dependencies = [ "serde", "serde_json", + "uuid", ] [[package]] @@ -1775,6 +1776,7 @@ dependencies = [ "serde_json", "sqlx", "thiserror", + "time 0.1.44", "tokio", "tower", "tower-cookies", @@ -2408,9 +2410,9 @@ checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" [[package]] name = "uuid" -version = "1.1.2" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f" +checksum = "683f0a095f6dcf74520a5f17a12452ae6f970abbd2443299a1e226fd38195f2b" dependencies = [ "getrandom", "rand", diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml index 2836fad..92c0a5e 100644 --- a/protocol/Cargo.toml +++ b/protocol/Cargo.toml @@ -8,3 +8,4 @@ edition = "2021" [dependencies] serde = { version = "1.0.145", features = ["derive"] } serde_json = "1.0.85" +uuid = "1.2.0" diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index a56554f..4b1bb06 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -1,6 +1,13 @@ use serde::{Deserialize, Serialize}; +use uuid::Uuid; #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] pub struct UserInfo { pub username: String, + pub table: Option<Table>, +} + +#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] +pub struct Table { + pub id: Uuid, } diff --git a/server/.env b/server/.env index e575250..66c6f58 100644 --- a/server/.env +++ b/server/.env @@ -1,4 +1,4 @@ -RUST_LOG=info,tower_http=debug,server=debug +RUST_LOG=info,tower_http=debug,server=info,sqlx=warn BIND_ADDRESS=[::]:11121 RUST_BACKTRACE=1 OPENID_ISSUER_URL=https://auth.orbekk.com/realms/test diff --git a/server/Cargo.toml b/server/Cargo.toml index 423ada2..94b2684 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -27,3 +27,4 @@ chrono = { version = "0.4.22", features = ["serde"] } thiserror = "1.0.37" reqwest = "0.11.12" cookie = "0.16.1" +time = "0.1.44" diff --git a/server/migrations/20221008120534_init.down.sql b/server/migrations/20221008120534_init.down.sql index c855f63..ec1abe4 100644 --- a/server/migrations/20221008120534_init.down.sql +++ b/server/migrations/20221008120534_init.down.sql @@ -1,3 +1,5 @@ -- Add down migration script here -drop table if exists users; +drop table if exists players; drop table if exists sessions; +drop table if exists table_players; +drop table if exists active_tables; diff --git a/server/migrations/20221008120534_init.up.sql b/server/migrations/20221008120534_init.up.sql index b3527eb..6a46f84 100644 --- a/server/migrations/20221008120534_init.up.sql +++ b/server/migrations/20221008120534_init.up.sql @@ -1,8 +1,24 @@ -- Add up migration script here +create table players ( + id varchar(64) primary key not null +); + create table sessions ( - id uuid primary key, + id uuid primary key not null, + player_id varchar(64) references players (id) not null, access_token varchar(2048) not null, access_token_expiration timestamp with time zone not null, refresh_token varchar(1024) not null, last_refresh timestamp with time zone not null default now() ); + +create table active_tables ( + id uuid primary key not null +); + +create table table_players ( + active_tables_id uuid not null references active_tables (id), + player_id varchar(64) not null references players (id), + primary key(active_tables_id, player_id) +); +create unique index player_table on table_players (player_id); diff --git a/server/src/auth.rs b/server/src/auth.rs index 44f16ea..e30cd6e 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -20,7 +20,7 @@ use openidconnect::{ use protocol::UserInfo; use serde::{Deserialize, Serialize}; use sqlx::PgPool; -use tracing::info; +use tracing::{info, error, debug}; use uuid::Uuid; pub struct LoginState { @@ -30,6 +30,7 @@ pub struct LoginState { #[derive(Debug)] pub struct AuthenticatedSession { + pub player_id: String, pub session_id: SessionId, expiration: DateTime<Utc>, access_token: AccessToken, @@ -142,12 +143,19 @@ impl Authenticator { return Ok(()); } info!("Refreshing expiring token: {}", session.expiration); + let refresh_start = Utc::now(); let new_token = self .client .exchange_refresh_token(&session.refresh_token) .request_async(async_http_client) .await?; - info!("Got new token: {new_token:#?}"); + debug!("Got new token: {new_token:#?}"); + // TODO: Validate token? + if let Some(expires_in) = new_token.expires_in() { + session.expiration = refresh_start + chrono::Duration::from_std(expires_in)?; + } else { + error!("Token is missing expiration! Will refresh token every time."); + } if let Some(refresh_token) = new_token.refresh_token() { session.refresh_token = refresh_token.clone(); } @@ -215,8 +223,9 @@ impl Authenticator { let refresh_token = token .refresh_token() .ok_or(BridgeError::Internal("Expected refresh token".to_string()))?; - + let mut session = AuthenticatedSession { + player_id: claims.subject().to_string(), session_id, expiration: claims.expiration(), access_token: token.access_token().clone(), @@ -230,22 +239,20 @@ impl Authenticator { pub async fn user_info( &self, session: &mut AuthenticatedSession, - ) -> Result<UserInfo, BridgeError> { + ) -> Result<String, BridgeError> { self.maybe_refresh_token(session).await?; let user_info: CoreUserInfoClaims = self .client .user_info(session.access_token.clone(), None)? .request_async(async_http_client) .await?; - info!("Resolved user info: {user_info:#?}"); - Ok(UserInfo { - username: user_info - .preferred_username() - .ok_or(BridgeError::Internal( - "missing preferred username".to_string(), - ))? - .to_string(), - }) + debug!("Resolved user info: {user_info:#?}"); + Ok(user_info + .preferred_username() + .ok_or(BridgeError::Internal( + "missing preferred username".to_string(), + ))? + .to_string()) } } @@ -253,18 +260,30 @@ async fn store_authenticated_session( pool: &PgPool, session: &mut AuthenticatedSession, ) -> Result<(), BridgeError> { - info!( + debug!( "Refresh token length: {}", session.refresh_token.secret().len() ); + sqlx::query!( + r#" + insert into players (id) + values ($1) + on conflict do nothing + "#, + session.player_id + ) + .execute(pool) + .await?; + let record = sqlx::query!( r#" insert into sessions ( id, + player_id, access_token, access_token_expiration, refresh_token - ) values ($1, $2, $3, $4) + ) values ($1, $2, $3, $4, $5) on conflict (id) do update set access_token = EXCLUDED.access_token, access_token_expiration = EXCLUDED.access_token_expiration, @@ -273,12 +292,14 @@ async fn store_authenticated_session( returning * "#, session.session_id.0, + session.player_id, session.access_token.secret(), session.expiration, session.refresh_token.secret() ) .fetch_one(pool) .await?; + session.player_id = record.player_id; session.session_id = SessionId(record.id); session.access_token = AccessToken::new(record.access_token); session.expiration = record.access_token_expiration; @@ -302,6 +323,7 @@ pub async fn fetch_authenticated_session( match record { None => Ok(None), Some(record) => Ok(Some(AuthenticatedSession { + player_id: record.player_id, session_id: SessionId(record.id), access_token: AccessToken::new(record.access_token), expiration: record.access_token_expiration, diff --git a/server/src/error.rs b/server/src/error.rs index 1a45e96..cea23e7 100644 --- a/server/src/error.rs +++ b/server/src/error.rs @@ -2,8 +2,7 @@ use axum::{http::StatusCode, response::IntoResponse}; use openidconnect::{core::CoreErrorResponseType, ClaimsVerificationError, StandardErrorResponse}; use tracing::error; -type UserInfoError = - openidconnect::UserInfoError<openidconnect::reqwest::Error<reqwest::Error>>; +type UserInfoError = openidconnect::UserInfoError<openidconnect::reqwest::Error<reqwest::Error>>; type RequestTokenError = openidconnect::RequestTokenError< openidconnect::reqwest::Error<reqwest::Error>, @@ -38,11 +37,14 @@ pub enum BridgeError { #[error("Internal server error: {0}")] Internal(String), + + #[error("Duration out of range")] + DurationOutOfRange(#[from] time::OutOfRangeError), } impl IntoResponse for BridgeError { fn into_response(self) -> axum::response::Response { - error!("Error occurred: {self:?}"); + error!("Error occurred: {self:?}"); (StatusCode::INTERNAL_SERVER_ERROR, format!("Error: {self}")).into_response() } } diff --git a/server/src/main.rs b/server/src/main.rs index 87f95e4..22f9e19 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,12 +1,13 @@ use std::{collections::HashMap, env, str::FromStr, sync::Arc}; +use auth::AuthenticatedSession; use axum::{ extract::{Extension, Query}, response::Redirect, routing::get, Json, Router, }; -use protocol::UserInfo; +use protocol::{Table, UserInfo}; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use tower_http::trace::TraceLayer; use tracing::info; @@ -52,7 +53,7 @@ async fn main() { let app_url = env::var("APP_URL").unwrap(); let state = Arc::new(ServerContext { - app_url: app_url, + app_url, authenticator: Authenticator::from_env(db_pool.clone()).await, db: db_pool, }); @@ -93,7 +94,28 @@ async fn user_info( None => return Ok(Json(None)), Some(v) => v, }; - Ok(Json(Some(extension.authenticator.user_info(&mut session).await?))) + Ok(Json(Some(UserInfo { + username: extension.authenticator.user_info(&mut session).await?, + table: get_table(&extension.db, &session).await?, + }))) +} + +async fn get_table( + db: &PgPool, + session: &AuthenticatedSession, +) -> Result<Option<Table>, BridgeError> { + Ok(sqlx::query_as!( + Table, + r#" + select tables.id + from table_players players + natural join active_tables tables + where player_id = $1 + "#, + session.player_id + ) + .fetch_optional(db) + .await?) } async fn login_callback( |