diff options
author | Kjetil Orbekk <kj@orbekk.com> | 2022-12-31 12:17:35 -0500 |
---|---|---|
committer | Kjetil Orbekk <kj@orbekk.com> | 2022-12-31 12:55:45 -0500 |
commit | 88366acba07b678466b42829887dcdda4f583686 (patch) | |
tree | e3450615a6b4b654ea7106a00cc6b551dd4ddb26 | |
parent | aa6d050b09dfbf3e5be112325e8e8d8a1f4dacf9 (diff) |
Add database conversion for bridge types
-rw-r--r-- | Cargo.lock | 1 | ||||
-rw-r--r-- | Cargo.toml | 3 | ||||
-rw-r--r-- | protocol/Cargo.toml | 1 | ||||
-rw-r--r-- | protocol/src/card.rs | 62 | ||||
-rw-r--r-- | server/migrations/20221008120534_init.down.sql | 2 | ||||
-rw-r--r-- | server/migrations/20221008120534_init.up.sql | 7 | ||||
-rw-r--r-- | server/tests/db_test.rs | 45 |
7 files changed, 117 insertions, 4 deletions
@@ -1578,6 +1578,7 @@ dependencies = [ "regex", "serde", "serde_json", + "sqlx", "strum", "strum_macros", "tokio", @@ -16,3 +16,6 @@ opt-level = 'z' # opt-level = 's' # link time optimization using using whole-program analysis lto = true + +[profile.dev.package.sqlx-macros] +opt-level = 3
\ No newline at end of file diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml index 9d25188..dc1510d 100644 --- a/protocol/Cargo.toml +++ b/protocol/Cargo.toml @@ -17,6 +17,7 @@ log = "0.4" regex = "1.0" lazy_static = "1.4" async-trait = "0.1.58" +sqlx = { version = "0.6", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono", "json", "offline"] } [dev-dependencies] env_logger = "0.10.0" diff --git a/protocol/src/card.rs b/protocol/src/card.rs index 771d5e9..31d389a 100644 --- a/protocol/src/card.rs +++ b/protocol/src/card.rs @@ -1,6 +1,13 @@ use anyhow::anyhow; pub(crate) use serde::{Deserialize, Serialize}; +use sqlx::encode::IsNull; +use sqlx::error::BoxDynError; +use sqlx::postgres::PgArgumentBuffer; +use sqlx::postgres::PgTypeInfo; +use sqlx::Postgres; +use sqlx::postgres::PgValueRef; +use strum_macros::FromRepr; use std::fmt; use strum::EnumCount; use strum::IntoEnumIterator; @@ -18,7 +25,9 @@ use strum_macros::EnumIter; EnumCount, Serialize, Deserialize, + FromRepr, )] +#[repr(u8)] pub enum Suit { Club, Diamond, @@ -36,7 +45,9 @@ pub enum Suit { EnumIter, Serialize, Deserialize, + FromRepr, )] +#[repr(u8)] pub enum Rank { Two = 2, Three, @@ -67,6 +78,57 @@ impl fmt::Display for Suit { } } +impl sqlx::Type<Postgres> for Rank { + fn type_info() -> PgTypeInfo { + <i16 as sqlx::Type<Postgres>>::type_info() + } +} + +impl sqlx::Decode<'_, Postgres> for Rank { + fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> { + let value = <i16 as sqlx::Decode<Postgres>>::decode(value)?; + Ok(Rank::from_repr(u8::try_from(value).expect("domain check")).expect("domain check")) + } +} + +impl sqlx::Encode<'_, Postgres> for Rank { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + let pg_value = *self as i16; + <i16 as sqlx::Encode<'_, Postgres>>::encode_by_ref(&pg_value, buf) + } +} + +impl sqlx::Type<Postgres> for Suit { + fn type_info() -> PgTypeInfo { + <&str as sqlx::Type<Postgres>>::type_info() + } +} + +impl sqlx::Decode<'_, Postgres> for Suit { + fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> { + let value = <&str as sqlx::Decode<Postgres>>::decode(value)?; + match value { + "club" => Ok(Suit::Club), + "diamond" => Ok(Suit::Diamond), + "heart" => Ok(Suit::Heart), + "spade" => Ok(Suit::Spade), + _ => panic!("invalid suit enum value"), + } + } +} + +impl sqlx::Encode<'_, Postgres> for Suit { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + let pg_value = match *self { + Suit::Club => "club", + Suit::Diamond => "diamond", + Suit::Heart => "heart", + Suit::Spade => "spade", + }; + <&str as sqlx::Encode<'_, Postgres>>::encode_by_ref(&pg_value, buf) + } +} + impl std::str::FromStr for Suit { type Err = anyhow::Error; diff --git a/server/migrations/20221008120534_init.down.sql b/server/migrations/20221008120534_init.down.sql index 50ba511..3bff171 100644 --- a/server/migrations/20221008120534_init.down.sql +++ b/server/migrations/20221008120534_init.down.sql @@ -6,4 +6,6 @@ drop table if exists object_journal; drop table if exists active_tables; drop table if exists players; drop type if exists player_position; +drop type if exists suit; +drop domain if exists rank; commit; diff --git a/server/migrations/20221008120534_init.up.sql b/server/migrations/20221008120534_init.up.sql index db9fd3c..05b7697 100644 --- a/server/migrations/20221008120534_init.up.sql +++ b/server/migrations/20221008120534_init.up.sql @@ -16,6 +16,11 @@ create table active_tables ( id uuid primary key not null ); +create type player_position as enum ('west', 'north', 'east', 'south'); +create type suit as enum ('club', 'diamond', 'heart', 'spade'); +create domain rank smallint check (value between 2 and 14); + +-- TODO: Remove this. create table object_journal ( id uuid not null, seq bigint not null, @@ -23,8 +28,6 @@ create table object_journal ( ); create unique index journal_entry on object_journal (id, seq); -create type player_position as enum ('west', 'north', 'east', 'south'); - create table table_players ( active_tables_id uuid not null references active_tables (id), player_id varchar(64) not null references players (id), diff --git a/server/tests/db_test.rs b/server/tests/db_test.rs index 6acc21e..e495d78 100644 --- a/server/tests/db_test.rs +++ b/server/tests/db_test.rs @@ -1,3 +1,4 @@ +use protocol::card::{Rank, Suit}; use sqlx::Row; use tracing::info; @@ -8,7 +9,47 @@ mod common; async fn basic_db_query() -> Result<(), anyhow::Error> { let db = common::TestDb::new().await; let db = db.db(); - let number = sqlx::query!(r#"select 1 + 1 as number"#).fetch_one(&db).await?; - assert_eq!(number.number, Some(2)); + let row = sqlx::query!(r#"select 1 + 1 as number"#) + .fetch_one(&db) + .await?; + assert_eq!(row.number, Some(2)); + Ok(()) +} + +#[tokio::test] +#[ignore] +async fn rank_type_conversion() -> Result<(), anyhow::Error> { + let db = common::TestDb::new().await; + let db = db.db(); + let row = sqlx::query!(r#"select cast (2 as rank) as "rank!: Rank""#) + .fetch_one(&db) + .await?; + assert_eq!(Rank::Two, row.rank); + + let rank = Rank::Ace; + let (equal,): (bool,) = sqlx::query_as(r#"select ($1 is not distinct from 14)"#) + .bind(rank) + .fetch_one(&db) + .await?; + assert!(equal); + Ok(()) +} + +#[tokio::test] +#[ignore] +async fn suit_type_conversion() -> Result<(), anyhow::Error> { + let db = common::TestDb::new().await; + let db = db.db(); + let row = sqlx::query!(r#"select cast ('heart' as suit) as "suit!: Suit""#) + .fetch_one(&db) + .await?; + assert_eq!(Suit::Heart, row.suit); + + let suit = Suit::Spade; + let (equal,): (bool,) = sqlx::query_as(r#"select ($1 is not distinct from 'spade')"#) + .bind(suit) + .fetch_one(&db) + .await?; + assert!(equal); Ok(()) } |