diff options
author | Kjetil Orbekk <kj@orbekk.com> | 2022-12-23 07:37:08 -0500 |
---|---|---|
committer | Kjetil Orbekk <kj@orbekk.com> | 2022-12-23 07:37:08 -0500 |
commit | 38f4ef0073c43e478e14c3dd0cc28943b360f013 (patch) | |
tree | 17c456acd9a9b5898a9a7b728d14a38932d69a80 /server | |
parent | eae8b9b7a40c3f2a52f319e695b280a41618fdd8 (diff) |
Use new type safe state handling from axum 0.6
Diffstat (limited to 'server')
-rw-r--r-- | server/src/auth.rs | 17 | ||||
-rw-r--r-- | server/src/main.rs | 50 | ||||
-rw-r--r-- | server/src/server.rs | 4 |
3 files changed, 35 insertions, 36 deletions
diff --git a/server/src/auth.rs b/server/src/auth.rs index d0f6c38..a924f44 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -6,10 +6,10 @@ use std::{ sync::{Arc, Mutex}, }; -use crate::{error::BridgeError, server::ContextExtension}; +use crate::{error::BridgeError, server::{ServerState, ServerContext}}; use async_trait::async_trait; use axum::{ - extract::{FromRequest, FromRequestParts}, + extract::{FromRequestParts, State, FromRef}, response::{IntoResponse, Response}, http::request::Parts, }; use chrono::{DateTime, Utc}; @@ -368,21 +368,20 @@ pub async fn fetch_authenticated_session( } #[async_trait] -impl<B> FromRequestParts<B> for AuthenticatedSession +impl<S> FromRequestParts<S> for AuthenticatedSession where - B: Send + Sync, + S: Send + Sync, + Arc<ServerContext>: FromRef<S> { type Rejection = Response; async fn from_request_parts( - parts: &mut Parts, state: &B + parts: &mut Parts, state: &S ) -> Result<Self, Self::Rejection> { let cookies = Cookies::from_request_parts(parts, state) .await .map_err(|e| e.into_response())?; - let extension = ContextExtension::from_request_parts(parts, state) - .await - .map_err(|e| e.into_response())?; + let state = Arc::<ServerContext>::from_ref(state); let cookie = match cookies.get("user-id") { None => return Err(BridgeError::NotLoggedIn.into_response()), Some(v) => v, @@ -396,7 +395,7 @@ where } Ok(s) => s, }; - let session = match crate::auth::fetch_authenticated_session(&extension.db, &session_id) + let session = match crate::auth::fetch_authenticated_session(&state.db, &session_id) .await .map_err(|e| e.into_response())? { diff --git a/server/src/main.rs b/server/src/main.rs index 03ae3e3..ab5cdcc 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -4,14 +4,14 @@ use uuid::Uuid; use auth::AuthenticatedSession; use axum::{ - extract::{Extension, Path, Query}, + extract::{Path, Query, State}, response::{Html, Redirect}, routing::{delete, get, post}, Json, Router, }; use protocol::bridge_engine::{Bid, GameStatePlayerView, Player}; use protocol::{Table, UserInfo}; -use server::ContextExtension; +use server::ServerState; use tower_cookies::{Cookie, CookieManagerLayer, Cookies}; use tower_http::trace::TraceLayer; use tracing::{info, log::warn}; @@ -117,8 +117,8 @@ async fn main() { .route("/api/login", get(login)) .route(auth::LOGIN_CALLBACK, get(login_callback)) .layer(CookieManagerLayer::new()) - .layer(Extension(state)) - .layer(TraceLayer::new_for_http()); + .layer(TraceLayer::new_for_http()) + .with_state(state); axum::Server::bind(&bind_address.parse().unwrap()) .serve(app.into_make_service()) @@ -145,12 +145,12 @@ async fn fake_login() -> Html<&'static str> { async fn get_table_view( _session: AuthenticatedSession, - extension: ContextExtension, + State(state): ServerState, Path(id): Path<Uuid>, ) -> Result<Json<protocol::bridge_engine::GameStatePlayerView>, BridgeError> { info!("Getting table state for {id:}"); let player_position = Player::South; - let jnl = DbJournal::new(extension.db.clone(), id); + let jnl = DbJournal::new(state.db.clone(), id); let mut table = play::Table::new_or_replay(jnl).await?; info!("Advancing play"); while table.game()?.current_player() != player_position { @@ -166,12 +166,12 @@ async fn get_table_view( async fn post_bid( _session: AuthenticatedSession, - extension: ContextExtension, + State(state): ServerState, Path(id): Path<Uuid>, Json(bid): Json<Bid>, ) -> Result<Json<()>, BridgeError> { info!("Getting table state for {id:}"); - let jnl = DbJournal::new(extension.db.clone(), id); + let jnl = DbJournal::new(state.db.clone(), id); let mut table = play::Table::replay(jnl).await?; if !table.game()?.is_bidding() { return Err(BridgeError::InvalidRequest( @@ -191,7 +191,7 @@ async fn post_bid( async fn leave_table( session: AuthenticatedSession, - extension: ContextExtension, + State(state): ServerState, ) -> Result<(), BridgeError> { sqlx::query!( r#" @@ -199,16 +199,16 @@ async fn leave_table( "#, session.player_id ) - .execute(&extension.db) + .execute(&state.db) .await?; Ok(()) } async fn create_table( session: AuthenticatedSession, - extension: ContextExtension, + State(state): ServerState, ) -> Result<Json<Uuid>, BridgeError> { - let txn = extension.db.begin().await?; + let txn = state.db.begin().await?; let table_id = sqlx::query!( r#" insert into active_tables (id) @@ -217,7 +217,7 @@ async fn create_table( "#, Uuid::new_v4() ) - .fetch_one(&extension.db) + .fetch_one(&state.db) .await? .id; @@ -231,7 +231,7 @@ async fn create_table( table_id, session.player_id ) - .execute(&extension.db) + .execute(&state.db) .await?; txn.commit().await?; @@ -240,20 +240,20 @@ async fn create_table( async fn user_info( session: Option<AuthenticatedSession>, - extension: ContextExtension, + State(state): ServerState, ) -> Result<Json<Option<UserInfo>>, BridgeError> { let mut session = match session { None => return Ok(Json(None)), Some(s) => s, }; Ok(Json(Some(UserInfo { - username: extension.authenticator.user_info(&mut session).await?, - table: user_table(extension, &session).await?, + username: state.authenticator.user_info(&mut session).await?, + table: user_table(&*state, &session).await?, }))) } async fn user_table( - extension: ContextExtension, + state: &ServerContext, session: &AuthenticatedSession, ) -> Result<Option<Table>, BridgeError> { Ok(sqlx::query_as!( @@ -266,27 +266,27 @@ async fn user_table( "#, session.player_id ) - .fetch_optional(&extension.db) + .fetch_optional(&state.db) .await?) } async fn login_callback( cookies: Cookies, Query(params): Query<HashMap<String, String>>, - extension: ContextExtension, + State(state): ServerState, ) -> Result<Redirect, BridgeError> { let cookie = cookies.get("user-id").unwrap(); let user_id: SessionId = SessionId::from_str(cookie.value())?; - let session = extension + let session = state .authenticator - .authenticate(&extension.db, user_id, params) + .authenticate(&state.db, user_id, params) .await?; info!("Logged in session: {session:?}"); - Ok(Redirect::temporary(&extension.app_url)) + Ok(Redirect::temporary(&state.app_url)) } -async fn login(cookies: Cookies, extension: ContextExtension) -> Redirect { - let (user_id, auth_url) = extension.authenticator.get_login_url().await; +async fn login(cookies: Cookies, State(state): ServerState) -> Redirect { + let (user_id, auth_url) = state.authenticator.get_login_url().await; info!("Creating auth url for {user_id:?}"); let user_id = serde_json::to_string(&user_id).unwrap(); let mut cookie = Cookie::new("user-id", user_id.to_string()); diff --git a/server/src/server.rs b/server/src/server.rs index 647abf9..4e563df 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,7 +1,7 @@ use sqlx::PgPool; use std::sync::Arc; -use axum::Extension; +use axum::extract::State; use crate::auth::Authenticator; @@ -10,4 +10,4 @@ pub struct ServerContext { pub authenticator: Box<dyn Authenticator + Send + Sync>, pub db: PgPool, } -pub type ContextExtension = Extension<Arc<ServerContext>>; +pub type ServerState = State<Arc<ServerContext>>; |