1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
use std::{collections::HashMap, env, sync::Arc};
use axum::{
extract::{Extension, Query},
response::{Redirect, IntoResponse},
routing::get,
Json, Router, http::StatusCode,
};
use protocol::UserInfo;
use tower_cookies::{Cookie, CookieManagerLayer, Cookies};
use tower_http::trace::TraceLayer;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod auth;
mod error;
use crate::auth::{Authenticator, EndUserId};
use sqlx::{postgres::PgPoolOptions, PgPool};
use crate::error::BridgeError;
pub struct ServerContext {
pub app_url: String,
pub authenticator: Authenticator,
pub db: PgPool,
}
type ContextExtension = Extension<Arc<ServerContext>>;
#[tokio::main]
async fn main() {
dotenv::dotenv().ok();
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
info!("Opening database connection");
let db_url = env::var("DATABASE_URL").unwrap();
let db_pool: PgPool = PgPoolOptions::new()
.max_connections(10)
.connect(&db_url).await.expect("db connection");
info!("Running db migrations");
sqlx::migrate!().run(&db_pool).await.expect("db migration");
let bind_address = env::var("BIND_ADDRESS").unwrap();
info!("Starting server on {}", bind_address);
let app_url = env::var("APP_URL").unwrap();
let state = Arc::new(ServerContext {
app_url: app_url,
authenticator: Authenticator::from_env().await,
db: db_pool,
});
let app = Router::new()
.route("/api/user/info", get(user_info))
.route("/api/login", get(login))
.route(auth::LOGIN_CALLBACK, get(login_callback))
.layer(CookieManagerLayer::new())
.layer(Extension(state))
.layer(TraceLayer::new_for_http());
axum::Server::bind(&bind_address.parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
async fn user_info() -> Json<Option<UserInfo>> {
Json(None)
}
async fn login_callback(
cookies: Cookies,
Query(params): Query<HashMap<String, String>>,
extension: ContextExtension,
) -> Result<(), BridgeError> {
let cookie = cookies.get("user-id").unwrap();
let user_id: EndUserId =
serde_json::from_str(&urlencoding::decode(cookie.value()).unwrap()).unwrap();
info!("cookie: {cookie:?}");
info!("params: {params:?}");
extension.authenticator.authenticate(user_id, params).await?;
Ok(())
}
async fn login(cookies: Cookies, extension: ContextExtension) -> Redirect {
let (user_id, auth_url) = extension.authenticator.get_login_url().await;
info!("Creating auth url for {user_id:?}");
let user_id = serde_json::to_string(&user_id).unwrap();
cookies.add(Cookie::new(
"user-id",
urlencoding::encode(&user_id).to_string(),
));
Redirect::temporary(auth_url.as_str())
}
|