diff options
author | Kjetil Orbekk <kjetil.orbekk@gmail.com> | 2020-02-03 22:55:36 -0500 |
---|---|---|
committer | Kjetil Orbekk <kjetil.orbekk@gmail.com> | 2020-02-03 22:55:36 -0500 |
commit | 6d0a4d03705b96b252a6b29d3b8c188b9c903b89 (patch) | |
tree | b8ea3f7459ae4c9b22a976259e637cc7a3d695c7 /src | |
parent | c459b5e85ef9b695b3c9a107b7cf7f08847c608f (diff) |
Refactor importer to store tasks in postgresql
Diffstat (limited to 'src')
-rw-r--r-- | src/db.rs | 62 | ||||
-rw-r--r-- | src/error.rs | 7 | ||||
-rw-r--r-- | src/importer.rs | 257 | ||||
-rw-r--r-- | src/models.rs | 59 | ||||
-rw-r--r-- | src/schema.rs | 18 | ||||
-rw-r--r-- | src/server.rs | 28 | ||||
-rw-r--r-- | src/strava.rs | 10 |
7 files changed, 359 insertions, 82 deletions
@@ -6,6 +6,9 @@ use diesel::pg::PgConnection; use diesel::ExpressionMethods; use diesel::QueryDsl; use diesel::RunQueryDsl; +use std::time::Duration; +use chrono::DateTime; +use chrono::Utc; pub const COST: u32 = 10; @@ -98,3 +101,62 @@ pub fn get_strava_token( .get_result::<models::StravaToken>(conn)?; Ok(token) } + +pub fn insert_task( + conn: &PgConnection, + task: &models::NewTask) -> Result<i64, Error> { + use crate::schema::tasks; + let id = diesel::insert_into(tasks::table) + .values(task) + .returning(tasks::id) + .get_result(conn)?; + Ok(id) +} + +fn update_task_inner(conn: &PgConnection, task: &models::Task) + -> Result<models::Task, Error> { + use crate::schema::tasks; + + diesel::delete(tasks::table.filter(tasks::columns::id.eq(task.id))) + .execute(conn)?; + + let new_id = insert_task(conn, &models::NewTask { + start_at: task.start_at, + state: task.state, + username: &task.username, + payload: &task.payload, + })?; + + let new_task = tasks::table.find(new_id) + .get_result::<models::Task>(conn)?; + + Ok(new_task) +} + +fn update_task(conn: &PgConnection, task: &models::Task) -> Result<models::Task, Error> { + conn.transaction(|| { + update_task_inner(conn, task) + }) +} + +pub fn take_task( + conn: &PgConnection, + state: models::TaskState, + start_before: DateTime<Utc>, + eta: DateTime<Utc>) + -> Result<models::Task, Error> { + use crate::schema::tasks; + + conn.transaction(|| { + let mut task = tasks::table + .filter(tasks::state.eq(state)) + .filter(tasks::start_at.lt(start_before)) + .order(tasks::start_at.asc()) + .first::<models::Task>(conn)?; + + task.start_at = eta; + let task = update_task_inner(conn, &task)?; + + Ok(task) + }) +} diff --git a/src/error.rs b/src/error.rs index 4ae2995..75a7568 100644 --- a/src/error.rs +++ b/src/error.rs @@ -67,6 +67,7 @@ pub enum Error { CommunicationError(reqwest::Error), ParseError(serde_json::error::Error), StravaApiError(StravaApiError), + UnexpectedJson(Value), AlreadyExists, NotFound, InternalError, @@ -79,6 +80,7 @@ impl fmt::Display for Error { Error::PasswordError(ref e) => e.fmt(f), Error::CommunicationError(ref e) => e.fmt(f), Error::ParseError(ref e) => e.fmt(f), + Error::UnexpectedJson(_) => f.write_str("UnexpectedJson"), Error::StravaApiError(ref e) => e.fmt(f), Error::AlreadyExists => f.write_str("AlreadyExists"), Error::NotFound => f.write_str("NotFound"), @@ -107,7 +109,10 @@ impl From<reqwest::Error> for Error { impl From<DieselErr> for Error { fn from(e: DieselErr) -> Error { - Error::DieselError(e) + match e { + DieselErr::NotFound => Error::NotFound, + e => Error::DieselError(e) + } } } diff --git a/src/importer.rs b/src/importer.rs index 9ea7e35..6909350 100644 --- a/src/importer.rs +++ b/src/importer.rs @@ -7,6 +7,13 @@ use std::sync::Mutex; use std::sync::RwLock; use threadpool::ThreadPool; use chrono::Utc; +use timer::Timer; +use timer::Guard; +use std::time::Instant; +use std::time::Duration; +use std::thread; +use serde::Deserialize; +use serde::Serialize; use crate::error::Error; use crate::db; @@ -18,92 +25,210 @@ use crate::Params; pub const WORKERS: usize = 10; pub const EMPTY_PARAMS: &[(&str, &str)] = &[]; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Command { - ImportStravaUser(models::User), - Quit, + ImportStravaUser { username: String }, } -#[derive(Clone)] -struct ImporterState { - pool: ThreadPool, - conn: Arc<Mutex<PgConnection>>, - strava: Arc<RwLock<strava::StravaImpl>>, - rx: Arc<Mutex<Receiver<Command>>>, +macro_rules! clone { + ( [ $( $i:ident ),* ] $e:expr ) => { + { + $(let $i = $i.clone();)* + $e + } + } } -fn get_or_refresh_token<Strava: strava::StravaApi>(strava: &Strava, conn: &PgConnection, user: &models::User) -> Result<models::StravaToken, Error> { - let mut token = db::get_strava_token(&conn, &user).expect("FIX"); +pub struct ImporterSharedData<StravaApi: strava::StravaApi + 'static> { + strava: RwLock<StravaApi>, + pool: Mutex<ThreadPool>, + conn: Mutex<PgConnection>, + running: Mutex<bool>, +} - if token.expires_at < Utc::now() { - info!("refresh expired token: {:?}", token.expires_at); - let new_token = strava.refresh_token(&From::from(&token))?; - new_token.update_model(&mut token); +pub struct Importer<StravaApi: strava::StravaApi + 'static> { + shared: Arc<ImporterSharedData<StravaApi>>, +} + +fn run_periodically<S: strava::StravaApi>( + shared: Arc<ImporterSharedData<S>>, + period: Duration) { + let sleep_time = Duration::from_millis(1000); + let mut now = Instant::now(); + loop { + while now.elapsed() < period { + if !*shared.running.lock().unwrap() { + return; + } + thread::sleep(sleep_time); + } + now = Instant::now(); + + info!("run_periodically: wakeup"); + handle_tasks(shared.clone()) } +} - Ok(token) + +fn handle_one_task<S: strava::StravaApi>( + shared: Arc<ImporterSharedData<S>>) -> Result<models::Task, Error> { + let task = { + let conn = shared.conn.lock().unwrap(); + let now = Utc::now(); + let eta = now + chrono::Duration::seconds(5); + + db::take_task(&conn, + models::TaskState::NEW, + now, + eta)? + }; + + let command = serde_json::from_value(task.payload.clone())?; + + match command { + Command::ImportStravaUser{ username } => { + import_strava_user(shared, username.as_str())? + }, + } + + Ok(task) } -fn import_strava_user(state: ImporterState, user: models::User) { - use std::thread::sleep; - use std::time::Duration; +fn handle_tasks<S: strava::StravaApi>( + shared: Arc<ImporterSharedData<S>>) { + let mut done = false; + while !done { + match handle_one_task(shared.clone()) { + Err(Error::NotFound) => { + info!("No more tasks"); + done = true; + }, + Err(e) => { + error!("Error handling task: {}", e); + } + Ok(t) => { + info!("Successfully handled task: {:?}", t); + } + }; + } +} + +impl<StravaApi: strava::StravaApi> Importer<StravaApi> { + pub fn new(conn: PgConnection, strava: StravaApi) -> Importer<StravaApi> { + let shared = Arc::new(ImporterSharedData { + pool: Mutex::new(ThreadPool::with_name("importer".to_string(), WORKERS)), + conn: Mutex::new(conn), + strava: RwLock::new(strava), + running: Mutex::new(false), + }); + Importer { shared: shared } + } - let strava = state.strava.read().expect("FIX"); - let conn = state.conn.lock().expect("FIX"); - let token = get_or_refresh_token(&*strava, &conn, &user).expect("FIX"); + pub fn run(&self) { + info!("run()"); + let pool = self.shared.pool.lock().unwrap(); + let mut running = self.shared.running.lock().unwrap(); + if !*running { + *running = true; + pool.execute({ + let shared = self.shared.clone(); + move || run_periodically(shared, Duration::from_secs(10)) + }); + } + } + + pub fn join(&self) { + self.shared.pool.lock().expect("FIX").join() + } +} + +fn import_strava_user<S: strava::StravaApi>( + shared: Arc<ImporterSharedData<S>>, + username: &str) -> Result<(), Error> { + let strava = shared.strava.read().unwrap(); + let user = db::get_user(&shared.conn.lock().unwrap(), username)?; + + let token = { + let conn = shared.conn.lock().unwrap(); + get_or_refresh_token(&*strava, &conn, &user)? + }; + + let per_page = 30; for page in 1.. { let params = [ - ("page", &format!("{}", page)), - ("per_page", &format!("{}", 200)), + ("page", &format!("{}", page)[..]), + ("per_page", &format!("{}", per_page)[..]) ]; + let result = strava - .get("/athlete/activities", &token.access_token, ¶ms) - .expect("ok"); - // info!("import_strava_user: Got result: {:#?}", result); - for activity in result.as_array().expect("FIX") { + .get("/athlete/activities", &token.access_token, ¶ms[..])?; + + let result = result.as_array().ok_or( + Error::UnexpectedJson(result.clone()))?; + + for activity in result { info!("activity id: {} start: {}", activity["id"], activity["start_date"]); } - sleep(Duration::from_secs(1)); - } -} -fn handle_command(state: ImporterState, command: Command) { - info!("handle_command {:?}", command); - match command { - Command::ImportStravaUser(user) => import_strava_user(state, user), - Command::Quit => (), - } + if result.len() < per_page { + break; + } + thread::sleep(Duration::from_secs(1)); + }; + + Err(Error::InternalError) } -fn receive_commands(state: ImporterState) { - info!("receive_commands"); - match (|| -> Result<(), Box<dyn std::error::Error>> { - let rx = state.rx.lock()?; - let mut command = rx.recv()?; - loop { - info!("got command: {:?}", command); - let state0 = state.clone(); - state.pool.execute(move || handle_command(state0, command)); - command = rx.recv()?; - } - })() { - Ok(()) => (), - Err(e) => { - error!("receive_commands: {:?}", e); - () - } +fn get_or_refresh_token<Strava: strava::StravaApi>(strava: &Strava, conn: &PgConnection, user: &models::User) -> Result<models::StravaToken, Error> { + let mut token = db::get_strava_token(&conn, &user).expect("FIX"); + + if token.expires_at < Utc::now() { + info!("refresh expired token: {:?}", token.expires_at); + let new_token = strava.refresh_token(&From::from(&token))?; + new_token.update_model(&mut token); } -} -pub fn run(pool: ThreadPool, conn: PgConnection, params: &Params) -> Sender<Command> { - let (tx, rx0) = channel(); - let state = ImporterState { - pool: pool.clone(), - conn: Arc::new(Mutex::new(conn)), - strava: Arc::new(RwLock::new(strava::StravaImpl::new( - params.strava_client_id.clone(), params.strava_client_secret.clone()))), - rx: Arc::new(Mutex::new(rx0)), - }; - pool.execute(move || receive_commands(state)); - tx + Ok(token) } + +// fn handle_command(state: Importer, command: Command) { +// info!("handle_command {:?}", command); +// match command { +// Command::ImportStravaUser(user) => import_strava_user(state, user), +// Command::Quit => (), +// } +// } + +// fn receive_commands(state: Importer) { +// info!("receive_commands"); +// match (|| -> Result<(), Box<dyn std::error::Error>> { +// let rx = state.rx.lock()?; +// let mut command = rx.recv()?; +// loop { +// info!("got command: {:?}", command); +// let state0 = state.clone(); +// state.pool.execute(move || handle_command(state0, command)); +// command = rx.recv()?; +// } +// })() { +// Ok(()) => (), +// Err(e) => { +// error!("receive_commands: {:?}", e); +// () +// } +// } +// } + +// pub fn run(pool: ThreadPool, conn: PgConnection, params: &Params) -> Sender<Command> { +// let (tx, rx0) = channel(); +// let importer = Arc::new(Importer { +// pool: Mutex::new(pool.clone()), +// conn: Mutex::new(conn), +// strava: RwLock::new(strava::StravaImpl::new( +// params.strava_client_id.clone(), params.strava_client_secret.clone())), +// rx: Mutex::new(rx0), +// }); +// // pool.execute(move || receive_commands(state)); +// pool.execute(clone! { [importer] move || importer.run() }); +// tx +// } diff --git a/src/models.rs b/src/models.rs index ce3dd19..0b7e5db 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,9 +1,68 @@ +use crate::schema::tasks; use crate::schema::config; use crate::schema::strava_tokens; use crate::schema::users; use chrono::DateTime; use chrono::Utc; use std::fmt; +use serde_json::Value; +use diesel::pg::Pg; +use diesel::deserialize; +use diesel::deserialize::FromSql; +use diesel::serialize; +use diesel::serialize::Output; +use diesel::serialize::ToSql; +use diesel::sql_types; +use std::io::Write; + +#[derive(PartialEq, Debug, Clone, Copy, AsExpression, FromSqlRow)] +#[sql_type = "sql_types::Text"] +pub enum TaskState { + NEW = 0, + SUCCESSFUL, + FAILED, +} + +impl ToSql<sql_types::Text, Pg> for TaskState { + fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result { + let t = match *self { + TaskState::NEW => "new".to_string(), + TaskState::SUCCESSFUL => "success".to_string(), + TaskState::FAILED => "failed".to_string(), + }; + <String as ToSql<sql_types::Text, Pg>>::to_sql(&t, out) + } +} + +impl FromSql<sql_types::Text, Pg> for TaskState { + fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> { + let s = <String as FromSql<sql_types::Text, Pg>>::from_sql(bytes)?; + match s.as_str() { + "new" => Ok(TaskState::NEW), + "success" => Ok(TaskState::SUCCESSFUL), + "failed" => Ok(TaskState::FAILED), + &_ => Err("Unrecognized task state".into()), + } + } +} + +#[derive(Insertable)] +#[table_name = "tasks"] +pub struct NewTask<'a> { + pub start_at: DateTime<Utc>, + pub state: TaskState, + pub username: &'a str, + pub payload: &'a Value, +} + +#[derive(Queryable, Debug, Clone)] +pub struct Task { + pub id: i64, + pub state: TaskState, + pub start_at: DateTime<Utc>, + pub username: String, + pub payload: Value, +} #[derive(Insertable, Queryable)] #[table_name = "config"] diff --git a/src/schema.rs b/src/schema.rs index 7cf2892..8748f3c 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -17,6 +17,16 @@ table! { } table! { + tasks (id) { + id -> Int8, + state -> Varchar, + start_at -> Timestamptz, + username -> Varchar, + payload -> Jsonb, + } +} + +table! { users (username) { username -> Varchar, password -> Varchar, @@ -24,5 +34,11 @@ table! { } joinable!(strava_tokens -> users (username)); +joinable!(tasks -> users (username)); -allow_tables_to_appear_in_same_query!(config, strava_tokens, users,); +allow_tables_to_appear_in_same_query!( + config, + strava_tokens, + tasks, + users, +); diff --git a/src/server.rs b/src/server.rs index abc430b..f0dd591 100644 --- a/src/server.rs +++ b/src/server.rs @@ -17,6 +17,8 @@ use std::collections::HashMap; use std::sync::mpsc::Sender; use std::sync::Mutex; use threadpool::ThreadPool; +use chrono::Utc; +use serde_json::to_value; use crate::db; use crate::error::Error; @@ -133,14 +135,18 @@ fn link_strava_callback( #[get("/import_strava")] fn import_strava( conn: Db, - tx: State<Mutex<Sender<importer::Command>>>, user: LoggedInUser, ) -> Result<(), Error> { let user = db::get_user(&*conn, &user.username)?; - tx.lock() - .expect("FIX") - .send(importer::Command::ImportStravaUser(user)) - .expect("FIX"); + let command = + importer::Command::ImportStravaUser { username: user.username.clone() }; + db::insert_task(&conn, + &models::NewTask { + start_at: Utc::now(), + state: models::TaskState::NEW, + username: user.username.as_str(), + payload: &to_value(command)?, + })?; Ok(()) } @@ -179,12 +185,16 @@ pub fn start(conn: diesel::PgConnection, db_url: &str, base_url: &str) { .finalize() .unwrap(); - let importer_pool = ThreadPool::with_name("import".to_string(), importer::WORKERS); - let tx = importer::run(importer_pool.clone(), conn, ¶ms); + let strava = strava::StravaImpl::new( + params.strava_client_id.clone(), + params.strava_client_secret.clone(), + ); + + let importer = importer::Importer::new(conn, strava); + importer.run(); rocket::custom(config) .manage(params) - .manage(Mutex::new(tx)) .mount( "/", routes![ @@ -200,5 +210,5 @@ pub fn start(conn: diesel::PgConnection, db_url: &str, base_url: &str) { .attach(Db::fairing()) .launch(); - importer_pool.join(); + importer.join(); } diff --git a/src/strava.rs b/src/strava.rs index 284d8b1..ff59c66 100644 --- a/src/strava.rs +++ b/src/strava.rs @@ -36,12 +36,12 @@ impl From<&models::StravaToken> for Token { } } -pub trait StravaApi { - fn get<T: Serialize + ?Sized>( +pub trait StravaApi: Sync + Send { + fn get( &self, method: &str, access_token: &str, - parasm: &T, + params: &[(&str, &str)], ) -> Result<Value, Error>; fn refresh_token( @@ -70,11 +70,11 @@ impl StravaImpl { } impl StravaApi for StravaImpl { - fn get<T: Serialize + ?Sized>( + fn get( &self, method: &str, access_token: &str, - params: &T, + params: &[(&str, &str)], ) -> Result<Value, Error> { let uri = format!("{}{}{}", self.base_url, self.api_url, method); let response = self |