summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorKjetil Orbekk <kj@orbekk.com>2022-12-23 07:37:08 -0500
committerKjetil Orbekk <kj@orbekk.com>2022-12-23 07:37:08 -0500
commit38f4ef0073c43e478e14c3dd0cc28943b360f013 (patch)
tree17c456acd9a9b5898a9a7b728d14a38932d69a80 /server
parenteae8b9b7a40c3f2a52f319e695b280a41618fdd8 (diff)
Use new type safe state handling from axum 0.6
Diffstat (limited to 'server')
-rw-r--r--server/src/auth.rs17
-rw-r--r--server/src/main.rs50
-rw-r--r--server/src/server.rs4
3 files changed, 35 insertions, 36 deletions
diff --git a/server/src/auth.rs b/server/src/auth.rs
index d0f6c38..a924f44 100644
--- a/server/src/auth.rs
+++ b/server/src/auth.rs
@@ -6,10 +6,10 @@ use std::{
sync::{Arc, Mutex},
};
-use crate::{error::BridgeError, server::ContextExtension};
+use crate::{error::BridgeError, server::{ServerState, ServerContext}};
use async_trait::async_trait;
use axum::{
- extract::{FromRequest, FromRequestParts},
+ extract::{FromRequestParts, State, FromRef},
response::{IntoResponse, Response}, http::request::Parts,
};
use chrono::{DateTime, Utc};
@@ -368,21 +368,20 @@ pub async fn fetch_authenticated_session(
}
#[async_trait]
-impl<B> FromRequestParts<B> for AuthenticatedSession
+impl<S> FromRequestParts<S> for AuthenticatedSession
where
- B: Send + Sync,
+ S: Send + Sync,
+ Arc<ServerContext>: FromRef<S>
{
type Rejection = Response;
async fn from_request_parts(
- parts: &mut Parts, state: &B
+ parts: &mut Parts, state: &S
) -> Result<Self, Self::Rejection> {
let cookies = Cookies::from_request_parts(parts, state)
.await
.map_err(|e| e.into_response())?;
- let extension = ContextExtension::from_request_parts(parts, state)
- .await
- .map_err(|e| e.into_response())?;
+ let state = Arc::<ServerContext>::from_ref(state);
let cookie = match cookies.get("user-id") {
None => return Err(BridgeError::NotLoggedIn.into_response()),
Some(v) => v,
@@ -396,7 +395,7 @@ where
}
Ok(s) => s,
};
- let session = match crate::auth::fetch_authenticated_session(&extension.db, &session_id)
+ let session = match crate::auth::fetch_authenticated_session(&state.db, &session_id)
.await
.map_err(|e| e.into_response())?
{
diff --git a/server/src/main.rs b/server/src/main.rs
index 03ae3e3..ab5cdcc 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -4,14 +4,14 @@ use uuid::Uuid;
use auth::AuthenticatedSession;
use axum::{
- extract::{Extension, Path, Query},
+ extract::{Path, Query, State},
response::{Html, Redirect},
routing::{delete, get, post},
Json, Router,
};
use protocol::bridge_engine::{Bid, GameStatePlayerView, Player};
use protocol::{Table, UserInfo};
-use server::ContextExtension;
+use server::ServerState;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
use tower_http::trace::TraceLayer;
use tracing::{info, log::warn};
@@ -117,8 +117,8 @@ async fn main() {
.route("/api/login", get(login))
.route(auth::LOGIN_CALLBACK, get(login_callback))
.layer(CookieManagerLayer::new())
- .layer(Extension(state))
- .layer(TraceLayer::new_for_http());
+ .layer(TraceLayer::new_for_http())
+ .with_state(state);
axum::Server::bind(&bind_address.parse().unwrap())
.serve(app.into_make_service())
@@ -145,12 +145,12 @@ async fn fake_login() -> Html<&'static str> {
async fn get_table_view(
_session: AuthenticatedSession,
- extension: ContextExtension,
+ State(state): ServerState,
Path(id): Path<Uuid>,
) -> Result<Json<protocol::bridge_engine::GameStatePlayerView>, BridgeError> {
info!("Getting table state for {id:}");
let player_position = Player::South;
- let jnl = DbJournal::new(extension.db.clone(), id);
+ let jnl = DbJournal::new(state.db.clone(), id);
let mut table = play::Table::new_or_replay(jnl).await?;
info!("Advancing play");
while table.game()?.current_player() != player_position {
@@ -166,12 +166,12 @@ async fn get_table_view(
async fn post_bid(
_session: AuthenticatedSession,
- extension: ContextExtension,
+ State(state): ServerState,
Path(id): Path<Uuid>,
Json(bid): Json<Bid>,
) -> Result<Json<()>, BridgeError> {
info!("Getting table state for {id:}");
- let jnl = DbJournal::new(extension.db.clone(), id);
+ let jnl = DbJournal::new(state.db.clone(), id);
let mut table = play::Table::replay(jnl).await?;
if !table.game()?.is_bidding() {
return Err(BridgeError::InvalidRequest(
@@ -191,7 +191,7 @@ async fn post_bid(
async fn leave_table(
session: AuthenticatedSession,
- extension: ContextExtension,
+ State(state): ServerState,
) -> Result<(), BridgeError> {
sqlx::query!(
r#"
@@ -199,16 +199,16 @@ async fn leave_table(
"#,
session.player_id
)
- .execute(&extension.db)
+ .execute(&state.db)
.await?;
Ok(())
}
async fn create_table(
session: AuthenticatedSession,
- extension: ContextExtension,
+ State(state): ServerState,
) -> Result<Json<Uuid>, BridgeError> {
- let txn = extension.db.begin().await?;
+ let txn = state.db.begin().await?;
let table_id = sqlx::query!(
r#"
insert into active_tables (id)
@@ -217,7 +217,7 @@ async fn create_table(
"#,
Uuid::new_v4()
)
- .fetch_one(&extension.db)
+ .fetch_one(&state.db)
.await?
.id;
@@ -231,7 +231,7 @@ async fn create_table(
table_id,
session.player_id
)
- .execute(&extension.db)
+ .execute(&state.db)
.await?;
txn.commit().await?;
@@ -240,20 +240,20 @@ async fn create_table(
async fn user_info(
session: Option<AuthenticatedSession>,
- extension: ContextExtension,
+ State(state): ServerState,
) -> Result<Json<Option<UserInfo>>, BridgeError> {
let mut session = match session {
None => return Ok(Json(None)),
Some(s) => s,
};
Ok(Json(Some(UserInfo {
- username: extension.authenticator.user_info(&mut session).await?,
- table: user_table(extension, &session).await?,
+ username: state.authenticator.user_info(&mut session).await?,
+ table: user_table(&*state, &session).await?,
})))
}
async fn user_table(
- extension: ContextExtension,
+ state: &ServerContext,
session: &AuthenticatedSession,
) -> Result<Option<Table>, BridgeError> {
Ok(sqlx::query_as!(
@@ -266,27 +266,27 @@ async fn user_table(
"#,
session.player_id
)
- .fetch_optional(&extension.db)
+ .fetch_optional(&state.db)
.await?)
}
async fn login_callback(
cookies: Cookies,
Query(params): Query<HashMap<String, String>>,
- extension: ContextExtension,
+ State(state): ServerState,
) -> Result<Redirect, BridgeError> {
let cookie = cookies.get("user-id").unwrap();
let user_id: SessionId = SessionId::from_str(cookie.value())?;
- let session = extension
+ let session = state
.authenticator
- .authenticate(&extension.db, user_id, params)
+ .authenticate(&state.db, user_id, params)
.await?;
info!("Logged in session: {session:?}");
- Ok(Redirect::temporary(&extension.app_url))
+ Ok(Redirect::temporary(&state.app_url))
}
-async fn login(cookies: Cookies, extension: ContextExtension) -> Redirect {
- let (user_id, auth_url) = extension.authenticator.get_login_url().await;
+async fn login(cookies: Cookies, State(state): ServerState) -> Redirect {
+ let (user_id, auth_url) = state.authenticator.get_login_url().await;
info!("Creating auth url for {user_id:?}");
let user_id = serde_json::to_string(&user_id).unwrap();
let mut cookie = Cookie::new("user-id", user_id.to_string());
diff --git a/server/src/server.rs b/server/src/server.rs
index 647abf9..4e563df 100644
--- a/server/src/server.rs
+++ b/server/src/server.rs
@@ -1,7 +1,7 @@
use sqlx::PgPool;
use std::sync::Arc;
-use axum::Extension;
+use axum::extract::State;
use crate::auth::Authenticator;
@@ -10,4 +10,4 @@ pub struct ServerContext {
pub authenticator: Box<dyn Authenticator + Send + Sync>,
pub db: PgPool,
}
-pub type ContextExtension = Extension<Arc<ServerContext>>;
+pub type ServerState = State<Arc<ServerContext>>;