summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKjetil Orbekk <kj@orbekk.com>2022-12-31 12:17:35 -0500
committerKjetil Orbekk <kj@orbekk.com>2022-12-31 12:55:45 -0500
commit88366acba07b678466b42829887dcdda4f583686 (patch)
treee3450615a6b4b654ea7106a00cc6b551dd4ddb26
parentaa6d050b09dfbf3e5be112325e8e8d8a1f4dacf9 (diff)
Add database conversion for bridge types
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml3
-rw-r--r--protocol/Cargo.toml1
-rw-r--r--protocol/src/card.rs62
-rw-r--r--server/migrations/20221008120534_init.down.sql2
-rw-r--r--server/migrations/20221008120534_init.up.sql7
-rw-r--r--server/tests/db_test.rs45
7 files changed, 117 insertions, 4 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 04acb88..548641d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1578,6 +1578,7 @@ dependencies = [
"regex",
"serde",
"serde_json",
+ "sqlx",
"strum",
"strum_macros",
"tokio",
diff --git a/Cargo.toml b/Cargo.toml
index 7fbfcaa..6366b84 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,3 +16,6 @@ opt-level = 'z'
# opt-level = 's'
# link time optimization using using whole-program analysis
lto = true
+
+[profile.dev.package.sqlx-macros]
+opt-level = 3 \ No newline at end of file
diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml
index 9d25188..dc1510d 100644
--- a/protocol/Cargo.toml
+++ b/protocol/Cargo.toml
@@ -17,6 +17,7 @@ log = "0.4"
regex = "1.0"
lazy_static = "1.4"
async-trait = "0.1.58"
+sqlx = { version = "0.6", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono", "json", "offline"] }
[dev-dependencies]
env_logger = "0.10.0"
diff --git a/protocol/src/card.rs b/protocol/src/card.rs
index 771d5e9..31d389a 100644
--- a/protocol/src/card.rs
+++ b/protocol/src/card.rs
@@ -1,6 +1,13 @@
use anyhow::anyhow;
pub(crate) use serde::{Deserialize, Serialize};
+use sqlx::encode::IsNull;
+use sqlx::error::BoxDynError;
+use sqlx::postgres::PgArgumentBuffer;
+use sqlx::postgres::PgTypeInfo;
+use sqlx::Postgres;
+use sqlx::postgres::PgValueRef;
+use strum_macros::FromRepr;
use std::fmt;
use strum::EnumCount;
use strum::IntoEnumIterator;
@@ -18,7 +25,9 @@ use strum_macros::EnumIter;
EnumCount,
Serialize,
Deserialize,
+ FromRepr,
)]
+#[repr(u8)]
pub enum Suit {
Club,
Diamond,
@@ -36,7 +45,9 @@ pub enum Suit {
EnumIter,
Serialize,
Deserialize,
+ FromRepr,
)]
+#[repr(u8)]
pub enum Rank {
Two = 2,
Three,
@@ -67,6 +78,57 @@ impl fmt::Display for Suit {
}
}
+impl sqlx::Type<Postgres> for Rank {
+ fn type_info() -> PgTypeInfo {
+ <i16 as sqlx::Type<Postgres>>::type_info()
+ }
+}
+
+impl sqlx::Decode<'_, Postgres> for Rank {
+ fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
+ let value = <i16 as sqlx::Decode<Postgres>>::decode(value)?;
+ Ok(Rank::from_repr(u8::try_from(value).expect("domain check")).expect("domain check"))
+ }
+}
+
+impl sqlx::Encode<'_, Postgres> for Rank {
+ fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
+ let pg_value = *self as i16;
+ <i16 as sqlx::Encode<'_, Postgres>>::encode_by_ref(&pg_value, buf)
+ }
+}
+
+impl sqlx::Type<Postgres> for Suit {
+ fn type_info() -> PgTypeInfo {
+ <&str as sqlx::Type<Postgres>>::type_info()
+ }
+}
+
+impl sqlx::Decode<'_, Postgres> for Suit {
+ fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
+ let value = <&str as sqlx::Decode<Postgres>>::decode(value)?;
+ match value {
+ "club" => Ok(Suit::Club),
+ "diamond" => Ok(Suit::Diamond),
+ "heart" => Ok(Suit::Heart),
+ "spade" => Ok(Suit::Spade),
+ _ => panic!("invalid suit enum value"),
+ }
+ }
+}
+
+impl sqlx::Encode<'_, Postgres> for Suit {
+ fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
+ let pg_value = match *self {
+ Suit::Club => "club",
+ Suit::Diamond => "diamond",
+ Suit::Heart => "heart",
+ Suit::Spade => "spade",
+ };
+ <&str as sqlx::Encode<'_, Postgres>>::encode_by_ref(&pg_value, buf)
+ }
+}
+
impl std::str::FromStr for Suit {
type Err = anyhow::Error;
diff --git a/server/migrations/20221008120534_init.down.sql b/server/migrations/20221008120534_init.down.sql
index 50ba511..3bff171 100644
--- a/server/migrations/20221008120534_init.down.sql
+++ b/server/migrations/20221008120534_init.down.sql
@@ -6,4 +6,6 @@ drop table if exists object_journal;
drop table if exists active_tables;
drop table if exists players;
drop type if exists player_position;
+drop type if exists suit;
+drop domain if exists rank;
commit;
diff --git a/server/migrations/20221008120534_init.up.sql b/server/migrations/20221008120534_init.up.sql
index db9fd3c..05b7697 100644
--- a/server/migrations/20221008120534_init.up.sql
+++ b/server/migrations/20221008120534_init.up.sql
@@ -16,6 +16,11 @@ create table active_tables (
id uuid primary key not null
);
+create type player_position as enum ('west', 'north', 'east', 'south');
+create type suit as enum ('club', 'diamond', 'heart', 'spade');
+create domain rank smallint check (value between 2 and 14);
+
+-- TODO: Remove this.
create table object_journal (
id uuid not null,
seq bigint not null,
@@ -23,8 +28,6 @@ create table object_journal (
);
create unique index journal_entry on object_journal (id, seq);
-create type player_position as enum ('west', 'north', 'east', 'south');
-
create table table_players (
active_tables_id uuid not null references active_tables (id),
player_id varchar(64) not null references players (id),
diff --git a/server/tests/db_test.rs b/server/tests/db_test.rs
index 6acc21e..e495d78 100644
--- a/server/tests/db_test.rs
+++ b/server/tests/db_test.rs
@@ -1,3 +1,4 @@
+use protocol::card::{Rank, Suit};
use sqlx::Row;
use tracing::info;
@@ -8,7 +9,47 @@ mod common;
async fn basic_db_query() -> Result<(), anyhow::Error> {
let db = common::TestDb::new().await;
let db = db.db();
- let number = sqlx::query!(r#"select 1 + 1 as number"#).fetch_one(&db).await?;
- assert_eq!(number.number, Some(2));
+ let row = sqlx::query!(r#"select 1 + 1 as number"#)
+ .fetch_one(&db)
+ .await?;
+ assert_eq!(row.number, Some(2));
+ Ok(())
+}
+
+#[tokio::test]
+#[ignore]
+async fn rank_type_conversion() -> Result<(), anyhow::Error> {
+ let db = common::TestDb::new().await;
+ let db = db.db();
+ let row = sqlx::query!(r#"select cast (2 as rank) as "rank!: Rank""#)
+ .fetch_one(&db)
+ .await?;
+ assert_eq!(Rank::Two, row.rank);
+
+ let rank = Rank::Ace;
+ let (equal,): (bool,) = sqlx::query_as(r#"select ($1 is not distinct from 14)"#)
+ .bind(rank)
+ .fetch_one(&db)
+ .await?;
+ assert!(equal);
+ Ok(())
+}
+
+#[tokio::test]
+#[ignore]
+async fn suit_type_conversion() -> Result<(), anyhow::Error> {
+ let db = common::TestDb::new().await;
+ let db = db.db();
+ let row = sqlx::query!(r#"select cast ('heart' as suit) as "suit!: Suit""#)
+ .fetch_one(&db)
+ .await?;
+ assert_eq!(Suit::Heart, row.suit);
+
+ let suit = Suit::Spade;
+ let (equal,): (bool,) = sqlx::query_as(r#"select ($1 is not distinct from 'spade')"#)
+ .bind(suit)
+ .fetch_one(&db)
+ .await?;
+ assert!(equal);
Ok(())
}