relay_server.rs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. use async_speed_limit::Limiter;
  2. use async_trait::async_trait;
  3. use hbb_common::{
  4. allow_err, bail,
  5. bytes::{Bytes, BytesMut},
  6. futures_util::{sink::SinkExt, stream::StreamExt},
  7. log,
  8. protobuf::Message as _,
  9. rendezvous_proto::*,
  10. sleep,
  11. tcp::{listen_any, FramedStream},
  12. timeout,
  13. tokio::{
  14. self,
  15. io::{AsyncReadExt, AsyncWriteExt},
  16. net::{TcpListener, TcpStream},
  17. sync::{Mutex, RwLock},
  18. time::{interval, Duration},
  19. },
  20. ResultType,
  21. };
  22. use sodiumoxide::crypto::sign;
  23. use std::{
  24. collections::{HashMap, HashSet},
  25. io::prelude::*,
  26. io::Error,
  27. net::SocketAddr,
  28. sync::atomic::{AtomicUsize, Ordering},
  29. };
  30. type Usage = (usize, usize, usize, usize);
  31. lazy_static::lazy_static! {
  32. static ref PEERS: Mutex<HashMap<String, Box<dyn StreamTrait>>> = Default::default();
  33. static ref USAGE: RwLock<HashMap<String, Usage>> = Default::default();
  34. static ref BLACKLIST: RwLock<HashSet<String>> = Default::default();
  35. static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default();
  36. }
  37. static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66
  38. static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms
  39. static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s
  40. static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s
  41. static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s
  42. const BLACKLIST_FILE: &str = "blacklist.txt";
  43. const BLOCKLIST_FILE: &str = "blocklist.txt";
  44. #[tokio::main(flavor = "multi_thread")]
  45. pub async fn start(port: &str, key: &str) -> ResultType<()> {
  46. let key = get_server_sk(key);
  47. if let Ok(mut file) = std::fs::File::open(BLACKLIST_FILE) {
  48. let mut contents = String::new();
  49. if file.read_to_string(&mut contents).is_ok() {
  50. for x in contents.split('\n') {
  51. if let Some(ip) = x.trim().split(' ').next() {
  52. BLACKLIST.write().await.insert(ip.to_owned());
  53. }
  54. }
  55. }
  56. }
  57. log::info!(
  58. "#blacklist({}): {}",
  59. BLACKLIST_FILE,
  60. BLACKLIST.read().await.len()
  61. );
  62. if let Ok(mut file) = std::fs::File::open(BLOCKLIST_FILE) {
  63. let mut contents = String::new();
  64. if file.read_to_string(&mut contents).is_ok() {
  65. for x in contents.split('\n') {
  66. if let Some(ip) = x.trim().split(' ').next() {
  67. BLOCKLIST.write().await.insert(ip.to_owned());
  68. }
  69. }
  70. }
  71. }
  72. log::info!(
  73. "#blocklist({}): {}",
  74. BLOCKLIST_FILE,
  75. BLOCKLIST.read().await.len()
  76. );
  77. let port: u16 = port.parse()?;
  78. log::info!("Listening on tcp :{}", port);
  79. let port2 = port + 2;
  80. log::info!("Listening on websocket :{}", port2);
  81. let main_task = async move {
  82. loop {
  83. log::info!("Start");
  84. io_loop(listen_any(port).await?, listen_any(port2).await?, &key).await;
  85. }
  86. };
  87. let listen_signal = crate::common::listen_signal();
  88. tokio::select!(
  89. res = main_task => res,
  90. res = listen_signal => res,
  91. )
  92. }
  93. fn check_params() {
  94. let tmp = std::env::var("DOWNGRADE_THRESHOLD")
  95. .map(|x| x.parse::<f64>().unwrap_or(0.))
  96. .unwrap_or(0.);
  97. if tmp > 0. {
  98. DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst);
  99. }
  100. log::info!(
  101. "DOWNGRADE_THRESHOLD: {}",
  102. DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
  103. );
  104. let tmp = std::env::var("DOWNGRADE_START_CHECK")
  105. .map(|x| x.parse::<usize>().unwrap_or(0))
  106. .unwrap_or(0);
  107. if tmp > 0 {
  108. DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst);
  109. }
  110. log::info!(
  111. "DOWNGRADE_START_CHECK: {}s",
  112. DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000
  113. );
  114. let tmp = std::env::var("LIMIT_SPEED")
  115. .map(|x| x.parse::<f64>().unwrap_or(0.))
  116. .unwrap_or(0.);
  117. if tmp > 0. {
  118. LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
  119. }
  120. log::info!(
  121. "LIMIT_SPEED: {}Mb/s",
  122. LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
  123. );
  124. let tmp = std::env::var("TOTAL_BANDWIDTH")
  125. .map(|x| x.parse::<f64>().unwrap_or(0.))
  126. .unwrap_or(0.);
  127. if tmp > 0. {
  128. TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
  129. }
  130. log::info!(
  131. "TOTAL_BANDWIDTH: {}Mb/s",
  132. TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
  133. );
  134. let tmp = std::env::var("SINGLE_BANDWIDTH")
  135. .map(|x| x.parse::<f64>().unwrap_or(0.))
  136. .unwrap_or(0.);
  137. if tmp > 0. {
  138. SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
  139. }
  140. log::info!(
  141. "SINGLE_BANDWIDTH: {}Mb/s",
  142. SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
  143. )
  144. }
  145. async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
  146. use std::fmt::Write;
  147. let mut res = "".to_owned();
  148. let mut fds = cmd.trim().split(' ');
  149. match fds.next() {
  150. Some("h") => {
  151. res = format!(
  152. "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n",
  153. "blacklist-add(ba) <ip>",
  154. "blacklist-remove(br) <ip>",
  155. "blacklist(b) <ip>",
  156. "blocklist-add(Ba) <ip>",
  157. "blocklist-remove(Br) <ip>",
  158. "blocklist(B) <ip>",
  159. "downgrade-threshold(dt) [value]",
  160. "downgrade-start-check(t) [value(second)]",
  161. "limit-speed(ls) [value(Mb/s)]",
  162. "total-bandwidth(tb) [value(Mb/s)]",
  163. "single-bandwidth(sb) [value(Mb/s)]",
  164. "usage(u)"
  165. )
  166. }
  167. Some("blacklist-add" | "ba") => {
  168. if let Some(ip) = fds.next() {
  169. for ip in ip.split('|') {
  170. BLACKLIST.write().await.insert(ip.to_owned());
  171. }
  172. }
  173. }
  174. Some("blacklist-remove" | "br") => {
  175. if let Some(ip) = fds.next() {
  176. if ip == "all" {
  177. BLACKLIST.write().await.clear();
  178. } else {
  179. for ip in ip.split('|') {
  180. BLACKLIST.write().await.remove(ip);
  181. }
  182. }
  183. }
  184. }
  185. Some("blacklist" | "b") => {
  186. if let Some(ip) = fds.next() {
  187. res = format!("{}\n", BLACKLIST.read().await.get(ip).is_some());
  188. } else {
  189. for ip in BLACKLIST.read().await.clone().into_iter() {
  190. let _ = writeln!(res, "{ip}");
  191. }
  192. }
  193. }
  194. Some("blocklist-add" | "Ba") => {
  195. if let Some(ip) = fds.next() {
  196. for ip in ip.split('|') {
  197. BLOCKLIST.write().await.insert(ip.to_owned());
  198. }
  199. }
  200. }
  201. Some("blocklist-remove" | "Br") => {
  202. if let Some(ip) = fds.next() {
  203. if ip == "all" {
  204. BLOCKLIST.write().await.clear();
  205. } else {
  206. for ip in ip.split('|') {
  207. BLOCKLIST.write().await.remove(ip);
  208. }
  209. }
  210. }
  211. }
  212. Some("blocklist" | "B") => {
  213. if let Some(ip) = fds.next() {
  214. res = format!("{}\n", BLOCKLIST.read().await.get(ip).is_some());
  215. } else {
  216. for ip in BLOCKLIST.read().await.clone().into_iter() {
  217. let _ = writeln!(res, "{ip}");
  218. }
  219. }
  220. }
  221. Some("downgrade-threshold" | "dt") => {
  222. if let Some(v) = fds.next() {
  223. if let Ok(v) = v.parse::<f64>() {
  224. if v > 0. {
  225. DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst);
  226. }
  227. }
  228. } else {
  229. res = format!(
  230. "{}\n",
  231. DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
  232. );
  233. }
  234. }
  235. Some("downgrade-start-check" | "t") => {
  236. if let Some(v) = fds.next() {
  237. if let Ok(v) = v.parse::<usize>() {
  238. if v > 0 {
  239. DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst);
  240. }
  241. }
  242. } else {
  243. res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000);
  244. }
  245. }
  246. Some("limit-speed" | "ls") => {
  247. if let Some(v) = fds.next() {
  248. if let Ok(v) = v.parse::<f64>() {
  249. if v > 0. {
  250. LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
  251. }
  252. }
  253. } else {
  254. res = format!(
  255. "{}Mb/s\n",
  256. LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
  257. );
  258. }
  259. }
  260. Some("total-bandwidth" | "tb") => {
  261. if let Some(v) = fds.next() {
  262. if let Ok(v) = v.parse::<f64>() {
  263. if v > 0. {
  264. TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
  265. limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
  266. }
  267. }
  268. } else {
  269. res = format!(
  270. "{}Mb/s\n",
  271. TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
  272. );
  273. }
  274. }
  275. Some("single-bandwidth" | "sb") => {
  276. if let Some(v) = fds.next() {
  277. if let Ok(v) = v.parse::<f64>() {
  278. if v > 0. {
  279. SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
  280. }
  281. }
  282. } else {
  283. res = format!(
  284. "{}Mb/s\n",
  285. SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
  286. );
  287. }
  288. }
  289. Some("usage" | "u") => {
  290. let mut tmp: Vec<(String, Usage)> = USAGE
  291. .read()
  292. .await
  293. .iter()
  294. .map(|x| (x.0.clone(), *x.1))
  295. .collect();
  296. tmp.sort_by(|a, b| ((b.1).1).partial_cmp(&(a.1).1).unwrap());
  297. for (ip, (elapsed, total, highest, speed)) in tmp {
  298. if elapsed == 0 {
  299. continue;
  300. }
  301. let _ = writeln!(
  302. res,
  303. "{}: {}s {:.2}MB {}kb/s {}kb/s {}kb/s",
  304. ip,
  305. elapsed / 1000,
  306. total as f64 / 1024. / 1024. / 8.,
  307. highest,
  308. total / elapsed,
  309. speed
  310. );
  311. }
  312. }
  313. _ => {}
  314. }
  315. res
  316. }
  317. async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) {
  318. check_params();
  319. let limiter = <Limiter>::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
  320. loop {
  321. tokio::select! {
  322. res = listener.accept() => {
  323. match res {
  324. Ok((stream, addr)) => {
  325. stream.set_nodelay(true).ok();
  326. handle_connection(stream, addr, &limiter, key, false).await;
  327. }
  328. Err(err) => {
  329. log::error!("listener.accept failed: {}", err);
  330. break;
  331. }
  332. }
  333. }
  334. res = listener2.accept() => {
  335. match res {
  336. Ok((stream, addr)) => {
  337. stream.set_nodelay(true).ok();
  338. handle_connection(stream, addr, &limiter, key, true).await;
  339. }
  340. Err(err) => {
  341. log::error!("listener2.accept failed: {}", err);
  342. break;
  343. }
  344. }
  345. }
  346. }
  347. }
  348. }
  349. async fn handle_connection(
  350. stream: TcpStream,
  351. addr: SocketAddr,
  352. limiter: &Limiter,
  353. key: &str,
  354. ws: bool,
  355. ) {
  356. let ip = hbb_common::try_into_v4(addr).ip();
  357. if !ws && ip.is_loopback() {
  358. let limiter = limiter.clone();
  359. tokio::spawn(async move {
  360. let mut stream = stream;
  361. let mut buffer = [0; 1024];
  362. if let Ok(Ok(n)) = timeout(1000, stream.read(&mut buffer[..])).await {
  363. if let Ok(data) = std::str::from_utf8(&buffer[..n]) {
  364. let res = check_cmd(data, limiter).await;
  365. stream.write(res.as_bytes()).await.ok();
  366. }
  367. }
  368. });
  369. return;
  370. }
  371. let ip = ip.to_string();
  372. if BLOCKLIST.read().await.get(&ip).is_some() {
  373. log::info!("{} blocked", ip);
  374. return;
  375. }
  376. let key = key.to_owned();
  377. let limiter = limiter.clone();
  378. tokio::spawn(async move {
  379. allow_err!(make_pair(stream, addr, &key, limiter, ws).await);
  380. });
  381. }
  382. async fn make_pair(
  383. stream: TcpStream,
  384. mut addr: SocketAddr,
  385. key: &str,
  386. limiter: Limiter,
  387. ws: bool,
  388. ) -> ResultType<()> {
  389. if ws {
  390. use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
  391. let callback = |req: &Request, response: Response| {
  392. let headers = req.headers();
  393. let real_ip = headers
  394. .get("X-Real-IP")
  395. .or_else(|| headers.get("X-Forwarded-For"))
  396. .and_then(|header_value| header_value.to_str().ok());
  397. if let Some(ip) = real_ip {
  398. if ip.contains('.') {
  399. addr = format!("{ip}:0").parse().unwrap_or(addr);
  400. } else {
  401. addr = format!("[{ip}]:0").parse().unwrap_or(addr);
  402. }
  403. }
  404. Ok(response)
  405. };
  406. let ws_stream = tokio_tungstenite::accept_hdr_async(stream, callback).await?;
  407. make_pair_(ws_stream, addr, key, limiter).await;
  408. } else {
  409. make_pair_(FramedStream::from(stream, addr), addr, key, limiter).await;
  410. }
  411. Ok(())
  412. }
  413. async fn make_pair_(stream: impl StreamTrait, addr: SocketAddr, key: &str, limiter: Limiter) {
  414. let mut stream = stream;
  415. if let Ok(Some(Ok(bytes))) = timeout(30_000, stream.recv()).await {
  416. if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
  417. if let Some(rendezvous_message::Union::RequestRelay(rf)) = msg_in.union {
  418. if !key.is_empty() && rf.licence_key != key {
  419. return;
  420. }
  421. if !rf.uuid.is_empty() {
  422. let mut peer = PEERS.lock().await.remove(&rf.uuid);
  423. if let Some(peer) = peer.as_mut() {
  424. log::info!("Relayrequest {} from {} got paired", rf.uuid, addr);
  425. let id = format!("{}:{}", addr.ip(), addr.port());
  426. USAGE.write().await.insert(id.clone(), Default::default());
  427. if !stream.is_ws() && !peer.is_ws() {
  428. peer.set_raw();
  429. stream.set_raw();
  430. log::info!("Both are raw");
  431. }
  432. if let Err(err) = relay(addr, &mut stream, peer, limiter, id.clone()).await
  433. {
  434. log::info!("Relay of {} closed: {}", addr, err);
  435. } else {
  436. log::info!("Relay of {} closed", addr);
  437. }
  438. USAGE.write().await.remove(&id);
  439. } else {
  440. log::info!("New relay request {} from {}", rf.uuid, addr);
  441. PEERS.lock().await.insert(rf.uuid.clone(), Box::new(stream));
  442. sleep(30.).await;
  443. PEERS.lock().await.remove(&rf.uuid);
  444. }
  445. }
  446. }
  447. }
  448. }
  449. }
  450. async fn relay(
  451. addr: SocketAddr,
  452. stream: &mut impl StreamTrait,
  453. peer: &mut Box<dyn StreamTrait>,
  454. total_limiter: Limiter,
  455. id: String,
  456. ) -> ResultType<()> {
  457. let ip = addr.ip().to_string();
  458. let mut tm = std::time::Instant::now();
  459. let mut elapsed = 0;
  460. let mut total = 0;
  461. let mut total_s = 0;
  462. let mut highest_s = 0;
  463. let mut downgrade: bool = false;
  464. let mut blacked: bool = false;
  465. let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64;
  466. let limiter = <Limiter>::new(sb);
  467. let blacklist_limiter = <Limiter>::new(LIMIT_SPEED.load(Ordering::SeqCst) as _);
  468. let downgrade_threshold =
  469. (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms
  470. let mut timer = interval(Duration::from_secs(3));
  471. let mut last_recv_time = std::time::Instant::now();
  472. loop {
  473. tokio::select! {
  474. res = peer.recv() => {
  475. if let Some(Ok(bytes)) = res {
  476. last_recv_time = std::time::Instant::now();
  477. let nb = bytes.len() * 8;
  478. if blacked || downgrade {
  479. blacklist_limiter.consume(nb).await;
  480. } else {
  481. limiter.consume(nb).await;
  482. }
  483. total_limiter.consume(nb).await;
  484. total += nb;
  485. total_s += nb;
  486. if !bytes.is_empty() {
  487. stream.send_raw(bytes.into()).await?;
  488. }
  489. } else {
  490. break;
  491. }
  492. },
  493. res = stream.recv() => {
  494. if let Some(Ok(bytes)) = res {
  495. last_recv_time = std::time::Instant::now();
  496. let nb = bytes.len() * 8;
  497. if blacked || downgrade {
  498. blacklist_limiter.consume(nb).await;
  499. } else {
  500. limiter.consume(nb).await;
  501. }
  502. total_limiter.consume(nb).await;
  503. total += nb;
  504. total_s += nb;
  505. if !bytes.is_empty() {
  506. peer.send_raw(bytes.into()).await?;
  507. }
  508. } else {
  509. break;
  510. }
  511. },
  512. _ = timer.tick() => {
  513. if last_recv_time.elapsed().as_secs() > 30 {
  514. bail!("Timeout");
  515. }
  516. }
  517. }
  518. let n = tm.elapsed().as_millis() as usize;
  519. if n >= 1_000 {
  520. if BLOCKLIST.read().await.get(&ip).is_some() {
  521. log::info!("{} blocked", ip);
  522. break;
  523. }
  524. blacked = BLACKLIST.read().await.get(&ip).is_some();
  525. tm = std::time::Instant::now();
  526. let speed = total_s / n;
  527. if speed > highest_s {
  528. highest_s = speed;
  529. }
  530. elapsed += n;
  531. USAGE.write().await.insert(
  532. id.clone(),
  533. (elapsed as _, total as _, highest_s as _, speed as _),
  534. );
  535. total_s = 0;
  536. if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst)
  537. && !downgrade
  538. && total > elapsed * downgrade_threshold
  539. {
  540. downgrade = true;
  541. log::info!(
  542. "Downgrade {}, exceed downgrade threshold {}bit/ms in {}ms",
  543. id,
  544. downgrade_threshold,
  545. elapsed
  546. );
  547. }
  548. }
  549. }
  550. Ok(())
  551. }
  552. fn get_server_sk(key: &str) -> String {
  553. let mut key = key.to_owned();
  554. if let Ok(sk) = base64::decode(&key) {
  555. if sk.len() == sign::SECRETKEYBYTES {
  556. log::info!("The key is a crypto private key");
  557. key = base64::encode(&sk[(sign::SECRETKEYBYTES / 2)..]);
  558. }
  559. }
  560. if key == "-" || key == "_" {
  561. let (pk, _) = crate::common::gen_sk(300);
  562. key = pk;
  563. }
  564. if !key.is_empty() {
  565. log::info!("Key: {}", key);
  566. }
  567. key
  568. }
  569. #[async_trait]
  570. trait StreamTrait: Send + Sync + 'static {
  571. async fn recv(&mut self) -> Option<Result<BytesMut, Error>>;
  572. async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()>;
  573. fn is_ws(&self) -> bool;
  574. fn set_raw(&mut self);
  575. }
  576. #[async_trait]
  577. impl StreamTrait for FramedStream {
  578. async fn recv(&mut self) -> Option<Result<BytesMut, Error>> {
  579. self.next().await
  580. }
  581. async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> {
  582. self.send_bytes(bytes).await
  583. }
  584. fn is_ws(&self) -> bool {
  585. false
  586. }
  587. fn set_raw(&mut self) {
  588. self.set_raw();
  589. }
  590. }
  591. #[async_trait]
  592. impl StreamTrait for tokio_tungstenite::WebSocketStream<TcpStream> {
  593. async fn recv(&mut self) -> Option<Result<BytesMut, Error>> {
  594. if let Some(msg) = self.next().await {
  595. match msg {
  596. Ok(msg) => {
  597. match msg {
  598. tungstenite::Message::Binary(bytes) => {
  599. Some(Ok(bytes[..].into())) // to-do: poor performance
  600. }
  601. _ => Some(Ok(BytesMut::new())),
  602. }
  603. }
  604. Err(err) => Some(Err(Error::new(std::io::ErrorKind::Other, err.to_string()))),
  605. }
  606. } else {
  607. None
  608. }
  609. }
  610. async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> {
  611. Ok(self
  612. .send(tungstenite::Message::Binary(bytes.to_vec()))
  613. .await?) // to-do: poor performance
  614. }
  615. fn is_ws(&self) -> bool {
  616. true
  617. }
  618. fn set_raw(&mut self) {}
  619. }