summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKjetil Orbekk <kjetil.orbekk@gmail.com>2020-02-03 22:55:36 -0500
committerKjetil Orbekk <kjetil.orbekk@gmail.com>2020-02-03 22:55:36 -0500
commit6d0a4d03705b96b252a6b29d3b8c188b9c903b89 (patch)
treeb8ea3f7459ae4c9b22a976259e637cc7a3d695c7 /src
parentc459b5e85ef9b695b3c9a107b7cf7f08847c608f (diff)
Refactor importer to store tasks in postgresql
Diffstat (limited to 'src')
-rw-r--r--src/db.rs62
-rw-r--r--src/error.rs7
-rw-r--r--src/importer.rs257
-rw-r--r--src/models.rs59
-rw-r--r--src/schema.rs18
-rw-r--r--src/server.rs28
-rw-r--r--src/strava.rs10
7 files changed, 359 insertions, 82 deletions
diff --git a/src/db.rs b/src/db.rs
index 20123bf..f3a261c 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -6,6 +6,9 @@ use diesel::pg::PgConnection;
use diesel::ExpressionMethods;
use diesel::QueryDsl;
use diesel::RunQueryDsl;
+use std::time::Duration;
+use chrono::DateTime;
+use chrono::Utc;
pub const COST: u32 = 10;
@@ -98,3 +101,62 @@ pub fn get_strava_token(
.get_result::<models::StravaToken>(conn)?;
Ok(token)
}
+
+pub fn insert_task(
+ conn: &PgConnection,
+ task: &models::NewTask) -> Result<i64, Error> {
+ use crate::schema::tasks;
+ let id = diesel::insert_into(tasks::table)
+ .values(task)
+ .returning(tasks::id)
+ .get_result(conn)?;
+ Ok(id)
+}
+
+fn update_task_inner(conn: &PgConnection, task: &models::Task)
+ -> Result<models::Task, Error> {
+ use crate::schema::tasks;
+
+ diesel::delete(tasks::table.filter(tasks::columns::id.eq(task.id)))
+ .execute(conn)?;
+
+ let new_id = insert_task(conn, &models::NewTask {
+ start_at: task.start_at,
+ state: task.state,
+ username: &task.username,
+ payload: &task.payload,
+ })?;
+
+ let new_task = tasks::table.find(new_id)
+ .get_result::<models::Task>(conn)?;
+
+ Ok(new_task)
+}
+
+fn update_task(conn: &PgConnection, task: &models::Task) -> Result<models::Task, Error> {
+ conn.transaction(|| {
+ update_task_inner(conn, task)
+ })
+}
+
+pub fn take_task(
+ conn: &PgConnection,
+ state: models::TaskState,
+ start_before: DateTime<Utc>,
+ eta: DateTime<Utc>)
+ -> Result<models::Task, Error> {
+ use crate::schema::tasks;
+
+ conn.transaction(|| {
+ let mut task = tasks::table
+ .filter(tasks::state.eq(state))
+ .filter(tasks::start_at.lt(start_before))
+ .order(tasks::start_at.asc())
+ .first::<models::Task>(conn)?;
+
+ task.start_at = eta;
+ let task = update_task_inner(conn, &task)?;
+
+ Ok(task)
+ })
+}
diff --git a/src/error.rs b/src/error.rs
index 4ae2995..75a7568 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -67,6 +67,7 @@ pub enum Error {
CommunicationError(reqwest::Error),
ParseError(serde_json::error::Error),
StravaApiError(StravaApiError),
+ UnexpectedJson(Value),
AlreadyExists,
NotFound,
InternalError,
@@ -79,6 +80,7 @@ impl fmt::Display for Error {
Error::PasswordError(ref e) => e.fmt(f),
Error::CommunicationError(ref e) => e.fmt(f),
Error::ParseError(ref e) => e.fmt(f),
+ Error::UnexpectedJson(_) => f.write_str("UnexpectedJson"),
Error::StravaApiError(ref e) => e.fmt(f),
Error::AlreadyExists => f.write_str("AlreadyExists"),
Error::NotFound => f.write_str("NotFound"),
@@ -107,7 +109,10 @@ impl From<reqwest::Error> for Error {
impl From<DieselErr> for Error {
fn from(e: DieselErr) -> Error {
- Error::DieselError(e)
+ match e {
+ DieselErr::NotFound => Error::NotFound,
+ e => Error::DieselError(e)
+ }
}
}
diff --git a/src/importer.rs b/src/importer.rs
index 9ea7e35..6909350 100644
--- a/src/importer.rs
+++ b/src/importer.rs
@@ -7,6 +7,13 @@ use std::sync::Mutex;
use std::sync::RwLock;
use threadpool::ThreadPool;
use chrono::Utc;
+use timer::Timer;
+use timer::Guard;
+use std::time::Instant;
+use std::time::Duration;
+use std::thread;
+use serde::Deserialize;
+use serde::Serialize;
use crate::error::Error;
use crate::db;
@@ -18,92 +25,210 @@ use crate::Params;
pub const WORKERS: usize = 10;
pub const EMPTY_PARAMS: &[(&str, &str)] = &[];
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Command {
- ImportStravaUser(models::User),
- Quit,
+ ImportStravaUser { username: String },
}
-#[derive(Clone)]
-struct ImporterState {
- pool: ThreadPool,
- conn: Arc<Mutex<PgConnection>>,
- strava: Arc<RwLock<strava::StravaImpl>>,
- rx: Arc<Mutex<Receiver<Command>>>,
+macro_rules! clone {
+ ( [ $( $i:ident ),* ] $e:expr ) => {
+ {
+ $(let $i = $i.clone();)*
+ $e
+ }
+ }
}
-fn get_or_refresh_token<Strava: strava::StravaApi>(strava: &Strava, conn: &PgConnection, user: &models::User) -> Result<models::StravaToken, Error> {
- let mut token = db::get_strava_token(&conn, &user).expect("FIX");
+pub struct ImporterSharedData<StravaApi: strava::StravaApi + 'static> {
+ strava: RwLock<StravaApi>,
+ pool: Mutex<ThreadPool>,
+ conn: Mutex<PgConnection>,
+ running: Mutex<bool>,
+}
- if token.expires_at < Utc::now() {
- info!("refresh expired token: {:?}", token.expires_at);
- let new_token = strava.refresh_token(&From::from(&token))?;
- new_token.update_model(&mut token);
+pub struct Importer<StravaApi: strava::StravaApi + 'static> {
+ shared: Arc<ImporterSharedData<StravaApi>>,
+}
+
+fn run_periodically<S: strava::StravaApi>(
+ shared: Arc<ImporterSharedData<S>>,
+ period: Duration) {
+ let sleep_time = Duration::from_millis(1000);
+ let mut now = Instant::now();
+ loop {
+ while now.elapsed() < period {
+ if !*shared.running.lock().unwrap() {
+ return;
+ }
+ thread::sleep(sleep_time);
+ }
+ now = Instant::now();
+
+ info!("run_periodically: wakeup");
+ handle_tasks(shared.clone())
}
+}
- Ok(token)
+
+fn handle_one_task<S: strava::StravaApi>(
+ shared: Arc<ImporterSharedData<S>>) -> Result<models::Task, Error> {
+ let task = {
+ let conn = shared.conn.lock().unwrap();
+ let now = Utc::now();
+ let eta = now + chrono::Duration::seconds(5);
+
+ db::take_task(&conn,
+ models::TaskState::NEW,
+ now,
+ eta)?
+ };
+
+ let command = serde_json::from_value(task.payload.clone())?;
+
+ match command {
+ Command::ImportStravaUser{ username } => {
+ import_strava_user(shared, username.as_str())?
+ },
+ }
+
+ Ok(task)
}
-fn import_strava_user(state: ImporterState, user: models::User) {
- use std::thread::sleep;
- use std::time::Duration;
+fn handle_tasks<S: strava::StravaApi>(
+ shared: Arc<ImporterSharedData<S>>) {
+ let mut done = false;
+ while !done {
+ match handle_one_task(shared.clone()) {
+ Err(Error::NotFound) => {
+ info!("No more tasks");
+ done = true;
+ },
+ Err(e) => {
+ error!("Error handling task: {}", e);
+ }
+ Ok(t) => {
+ info!("Successfully handled task: {:?}", t);
+ }
+ };
+ }
+}
+
+impl<StravaApi: strava::StravaApi> Importer<StravaApi> {
+ pub fn new(conn: PgConnection, strava: StravaApi) -> Importer<StravaApi> {
+ let shared = Arc::new(ImporterSharedData {
+ pool: Mutex::new(ThreadPool::with_name("importer".to_string(), WORKERS)),
+ conn: Mutex::new(conn),
+ strava: RwLock::new(strava),
+ running: Mutex::new(false),
+ });
+ Importer { shared: shared }
+ }
- let strava = state.strava.read().expect("FIX");
- let conn = state.conn.lock().expect("FIX");
- let token = get_or_refresh_token(&*strava, &conn, &user).expect("FIX");
+ pub fn run(&self) {
+ info!("run()");
+ let pool = self.shared.pool.lock().unwrap();
+ let mut running = self.shared.running.lock().unwrap();
+ if !*running {
+ *running = true;
+ pool.execute({
+ let shared = self.shared.clone();
+ move || run_periodically(shared, Duration::from_secs(10))
+ });
+ }
+ }
+
+ pub fn join(&self) {
+ self.shared.pool.lock().expect("FIX").join()
+ }
+}
+
+fn import_strava_user<S: strava::StravaApi>(
+ shared: Arc<ImporterSharedData<S>>,
+ username: &str) -> Result<(), Error> {
+ let strava = shared.strava.read().unwrap();
+ let user = db::get_user(&shared.conn.lock().unwrap(), username)?;
+
+ let token = {
+ let conn = shared.conn.lock().unwrap();
+ get_or_refresh_token(&*strava, &conn, &user)?
+ };
+
+ let per_page = 30;
for page in 1.. {
let params = [
- ("page", &format!("{}", page)),
- ("per_page", &format!("{}", 200)),
+ ("page", &format!("{}", page)[..]),
+ ("per_page", &format!("{}", per_page)[..])
];
+
let result = strava
- .get("/athlete/activities", &token.access_token, &params)
- .expect("ok");
- // info!("import_strava_user: Got result: {:#?}", result);
- for activity in result.as_array().expect("FIX") {
+ .get("/athlete/activities", &token.access_token, &params[..])?;
+
+ let result = result.as_array().ok_or(
+ Error::UnexpectedJson(result.clone()))?;
+
+ for activity in result {
info!("activity id: {} start: {}", activity["id"], activity["start_date"]);
}
- sleep(Duration::from_secs(1));
- }
-}
-fn handle_command(state: ImporterState, command: Command) {
- info!("handle_command {:?}", command);
- match command {
- Command::ImportStravaUser(user) => import_strava_user(state, user),
- Command::Quit => (),
- }
+ if result.len() < per_page {
+ break;
+ }
+ thread::sleep(Duration::from_secs(1));
+ };
+
+ Err(Error::InternalError)
}
-fn receive_commands(state: ImporterState) {
- info!("receive_commands");
- match (|| -> Result<(), Box<dyn std::error::Error>> {
- let rx = state.rx.lock()?;
- let mut command = rx.recv()?;
- loop {
- info!("got command: {:?}", command);
- let state0 = state.clone();
- state.pool.execute(move || handle_command(state0, command));
- command = rx.recv()?;
- }
- })() {
- Ok(()) => (),
- Err(e) => {
- error!("receive_commands: {:?}", e);
- ()
- }
+fn get_or_refresh_token<Strava: strava::StravaApi>(strava: &Strava, conn: &PgConnection, user: &models::User) -> Result<models::StravaToken, Error> {
+ let mut token = db::get_strava_token(&conn, &user).expect("FIX");
+
+ if token.expires_at < Utc::now() {
+ info!("refresh expired token: {:?}", token.expires_at);
+ let new_token = strava.refresh_token(&From::from(&token))?;
+ new_token.update_model(&mut token);
}
-}
-pub fn run(pool: ThreadPool, conn: PgConnection, params: &Params) -> Sender<Command> {
- let (tx, rx0) = channel();
- let state = ImporterState {
- pool: pool.clone(),
- conn: Arc::new(Mutex::new(conn)),
- strava: Arc::new(RwLock::new(strava::StravaImpl::new(
- params.strava_client_id.clone(), params.strava_client_secret.clone()))),
- rx: Arc::new(Mutex::new(rx0)),
- };
- pool.execute(move || receive_commands(state));
- tx
+ Ok(token)
}
+
+// fn handle_command(state: Importer, command: Command) {
+// info!("handle_command {:?}", command);
+// match command {
+// Command::ImportStravaUser(user) => import_strava_user(state, user),
+// Command::Quit => (),
+// }
+// }
+
+// fn receive_commands(state: Importer) {
+// info!("receive_commands");
+// match (|| -> Result<(), Box<dyn std::error::Error>> {
+// let rx = state.rx.lock()?;
+// let mut command = rx.recv()?;
+// loop {
+// info!("got command: {:?}", command);
+// let state0 = state.clone();
+// state.pool.execute(move || handle_command(state0, command));
+// command = rx.recv()?;
+// }
+// })() {
+// Ok(()) => (),
+// Err(e) => {
+// error!("receive_commands: {:?}", e);
+// ()
+// }
+// }
+// }
+
+// pub fn run(pool: ThreadPool, conn: PgConnection, params: &Params) -> Sender<Command> {
+// let (tx, rx0) = channel();
+// let importer = Arc::new(Importer {
+// pool: Mutex::new(pool.clone()),
+// conn: Mutex::new(conn),
+// strava: RwLock::new(strava::StravaImpl::new(
+// params.strava_client_id.clone(), params.strava_client_secret.clone())),
+// rx: Mutex::new(rx0),
+// });
+// // pool.execute(move || receive_commands(state));
+// pool.execute(clone! { [importer] move || importer.run() });
+// tx
+// }
diff --git a/src/models.rs b/src/models.rs
index ce3dd19..0b7e5db 100644
--- a/src/models.rs
+++ b/src/models.rs
@@ -1,9 +1,68 @@
+use crate::schema::tasks;
use crate::schema::config;
use crate::schema::strava_tokens;
use crate::schema::users;
use chrono::DateTime;
use chrono::Utc;
use std::fmt;
+use serde_json::Value;
+use diesel::pg::Pg;
+use diesel::deserialize;
+use diesel::deserialize::FromSql;
+use diesel::serialize;
+use diesel::serialize::Output;
+use diesel::serialize::ToSql;
+use diesel::sql_types;
+use std::io::Write;
+
+#[derive(PartialEq, Debug, Clone, Copy, AsExpression, FromSqlRow)]
+#[sql_type = "sql_types::Text"]
+pub enum TaskState {
+ NEW = 0,
+ SUCCESSFUL,
+ FAILED,
+}
+
+impl ToSql<sql_types::Text, Pg> for TaskState {
+ fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
+ let t = match *self {
+ TaskState::NEW => "new".to_string(),
+ TaskState::SUCCESSFUL => "success".to_string(),
+ TaskState::FAILED => "failed".to_string(),
+ };
+ <String as ToSql<sql_types::Text, Pg>>::to_sql(&t, out)
+ }
+}
+
+impl FromSql<sql_types::Text, Pg> for TaskState {
+ fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
+ let s = <String as FromSql<sql_types::Text, Pg>>::from_sql(bytes)?;
+ match s.as_str() {
+ "new" => Ok(TaskState::NEW),
+ "success" => Ok(TaskState::SUCCESSFUL),
+ "failed" => Ok(TaskState::FAILED),
+ &_ => Err("Unrecognized task state".into()),
+ }
+ }
+}
+
+#[derive(Insertable)]
+#[table_name = "tasks"]
+pub struct NewTask<'a> {
+ pub start_at: DateTime<Utc>,
+ pub state: TaskState,
+ pub username: &'a str,
+ pub payload: &'a Value,
+}
+
+#[derive(Queryable, Debug, Clone)]
+pub struct Task {
+ pub id: i64,
+ pub state: TaskState,
+ pub start_at: DateTime<Utc>,
+ pub username: String,
+ pub payload: Value,
+}
#[derive(Insertable, Queryable)]
#[table_name = "config"]
diff --git a/src/schema.rs b/src/schema.rs
index 7cf2892..8748f3c 100644
--- a/src/schema.rs
+++ b/src/schema.rs
@@ -17,6 +17,16 @@ table! {
}
table! {
+ tasks (id) {
+ id -> Int8,
+ state -> Varchar,
+ start_at -> Timestamptz,
+ username -> Varchar,
+ payload -> Jsonb,
+ }
+}
+
+table! {
users (username) {
username -> Varchar,
password -> Varchar,
@@ -24,5 +34,11 @@ table! {
}
joinable!(strava_tokens -> users (username));
+joinable!(tasks -> users (username));
-allow_tables_to_appear_in_same_query!(config, strava_tokens, users,);
+allow_tables_to_appear_in_same_query!(
+ config,
+ strava_tokens,
+ tasks,
+ users,
+);
diff --git a/src/server.rs b/src/server.rs
index abc430b..f0dd591 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -17,6 +17,8 @@ use std::collections::HashMap;
use std::sync::mpsc::Sender;
use std::sync::Mutex;
use threadpool::ThreadPool;
+use chrono::Utc;
+use serde_json::to_value;
use crate::db;
use crate::error::Error;
@@ -133,14 +135,18 @@ fn link_strava_callback(
#[get("/import_strava")]
fn import_strava(
conn: Db,
- tx: State<Mutex<Sender<importer::Command>>>,
user: LoggedInUser,
) -> Result<(), Error> {
let user = db::get_user(&*conn, &user.username)?;
- tx.lock()
- .expect("FIX")
- .send(importer::Command::ImportStravaUser(user))
- .expect("FIX");
+ let command =
+ importer::Command::ImportStravaUser { username: user.username.clone() };
+ db::insert_task(&conn,
+ &models::NewTask {
+ start_at: Utc::now(),
+ state: models::TaskState::NEW,
+ username: user.username.as_str(),
+ payload: &to_value(command)?,
+ })?;
Ok(())
}
@@ -179,12 +185,16 @@ pub fn start(conn: diesel::PgConnection, db_url: &str, base_url: &str) {
.finalize()
.unwrap();
- let importer_pool = ThreadPool::with_name("import".to_string(), importer::WORKERS);
- let tx = importer::run(importer_pool.clone(), conn, &params);
+ let strava = strava::StravaImpl::new(
+ params.strava_client_id.clone(),
+ params.strava_client_secret.clone(),
+ );
+
+ let importer = importer::Importer::new(conn, strava);
+ importer.run();
rocket::custom(config)
.manage(params)
- .manage(Mutex::new(tx))
.mount(
"/",
routes![
@@ -200,5 +210,5 @@ pub fn start(conn: diesel::PgConnection, db_url: &str, base_url: &str) {
.attach(Db::fairing())
.launch();
- importer_pool.join();
+ importer.join();
}
diff --git a/src/strava.rs b/src/strava.rs
index 284d8b1..ff59c66 100644
--- a/src/strava.rs
+++ b/src/strava.rs
@@ -36,12 +36,12 @@ impl From<&models::StravaToken> for Token {
}
}
-pub trait StravaApi {
- fn get<T: Serialize + ?Sized>(
+pub trait StravaApi: Sync + Send {
+ fn get(
&self,
method: &str,
access_token: &str,
- parasm: &T,
+ params: &[(&str, &str)],
) -> Result<Value, Error>;
fn refresh_token(
@@ -70,11 +70,11 @@ impl StravaImpl {
}
impl StravaApi for StravaImpl {
- fn get<T: Serialize + ?Sized>(
+ fn get(
&self,
method: &str,
access_token: &str,
- params: &T,
+ params: &[(&str, &str)],
) -> Result<Value, Error> {
let uri = format!("{}{}{}", self.base_url, self.api_url, method);
let response = self