use async_trait::async_trait; use hbb_common::{log, ResultType}; use serde_json::value::Value; use sqlx::{ sqlite::SqliteConnectOptions, ConnectOptions, Connection, Error as SqlxError, SqliteConnection, }; use std::{ops::DerefMut, str::FromStr}; //use sqlx::postgres::PgPoolOptions; //use sqlx::mysql::MySqlPoolOptions; pub(crate) type DB = sqlx::Sqlite; pub(crate) type MapValue = serde_json::map::Map; pub(crate) type MapStr = std::collections::HashMap; type Pool = deadpool::managed::Pool; pub struct DbPool { url: String, } #[async_trait] impl deadpool::managed::Manager for DbPool { type Type = SqliteConnection; type Error = SqlxError; async fn create(&self) -> Result { let mut opt = SqliteConnectOptions::from_str(&self.url).unwrap(); opt.log_statements(log::LevelFilter::Debug); SqliteConnection::connect_with(&opt).await } async fn recycle( &self, obj: &mut SqliteConnection, ) -> deadpool::managed::RecycleResult { Ok(obj.ping().await?) } } #[derive(Clone)] pub struct Database { pool: Pool, } #[derive(Default)] pub struct Peer { pub guid: Vec, pub id: String, pub uuid: Vec, pub pk: Vec, pub user: Option>, pub info: String, pub status: Option, } impl Database { pub async fn new(url: &str) -> ResultType { if !std::path::Path::new(url).exists() { std::fs::File::create(url).ok(); } let n: usize = std::env::var("MAX_DATABASE_CONNECTIONS") .unwrap_or("1".to_owned()) .parse() .unwrap_or(1); log::debug!("MAX_DATABASE_CONNECTIONS={}", n); let pool = Pool::new( DbPool { url: url.to_owned(), }, n, ); let _ = pool.get().await?; // test let db = Database { pool }; db.create_tables().await?; Ok(db) } async fn create_tables(&self) -> ResultType<()> { sqlx::query!( " create table if not exists peer ( guid blob primary key not null, id varchar(100) not null, uuid blob not null, pk blob not null, created_at datetime not null default(current_timestamp), user blob, status tinyint, note varchar(300), info text not null ) without rowid; create unique index if not exists index_peer_id on peer (id); create index if not exists index_peer_user on peer (user); create index if not exists index_peer_created_at on peer (created_at); create index if not exists index_peer_status on peer (status); " ) .execute(self.pool.get().await?.deref_mut()) .await?; Ok(()) } pub async fn get_peer(&self, id: &str) -> ResultType> { Ok(sqlx::query_as!( Peer, "select guid, id, uuid, pk, user, status, info from peer where id = ?", id ) .fetch_optional(self.pool.get().await?.deref_mut()) .await?) } pub async fn get_peer_id(&self, guid: &[u8]) -> ResultType> { Ok(sqlx::query!("select id from peer where guid = ?", guid) .fetch_optional(self.pool.get().await?.deref_mut()) .await? .map(|x| x.id)) } #[inline] pub async fn get_conn(&self) -> ResultType> { Ok(self.pool.get().await?) } pub async fn update_peer(&self, payload: MapValue, guid: &[u8]) -> ResultType<()> { let mut conn = self.get_conn().await?; let mut tx = conn.begin().await?; if let Some(v) = payload.get("note") { let v = get_str(v); sqlx::query!("update peer set note = ? where guid = ?", v, guid) .execute(&mut tx) .await?; } tx.commit().await?; Ok(()) } pub async fn insert_peer( &self, id: &str, uuid: &Vec, pk: &Vec, info: &str, ) -> ResultType> { let guid = uuid::Uuid::new_v4().as_bytes().to_vec(); sqlx::query!( "insert into peer(guid, id, uuid, pk, info) values(?, ?, ?, ?, ?)", guid, id, uuid, pk, info ) .execute(self.pool.get().await?.deref_mut()) .await?; Ok(guid) } pub async fn update_pk( &self, guid: &Vec, id: &str, pk: &Vec, info: &str, ) -> ResultType<()> { sqlx::query!( "update peer set id=?, pk=?, info=? where guid=?", id, pk, info, guid ) .execute(self.pool.get().await?.deref_mut()) .await?; Ok(()) } } #[cfg(test)] mod tests { use hbb_common::tokio; #[test] fn test_insert() { insert(); } #[tokio::main(flavor = "multi_thread")] async fn insert() { let db = super::Database::new("test.sqlite3").await.unwrap(); let mut jobs = vec![]; for i in 0..10000 { let cloned = db.clone(); let id = i.to_string(); let a = tokio::spawn(async move { let empty_vec = Vec::new(); cloned .insert_peer(&id, &empty_vec, &empty_vec, "") .await .unwrap(); }); jobs.push(a); } for i in 0..10000 { let cloned = db.clone(); let id = i.to_string(); let a = tokio::spawn(async move { cloned.get_peer(&id).await.unwrap(); }); jobs.push(a); } hbb_common::futures::future::join_all(jobs).await; } } #[inline] pub fn guid2str(guid: &Vec) -> String { let mut bytes = [0u8; 16]; bytes[..].copy_from_slice(&guid); uuid::Uuid::from_bytes(bytes).to_string() } pub(crate) fn get_str(v: &Value) -> Option<&str> { match v { Value::String(v) => { let v = v.trim(); if v.is_empty() { None } else { Some(v) } } _ => None, } }