From 568b21aa76a452658b6b7f1b01e6ab75a49592cf Mon Sep 17 00:00:00 2001 From: Kjetil Orbekk Date: Sun, 2 Feb 2020 16:07:29 -0500 Subject: Refresh strava token if it's too old --- src/db.rs | 6 ++-- src/error.rs | 64 +++++++++++++++++++++++++++++++++ src/importer.rs | 32 +++++++++++++---- src/lib.rs | 6 ++++ src/server.rs | 25 +++++++------ src/strava.rs | 108 ++++++++++++++++++++++++++++++++++++++++++++------------ 6 files changed, 196 insertions(+), 45 deletions(-) diff --git a/src/db.rs b/src/db.rs index 091bc64..20123bf 100644 --- a/src/db.rs +++ b/src/db.rs @@ -87,8 +87,10 @@ pub fn get_user(conn: &PgConnection, username: &str) -> Result Result { +pub fn get_strava_token( + conn: &PgConnection, + user: &models::User, +) -> Result { use crate::schema::strava_tokens; let token = strava_tokens::table diff --git a/src/error.rs b/src/error.rs index 7e68288..4ae2995 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,15 +1,72 @@ use bcrypt::BcryptError; use diesel::result::Error as DieselErr; +use serde_json::Value; use std::convert::From; use std::error::Error as StdError; use std::fmt; +#[derive(Debug)] +pub struct StravaApiError{ + status: reqwest::StatusCode, + code: String, + field: String, + value: Value, +} + +impl StravaApiError { + pub fn new(status: reqwest::StatusCode, value: Value) -> StravaApiError { + let first_error = &value["errors"][0]; + + let code = first_error["code"].as_str().unwrap_or("unknown").to_string(); + let field = first_error["field"].as_str().unwrap_or("unknown").to_string(); + + StravaApiError { status, code, field, value } + } + } + +impl fmt::Display for StravaApiError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "field '{}' has error '{}'", self.field, self.code) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn strava_api_error_unknown() { + let data = r#""insane input""#; + let json = serde_json::from_str(data).unwrap(); + let error = Error::StravaApiError(StravaApiError::new(reqwest::StatusCode::UNAUTHORIZED, json)); + assert_eq!("field 'unknown' has error 'unknown'", + format!("{}", error)); + } + + #[test] + fn strava_api_error_invalid() { + let data = r#" + { + "errors":[ + {"code":"invalid", + "field":"access_token", + "resource":"Athlete"} + ], + "message":"Authorization Error" + }"#; + let json = serde_json::from_str(data).unwrap(); + let error = Error::StravaApiError(StravaApiError::new(reqwest::StatusCode::UNAUTHORIZED, json)); + assert_eq!("field 'access_token' has error 'invalid'", + format!("{}", error)); + } +} + #[derive(Debug)] pub enum Error { DieselError(DieselErr), PasswordError(BcryptError), CommunicationError(reqwest::Error), ParseError(serde_json::error::Error), + StravaApiError(StravaApiError), AlreadyExists, NotFound, InternalError, @@ -22,6 +79,7 @@ impl fmt::Display for Error { Error::PasswordError(ref e) => e.fmt(f), Error::CommunicationError(ref e) => e.fmt(f), Error::ParseError(ref e) => e.fmt(f), + Error::StravaApiError(ref e) => e.fmt(f), Error::AlreadyExists => f.write_str("AlreadyExists"), Error::NotFound => f.write_str("NotFound"), Error::InternalError => f.write_str("InternalError"), @@ -29,6 +87,12 @@ impl fmt::Display for Error { } } +impl From for Error { + fn from(e: StravaApiError) -> Error { + Error::StravaApiError(e) + } +} + impl From for Error { fn from(e: serde_json::error::Error) -> Error { Error::ParseError(e) diff --git a/src/importer.rs b/src/importer.rs index e1f73e2..ab77837 100644 --- a/src/importer.rs +++ b/src/importer.rs @@ -1,3 +1,4 @@ +use diesel::PgConnection; use std::sync::mpsc::channel; use std::sync::mpsc::Receiver; use std::sync::mpsc::Sender; @@ -5,12 +6,14 @@ use std::sync::Arc; use std::sync::Mutex; use std::sync::RwLock; use threadpool::ThreadPool; -use diesel::PgConnection; +use chrono::Utc; +use crate::error::Error; use crate::db; -use crate::strava; use crate::models; +use crate::strava; use crate::strava::StravaApi; +use crate::Params; pub const WORKERS: usize = 10; pub const EMPTY_PARAMS: &[(&str, &str)] = &[]; @@ -29,12 +32,26 @@ struct ImporterState { rx: Arc>>, } +fn get_or_refresh_token(strava: &Strava, conn: &PgConnection, user: &models::User) -> Result { + let mut token = db::get_strava_token(&conn, &user).expect("FIX"); + + if token.expires_at < Utc::now() { + info!("refresh expired token: {:?}", token.expires_at); + let new_token = strava.refresh_token(&From::from(&token))?; + new_token.update_model(&mut token); + } + + Ok(token) +} + fn import_strava_user(state: ImporterState, user: models::User) { let strava = state.strava.read().expect("FIX"); let conn = state.conn.lock().expect("FIX"); - let token = db::get_strava_token(&conn, &user).expect("FIX"); - let result = strava.get("/athlete/activities", &token.access_token, EMPTY_PARAMS).expect("ok"); - info!("Imported user. Got result: {:#?}", result); + let token = get_or_refresh_token(&*strava, &conn, &user).expect("FIX"); + let result = strava + .get("/athlete/activities", &token.access_token, EMPTY_PARAMS) + .expect("ok"); + info!("import_strava_user: Got result: {:#?}", result); } fn handle_command(state: ImporterState, command: Command) { @@ -65,12 +82,13 @@ fn receive_commands(state: ImporterState) { } } -pub fn run(pool: ThreadPool, conn: PgConnection) -> Sender { +pub fn run(pool: ThreadPool, conn: PgConnection, params: &Params) -> Sender { let (tx, rx0) = channel(); let state = ImporterState { pool: pool.clone(), conn: Arc::new(Mutex::new(conn)), - strava: Arc::new(RwLock::new(strava::StravaImpl::new())), + strava: Arc::new(RwLock::new(strava::StravaImpl::new( + params.strava_client_id.clone(), params.strava_client_secret.clone()))), rx: Arc::new(Mutex::new(rx0)), }; pool.execute(move || receive_commands(state)); diff --git a/src/lib.rs b/src/lib.rs index f55fef1..b80a235 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,3 +20,9 @@ pub mod models; mod schema; pub mod server; mod strava; + +pub struct Params { + pub base_url: String, + pub strava_client_id: String, + pub strava_client_secret: String, +} diff --git a/src/server.rs b/src/server.rs index 79c3226..abc430b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,21 +14,16 @@ use rocket::response::Redirect; use rocket::State; use rocket_contrib::templates::Template; use std::collections::HashMap; -use threadpool::ThreadPool; -use std::sync::Mutex; use std::sync::mpsc::Sender; +use std::sync::Mutex; +use threadpool::ThreadPool; use crate::db; use crate::error::Error; use crate::importer; use crate::models; use crate::strava; - -pub struct Params { - pub base_url: String, - pub strava_client_id: String, - pub strava_client_secret: String, -} +use crate::Params; #[database("db")] pub struct Db(diesel::PgConnection); @@ -62,7 +57,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for LoggedInUser { Ok(user) => { info!("Credentials: {:?}", user); Outcome::Success(user) - }, + } Err(Error::NotFound) => Outcome::Forward(()), Err(e) => Outcome::Failure((Status::InternalServerError, e)), } @@ -139,9 +134,13 @@ fn link_strava_callback( fn import_strava( conn: Db, tx: State>>, - user: LoggedInUser) -> Result<(), Error> { + user: LoggedInUser, +) -> Result<(), Error> { let user = db::get_user(&*conn, &user.username)?; - tx.lock().expect("FIX").send(importer::Command::ImportStravaUser(user)).expect("FIX"); + tx.lock() + .expect("FIX") + .send(importer::Command::ImportStravaUser(user)) + .expect("FIX"); Ok(()) } @@ -154,7 +153,7 @@ fn link_strava(params: State) -> Redirect { "response_type=code&", "redirect_uri={}&", "approval_prompt=force&", - "scope=read", + "scope=read_all,activity:read_all,profile:read_all", ), params.strava_client_id, format!("{}/link_strava_callback", params.base_url) @@ -181,7 +180,7 @@ pub fn start(conn: diesel::PgConnection, db_url: &str, base_url: &str) { .unwrap(); let importer_pool = ThreadPool::with_name("import".to_string(), importer::WORKERS); - let tx = importer::run(importer_pool.clone(), conn); + let tx = importer::run(importer_pool.clone(), conn, ¶ms); rocket::custom(config) .manage(params) diff --git a/src/strava.rs b/src/strava.rs index 6be5466..284d8b1 100644 --- a/src/strava.rs +++ b/src/strava.rs @@ -1,4 +1,6 @@ +use crate::error; use crate::error::Error; +use crate::models; use chrono::serde::ts_seconds; use chrono::DateTime; use chrono::Utc; @@ -8,53 +10,113 @@ use serde::Serialize; use serde_json::from_value; use serde_json::Value; +#[derive(Serialize, Deserialize, Debug)] +pub struct Token { + #[serde(with = "ts_seconds")] + pub expires_at: DateTime, + pub refresh_token: String, + pub access_token: String, +} + +impl Token { + pub fn update_model(&self, out: &mut models::StravaToken) { + out.expires_at = self.expires_at.clone(); + out.refresh_token = self.refresh_token.clone(); + out.access_token = self.access_token.clone(); + } +} + +impl From<&models::StravaToken> for Token { + fn from(t: &models::StravaToken) -> Token { + Token { + expires_at: t.expires_at, + refresh_token: t.refresh_token.clone(), + access_token: t.access_token.clone(), + } + } +} + pub trait StravaApi { - fn get(&self, method: &str, access_token: &str, parasm: &T) -> Result; + fn get( + &self, + method: &str, + access_token: &str, + parasm: &T, + ) -> Result; + + fn refresh_token( + &self, + token: &Token) -> Result; } pub struct StravaImpl { client: reqwest::blocking::Client, base_url: String, + api_url: String, + client_id: String, + client_secret: String, } impl StravaImpl { - pub fn new() -> StravaImpl { + pub fn new(client_id: String, client_secret: String) -> StravaImpl { StravaImpl { client: reqwest::blocking::Client::new(), - base_url: "https://www.strava.com/api/v3".to_string(), + base_url: "https://www.strava.com".to_string(), + api_url: "/api/v3".to_string(), + client_id, + client_secret, } } } impl StravaApi for StravaImpl { - fn get(&self, method: &str, access_token: &str, - params: &T) -> Result { - let uri = format!("{}{}", self.base_url, method); - let response = self.client.get(&uri) + fn get( + &self, + method: &str, + access_token: &str, + params: &T, + ) -> Result { + let uri = format!("{}{}{}", self.base_url, self.api_url, method); + let response = self + .client + .get(&uri) .bearer_auth(access_token) .query(params) .send()?; info!("StravaApi::get({}) returned {:?}", method, response); - let json = response.json()?; + let status = response.status(); + let json: Value = response.json()?; + + if !status.is_success() { + return Err(From::from(error::StravaApiError::new(status, json))); + } Ok(json) } -} -#[derive(Serialize, Deserialize, Debug)] -pub struct AthleteSummary { - id: i64, - username: String, - firstname: String, - lastname: String, -} + fn refresh_token( + &self, + token: &Token) -> Result { + let uri = format!("{}{}{}", self.base_url, self.api_url, "/oauth/token"); + let params = [ + ("client_id", self.client_id.as_str()), + ("client_secret", self.client_secret.as_str()), + ("grant_type", "refresh_token"), + ("refresh_token", token.refresh_token.as_str()), + ]; + let response = self + .client + .post(&uri) + .form(¶ms) + .send()?; + info!("StravaApi::refresh_token returned {:?}", response); + let status = response.status(); + let json: Value = response.json()?; -#[derive(Serialize, Deserialize, Debug)] -pub struct Token { - #[serde(with = "ts_seconds")] - pub expires_at: DateTime, - pub refresh_token: String, - pub access_token: String, - pub athlete: AthleteSummary, + if !status.is_success() { + return Err(From::from(error::StravaApiError::new(status, json))); + } + from_value(json).map_err(From::from) + } } pub fn exchange_token(client_id: &str, client_secret: &str, code: &str) -> Result { -- cgit v1.2.3