summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKjetil Orbekk <kjetil.orbekk@gmail.com>2020-02-02 16:07:29 -0500
committerKjetil Orbekk <kjetil.orbekk@gmail.com>2020-02-02 16:07:29 -0500
commit568b21aa76a452658b6b7f1b01e6ab75a49592cf (patch)
treefa1d106d5b7814f6375a19382bd316b7aa4e29b7
parenta226e7d888df3342f26e7eaaf1a24d0397d4dbad (diff)
Refresh strava token if it's too old
-rw-r--r--src/db.rs6
-rw-r--r--src/error.rs64
-rw-r--r--src/importer.rs32
-rw-r--r--src/lib.rs6
-rw-r--r--src/server.rs25
-rw-r--r--src/strava.rs108
6 files changed, 196 insertions, 45 deletions
diff --git a/src/db.rs b/src/db.rs
index 091bc64..20123bf 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -87,8 +87,10 @@ pub fn get_user(conn: &PgConnection, username: &str) -> Result<models::User, Err
Ok(user)
}
-pub fn get_strava_token(conn: &PgConnection, user: &models::User)
- -> Result<models::StravaToken, Error> {
+pub fn get_strava_token(
+ conn: &PgConnection,
+ user: &models::User,
+) -> Result<models::StravaToken, Error> {
use crate::schema::strava_tokens;
let token = strava_tokens::table
diff --git a/src/error.rs b/src/error.rs
index 7e68288..4ae2995 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,15 +1,72 @@
use bcrypt::BcryptError;
use diesel::result::Error as DieselErr;
+use serde_json::Value;
use std::convert::From;
use std::error::Error as StdError;
use std::fmt;
#[derive(Debug)]
+pub struct StravaApiError{
+ status: reqwest::StatusCode,
+ code: String,
+ field: String,
+ value: Value,
+}
+
+impl StravaApiError {
+ pub fn new(status: reqwest::StatusCode, value: Value) -> StravaApiError {
+ let first_error = &value["errors"][0];
+
+ let code = first_error["code"].as_str().unwrap_or("unknown").to_string();
+ let field = first_error["field"].as_str().unwrap_or("unknown").to_string();
+
+ StravaApiError { status, code, field, value }
+ }
+ }
+
+impl fmt::Display for StravaApiError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "field '{}' has error '{}'", self.field, self.code)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn strava_api_error_unknown() {
+ let data = r#""insane input""#;
+ let json = serde_json::from_str(data).unwrap();
+ let error = Error::StravaApiError(StravaApiError::new(reqwest::StatusCode::UNAUTHORIZED, json));
+ assert_eq!("field 'unknown' has error 'unknown'",
+ format!("{}", error));
+ }
+
+ #[test]
+ fn strava_api_error_invalid() {
+ let data = r#"
+ {
+ "errors":[
+ {"code":"invalid",
+ "field":"access_token",
+ "resource":"Athlete"}
+ ],
+ "message":"Authorization Error"
+ }"#;
+ let json = serde_json::from_str(data).unwrap();
+ let error = Error::StravaApiError(StravaApiError::new(reqwest::StatusCode::UNAUTHORIZED, json));
+ assert_eq!("field 'access_token' has error 'invalid'",
+ format!("{}", error));
+ }
+}
+
+#[derive(Debug)]
pub enum Error {
DieselError(DieselErr),
PasswordError(BcryptError),
CommunicationError(reqwest::Error),
ParseError(serde_json::error::Error),
+ StravaApiError(StravaApiError),
AlreadyExists,
NotFound,
InternalError,
@@ -22,6 +79,7 @@ impl fmt::Display for Error {
Error::PasswordError(ref e) => e.fmt(f),
Error::CommunicationError(ref e) => e.fmt(f),
Error::ParseError(ref e) => e.fmt(f),
+ Error::StravaApiError(ref e) => e.fmt(f),
Error::AlreadyExists => f.write_str("AlreadyExists"),
Error::NotFound => f.write_str("NotFound"),
Error::InternalError => f.write_str("InternalError"),
@@ -29,6 +87,12 @@ impl fmt::Display for Error {
}
}
+impl From<StravaApiError> for Error {
+ fn from(e: StravaApiError) -> Error {
+ Error::StravaApiError(e)
+ }
+}
+
impl From<serde_json::error::Error> for Error {
fn from(e: serde_json::error::Error) -> Error {
Error::ParseError(e)
diff --git a/src/importer.rs b/src/importer.rs
index e1f73e2..ab77837 100644
--- a/src/importer.rs
+++ b/src/importer.rs
@@ -1,3 +1,4 @@
+use diesel::PgConnection;
use std::sync::mpsc::channel;
use std::sync::mpsc::Receiver;
use std::sync::mpsc::Sender;
@@ -5,12 +6,14 @@ use std::sync::Arc;
use std::sync::Mutex;
use std::sync::RwLock;
use threadpool::ThreadPool;
-use diesel::PgConnection;
+use chrono::Utc;
+use crate::error::Error;
use crate::db;
-use crate::strava;
use crate::models;
+use crate::strava;
use crate::strava::StravaApi;
+use crate::Params;
pub const WORKERS: usize = 10;
pub const EMPTY_PARAMS: &[(&str, &str)] = &[];
@@ -29,12 +32,26 @@ struct ImporterState {
rx: Arc<Mutex<Receiver<Command>>>,
}
+fn get_or_refresh_token<Strava: strava::StravaApi>(strava: &Strava, conn: &PgConnection, user: &models::User) -> Result<models::StravaToken, Error> {
+ let mut token = db::get_strava_token(&conn, &user).expect("FIX");
+
+ if token.expires_at < Utc::now() {
+ info!("refresh expired token: {:?}", token.expires_at);
+ let new_token = strava.refresh_token(&From::from(&token))?;
+ new_token.update_model(&mut token);
+ }
+
+ Ok(token)
+}
+
fn import_strava_user(state: ImporterState, user: models::User) {
let strava = state.strava.read().expect("FIX");
let conn = state.conn.lock().expect("FIX");
- let token = db::get_strava_token(&conn, &user).expect("FIX");
- let result = strava.get("/athlete/activities", &token.access_token, EMPTY_PARAMS).expect("ok");
- info!("Imported user. Got result: {:#?}", result);
+ let token = get_or_refresh_token(&*strava, &conn, &user).expect("FIX");
+ let result = strava
+ .get("/athlete/activities", &token.access_token, EMPTY_PARAMS)
+ .expect("ok");
+ info!("import_strava_user: Got result: {:#?}", result);
}
fn handle_command(state: ImporterState, command: Command) {
@@ -65,12 +82,13 @@ fn receive_commands(state: ImporterState) {
}
}
-pub fn run(pool: ThreadPool, conn: PgConnection) -> Sender<Command> {
+pub fn run(pool: ThreadPool, conn: PgConnection, params: &Params) -> Sender<Command> {
let (tx, rx0) = channel();
let state = ImporterState {
pool: pool.clone(),
conn: Arc::new(Mutex::new(conn)),
- strava: Arc::new(RwLock::new(strava::StravaImpl::new())),
+ strava: Arc::new(RwLock::new(strava::StravaImpl::new(
+ params.strava_client_id.clone(), params.strava_client_secret.clone()))),
rx: Arc::new(Mutex::new(rx0)),
};
pool.execute(move || receive_commands(state));
diff --git a/src/lib.rs b/src/lib.rs
index f55fef1..b80a235 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -20,3 +20,9 @@ pub mod models;
mod schema;
pub mod server;
mod strava;
+
+pub struct Params {
+ pub base_url: String,
+ pub strava_client_id: String,
+ pub strava_client_secret: String,
+}
diff --git a/src/server.rs b/src/server.rs
index 79c3226..abc430b 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -14,21 +14,16 @@ use rocket::response::Redirect;
use rocket::State;
use rocket_contrib::templates::Template;
use std::collections::HashMap;
-use threadpool::ThreadPool;
-use std::sync::Mutex;
use std::sync::mpsc::Sender;
+use std::sync::Mutex;
+use threadpool::ThreadPool;
use crate::db;
use crate::error::Error;
use crate::importer;
use crate::models;
use crate::strava;
-
-pub struct Params {
- pub base_url: String,
- pub strava_client_id: String,
- pub strava_client_secret: String,
-}
+use crate::Params;
#[database("db")]
pub struct Db(diesel::PgConnection);
@@ -62,7 +57,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for LoggedInUser {
Ok(user) => {
info!("Credentials: {:?}", user);
Outcome::Success(user)
- },
+ }
Err(Error::NotFound) => Outcome::Forward(()),
Err(e) => Outcome::Failure((Status::InternalServerError, e)),
}
@@ -139,9 +134,13 @@ fn link_strava_callback(
fn import_strava(
conn: Db,
tx: State<Mutex<Sender<importer::Command>>>,
- user: LoggedInUser) -> Result<(), Error> {
+ user: LoggedInUser,
+) -> Result<(), Error> {
let user = db::get_user(&*conn, &user.username)?;
- tx.lock().expect("FIX").send(importer::Command::ImportStravaUser(user)).expect("FIX");
+ tx.lock()
+ .expect("FIX")
+ .send(importer::Command::ImportStravaUser(user))
+ .expect("FIX");
Ok(())
}
@@ -154,7 +153,7 @@ fn link_strava(params: State<Params>) -> Redirect {
"response_type=code&",
"redirect_uri={}&",
"approval_prompt=force&",
- "scope=read",
+ "scope=read_all,activity:read_all,profile:read_all",
),
params.strava_client_id,
format!("{}/link_strava_callback", params.base_url)
@@ -181,7 +180,7 @@ pub fn start(conn: diesel::PgConnection, db_url: &str, base_url: &str) {
.unwrap();
let importer_pool = ThreadPool::with_name("import".to_string(), importer::WORKERS);
- let tx = importer::run(importer_pool.clone(), conn);
+ let tx = importer::run(importer_pool.clone(), conn, &params);
rocket::custom(config)
.manage(params)
diff --git a/src/strava.rs b/src/strava.rs
index 6be5466..284d8b1 100644
--- a/src/strava.rs
+++ b/src/strava.rs
@@ -1,4 +1,6 @@
+use crate::error;
use crate::error::Error;
+use crate::models;
use chrono::serde::ts_seconds;
use chrono::DateTime;
use chrono::Utc;
@@ -8,53 +10,113 @@ use serde::Serialize;
use serde_json::from_value;
use serde_json::Value;
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Token {
+ #[serde(with = "ts_seconds")]
+ pub expires_at: DateTime<Utc>,
+ pub refresh_token: String,
+ pub access_token: String,
+}
+
+impl Token {
+ pub fn update_model(&self, out: &mut models::StravaToken) {
+ out.expires_at = self.expires_at.clone();
+ out.refresh_token = self.refresh_token.clone();
+ out.access_token = self.access_token.clone();
+ }
+}
+
+impl From<&models::StravaToken> for Token {
+ fn from(t: &models::StravaToken) -> Token {
+ Token {
+ expires_at: t.expires_at,
+ refresh_token: t.refresh_token.clone(),
+ access_token: t.access_token.clone(),
+ }
+ }
+}
+
pub trait StravaApi {
- fn get<T: Serialize + ?Sized>(&self, method: &str, access_token: &str, parasm: &T) -> Result<Value, Error>;
+ fn get<T: Serialize + ?Sized>(
+ &self,
+ method: &str,
+ access_token: &str,
+ parasm: &T,
+ ) -> Result<Value, Error>;
+
+ fn refresh_token(
+ &self,
+ token: &Token) -> Result<Token, Error>;
}
pub struct StravaImpl {
client: reqwest::blocking::Client,
base_url: String,
+ api_url: String,
+ client_id: String,
+ client_secret: String,
}
impl StravaImpl {
- pub fn new() -> StravaImpl {
+ pub fn new(client_id: String, client_secret: String) -> StravaImpl {
StravaImpl {
client: reqwest::blocking::Client::new(),
- base_url: "https://www.strava.com/api/v3".to_string(),
+ base_url: "https://www.strava.com".to_string(),
+ api_url: "/api/v3".to_string(),
+ client_id,
+ client_secret,
}
}
}
impl StravaApi for StravaImpl {
- fn get<T: Serialize + ?Sized>(&self, method: &str, access_token: &str,
- params: &T) -> Result<Value, Error> {
- let uri = format!("{}{}", self.base_url, method);
- let response = self.client.get(&uri)
+ fn get<T: Serialize + ?Sized>(
+ &self,
+ method: &str,
+ access_token: &str,
+ params: &T,
+ ) -> Result<Value, Error> {
+ let uri = format!("{}{}{}", self.base_url, self.api_url, method);
+ let response = self
+ .client
+ .get(&uri)
.bearer_auth(access_token)
.query(params)
.send()?;
info!("StravaApi::get({}) returned {:?}", method, response);
- let json = response.json()?;
+ let status = response.status();
+ let json: Value = response.json()?;
+
+ if !status.is_success() {
+ return Err(From::from(error::StravaApiError::new(status, json)));
+ }
Ok(json)
}
-}
-#[derive(Serialize, Deserialize, Debug)]
-pub struct AthleteSummary {
- id: i64,
- username: String,
- firstname: String,
- lastname: String,
-}
+ fn refresh_token(
+ &self,
+ token: &Token) -> Result<Token, Error> {
+ let uri = format!("{}{}{}", self.base_url, self.api_url, "/oauth/token");
+ let params = [
+ ("client_id", self.client_id.as_str()),
+ ("client_secret", self.client_secret.as_str()),
+ ("grant_type", "refresh_token"),
+ ("refresh_token", token.refresh_token.as_str()),
+ ];
+ let response = self
+ .client
+ .post(&uri)
+ .form(&params)
+ .send()?;
+ info!("StravaApi::refresh_token returned {:?}", response);
+ let status = response.status();
+ let json: Value = response.json()?;
-#[derive(Serialize, Deserialize, Debug)]
-pub struct Token {
- #[serde(with = "ts_seconds")]
- pub expires_at: DateTime<Utc>,
- pub refresh_token: String,
- pub access_token: String,
- pub athlete: AthleteSummary,
+ if !status.is_success() {
+ return Err(From::from(error::StravaApiError::new(status, json)));
+ }
+ from_value(json).map_err(From::from)
+ }
}
pub fn exchange_token(client_id: &str, client_secret: &str, code: &str) -> Result<Token, Error> {