| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647 |
- use async_speed_limit::Limiter;
- use async_trait::async_trait;
- use hbb_common::{
- allow_err, bail,
- bytes::{Bytes, BytesMut},
- futures_util::{sink::SinkExt, stream::StreamExt},
- log,
- protobuf::Message as _,
- rendezvous_proto::*,
- sleep,
- tcp::{listen_any, FramedStream},
- timeout,
- tokio::{
- self,
- io::{AsyncReadExt, AsyncWriteExt},
- net::{TcpListener, TcpStream},
- sync::{Mutex, RwLock},
- time::{interval, Duration},
- },
- ResultType,
- };
- use sodiumoxide::crypto::sign;
- use std::{
- collections::{HashMap, HashSet},
- io::prelude::*,
- io::Error,
- net::SocketAddr,
- sync::atomic::{AtomicUsize, Ordering},
- };
- type Usage = (usize, usize, usize, usize);
- lazy_static::lazy_static! {
- static ref PEERS: Mutex<HashMap<String, Box<dyn StreamTrait>>> = Default::default();
- static ref USAGE: RwLock<HashMap<String, Usage>> = Default::default();
- static ref BLACKLIST: RwLock<HashSet<String>> = Default::default();
- static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default();
- }
- static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66
- static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms
- static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s
- static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s
- static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s
- const BLACKLIST_FILE: &str = "blacklist.txt";
- const BLOCKLIST_FILE: &str = "blocklist.txt";
- #[tokio::main(flavor = "multi_thread")]
- pub async fn start(port: &str, key: &str) -> ResultType<()> {
- let key = get_server_sk(key);
- if let Ok(mut file) = std::fs::File::open(BLACKLIST_FILE) {
- let mut contents = String::new();
- if file.read_to_string(&mut contents).is_ok() {
- for x in contents.split('\n') {
- if let Some(ip) = x.trim().split(' ').next() {
- BLACKLIST.write().await.insert(ip.to_owned());
- }
- }
- }
- }
- log::info!(
- "#blacklist({}): {}",
- BLACKLIST_FILE,
- BLACKLIST.read().await.len()
- );
- if let Ok(mut file) = std::fs::File::open(BLOCKLIST_FILE) {
- let mut contents = String::new();
- if file.read_to_string(&mut contents).is_ok() {
- for x in contents.split('\n') {
- if let Some(ip) = x.trim().split(' ').next() {
- BLOCKLIST.write().await.insert(ip.to_owned());
- }
- }
- }
- }
- log::info!(
- "#blocklist({}): {}",
- BLOCKLIST_FILE,
- BLOCKLIST.read().await.len()
- );
- let port: u16 = port.parse()?;
- log::info!("Listening on tcp :{}", port);
- let port2 = port + 2;
- log::info!("Listening on websocket :{}", port2);
- let main_task = async move {
- loop {
- log::info!("Start");
- io_loop(listen_any(port).await?, listen_any(port2).await?, &key).await;
- }
- };
- let listen_signal = crate::common::listen_signal();
- tokio::select!(
- res = main_task => res,
- res = listen_signal => res,
- )
- }
- fn check_params() {
- let tmp = std::env::var("DOWNGRADE_THRESHOLD")
- .map(|x| x.parse::<f64>().unwrap_or(0.))
- .unwrap_or(0.);
- if tmp > 0. {
- DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst);
- }
- log::info!(
- "DOWNGRADE_THRESHOLD: {}",
- DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
- );
- let tmp = std::env::var("DOWNGRADE_START_CHECK")
- .map(|x| x.parse::<usize>().unwrap_or(0))
- .unwrap_or(0);
- if tmp > 0 {
- DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst);
- }
- log::info!(
- "DOWNGRADE_START_CHECK: {}s",
- DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000
- );
- let tmp = std::env::var("LIMIT_SPEED")
- .map(|x| x.parse::<f64>().unwrap_or(0.))
- .unwrap_or(0.);
- if tmp > 0. {
- LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
- }
- log::info!(
- "LIMIT_SPEED: {}Mb/s",
- LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
- );
- let tmp = std::env::var("TOTAL_BANDWIDTH")
- .map(|x| x.parse::<f64>().unwrap_or(0.))
- .unwrap_or(0.);
- if tmp > 0. {
- TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
- }
- log::info!(
- "TOTAL_BANDWIDTH: {}Mb/s",
- TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
- );
- let tmp = std::env::var("SINGLE_BANDWIDTH")
- .map(|x| x.parse::<f64>().unwrap_or(0.))
- .unwrap_or(0.);
- if tmp > 0. {
- SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
- }
- log::info!(
- "SINGLE_BANDWIDTH: {}Mb/s",
- SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
- )
- }
- async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
- use std::fmt::Write;
- let mut res = "".to_owned();
- let mut fds = cmd.trim().split(' ');
- match fds.next() {
- Some("h") => {
- res = format!(
- "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n",
- "blacklist-add(ba) <ip>",
- "blacklist-remove(br) <ip>",
- "blacklist(b) <ip>",
- "blocklist-add(Ba) <ip>",
- "blocklist-remove(Br) <ip>",
- "blocklist(B) <ip>",
- "downgrade-threshold(dt) [value]",
- "downgrade-start-check(t) [value(second)]",
- "limit-speed(ls) [value(Mb/s)]",
- "total-bandwidth(tb) [value(Mb/s)]",
- "single-bandwidth(sb) [value(Mb/s)]",
- "usage(u)"
- )
- }
- Some("blacklist-add" | "ba") => {
- if let Some(ip) = fds.next() {
- for ip in ip.split('|') {
- BLACKLIST.write().await.insert(ip.to_owned());
- }
- }
- }
- Some("blacklist-remove" | "br") => {
- if let Some(ip) = fds.next() {
- if ip == "all" {
- BLACKLIST.write().await.clear();
- } else {
- for ip in ip.split('|') {
- BLACKLIST.write().await.remove(ip);
- }
- }
- }
- }
- Some("blacklist" | "b") => {
- if let Some(ip) = fds.next() {
- res = format!("{}\n", BLACKLIST.read().await.get(ip).is_some());
- } else {
- for ip in BLACKLIST.read().await.clone().into_iter() {
- let _ = writeln!(res, "{ip}");
- }
- }
- }
- Some("blocklist-add" | "Ba") => {
- if let Some(ip) = fds.next() {
- for ip in ip.split('|') {
- BLOCKLIST.write().await.insert(ip.to_owned());
- }
- }
- }
- Some("blocklist-remove" | "Br") => {
- if let Some(ip) = fds.next() {
- if ip == "all" {
- BLOCKLIST.write().await.clear();
- } else {
- for ip in ip.split('|') {
- BLOCKLIST.write().await.remove(ip);
- }
- }
- }
- }
- Some("blocklist" | "B") => {
- if let Some(ip) = fds.next() {
- res = format!("{}\n", BLOCKLIST.read().await.get(ip).is_some());
- } else {
- for ip in BLOCKLIST.read().await.clone().into_iter() {
- let _ = writeln!(res, "{ip}");
- }
- }
- }
- Some("downgrade-threshold" | "dt") => {
- if let Some(v) = fds.next() {
- if let Ok(v) = v.parse::<f64>() {
- if v > 0. {
- DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst);
- }
- }
- } else {
- res = format!(
- "{}\n",
- DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
- );
- }
- }
- Some("downgrade-start-check" | "t") => {
- if let Some(v) = fds.next() {
- if let Ok(v) = v.parse::<usize>() {
- if v > 0 {
- DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst);
- }
- }
- } else {
- res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000);
- }
- }
- Some("limit-speed" | "ls") => {
- if let Some(v) = fds.next() {
- if let Ok(v) = v.parse::<f64>() {
- if v > 0. {
- LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
- }
- }
- } else {
- res = format!(
- "{}Mb/s\n",
- LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
- );
- }
- }
- Some("total-bandwidth" | "tb") => {
- if let Some(v) = fds.next() {
- if let Ok(v) = v.parse::<f64>() {
- if v > 0. {
- TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
- limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
- }
- }
- } else {
- res = format!(
- "{}Mb/s\n",
- TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
- );
- }
- }
- Some("single-bandwidth" | "sb") => {
- if let Some(v) = fds.next() {
- if let Ok(v) = v.parse::<f64>() {
- if v > 0. {
- SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
- }
- }
- } else {
- res = format!(
- "{}Mb/s\n",
- SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
- );
- }
- }
- Some("usage" | "u") => {
- let mut tmp: Vec<(String, Usage)> = USAGE
- .read()
- .await
- .iter()
- .map(|x| (x.0.clone(), *x.1))
- .collect();
- tmp.sort_by(|a, b| ((b.1).1).partial_cmp(&(a.1).1).unwrap());
- for (ip, (elapsed, total, highest, speed)) in tmp {
- if elapsed == 0 {
- continue;
- }
- let _ = writeln!(
- res,
- "{}: {}s {:.2}MB {}kb/s {}kb/s {}kb/s",
- ip,
- elapsed / 1000,
- total as f64 / 1024. / 1024. / 8.,
- highest,
- total / elapsed,
- speed
- );
- }
- }
- _ => {}
- }
- res
- }
- async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) {
- check_params();
- let limiter = <Limiter>::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
- loop {
- tokio::select! {
- res = listener.accept() => {
- match res {
- Ok((stream, addr)) => {
- stream.set_nodelay(true).ok();
- handle_connection(stream, addr, &limiter, key, false).await;
- }
- Err(err) => {
- log::error!("listener.accept failed: {}", err);
- break;
- }
- }
- }
- res = listener2.accept() => {
- match res {
- Ok((stream, addr)) => {
- stream.set_nodelay(true).ok();
- handle_connection(stream, addr, &limiter, key, true).await;
- }
- Err(err) => {
- log::error!("listener2.accept failed: {}", err);
- break;
- }
- }
- }
- }
- }
- }
- async fn handle_connection(
- stream: TcpStream,
- addr: SocketAddr,
- limiter: &Limiter,
- key: &str,
- ws: bool,
- ) {
- let ip = hbb_common::try_into_v4(addr).ip();
- if !ws && ip.is_loopback() {
- let limiter = limiter.clone();
- tokio::spawn(async move {
- let mut stream = stream;
- let mut buffer = [0; 1024];
- if let Ok(Ok(n)) = timeout(1000, stream.read(&mut buffer[..])).await {
- if let Ok(data) = std::str::from_utf8(&buffer[..n]) {
- let res = check_cmd(data, limiter).await;
- stream.write(res.as_bytes()).await.ok();
- }
- }
- });
- return;
- }
- let ip = ip.to_string();
- if BLOCKLIST.read().await.get(&ip).is_some() {
- log::info!("{} blocked", ip);
- return;
- }
- let key = key.to_owned();
- let limiter = limiter.clone();
- tokio::spawn(async move {
- allow_err!(make_pair(stream, addr, &key, limiter, ws).await);
- });
- }
- async fn make_pair(
- stream: TcpStream,
- mut addr: SocketAddr,
- key: &str,
- limiter: Limiter,
- ws: bool,
- ) -> ResultType<()> {
- if ws {
- use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
- let callback = |req: &Request, response: Response| {
- let headers = req.headers();
- let real_ip = headers
- .get("X-Real-IP")
- .or_else(|| headers.get("X-Forwarded-For"))
- .and_then(|header_value| header_value.to_str().ok());
- if let Some(ip) = real_ip {
- if ip.contains('.') {
- addr = format!("{ip}:0").parse().unwrap_or(addr);
- } else {
- addr = format!("[{ip}]:0").parse().unwrap_or(addr);
- }
- }
- Ok(response)
- };
- let ws_stream = tokio_tungstenite::accept_hdr_async(stream, callback).await?;
- make_pair_(ws_stream, addr, key, limiter).await;
- } else {
- make_pair_(FramedStream::from(stream, addr), addr, key, limiter).await;
- }
- Ok(())
- }
- async fn make_pair_(stream: impl StreamTrait, addr: SocketAddr, key: &str, limiter: Limiter) {
- let mut stream = stream;
- if let Ok(Some(Ok(bytes))) = timeout(30_000, stream.recv()).await {
- if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
- if let Some(rendezvous_message::Union::RequestRelay(rf)) = msg_in.union {
- if !key.is_empty() && rf.licence_key != key {
- return;
- }
- if !rf.uuid.is_empty() {
- let mut peer = PEERS.lock().await.remove(&rf.uuid);
- if let Some(peer) = peer.as_mut() {
- log::info!("Relayrequest {} from {} got paired", rf.uuid, addr);
- let id = format!("{}:{}", addr.ip(), addr.port());
- USAGE.write().await.insert(id.clone(), Default::default());
- if !stream.is_ws() && !peer.is_ws() {
- peer.set_raw();
- stream.set_raw();
- log::info!("Both are raw");
- }
- if let Err(err) = relay(addr, &mut stream, peer, limiter, id.clone()).await
- {
- log::info!("Relay of {} closed: {}", addr, err);
- } else {
- log::info!("Relay of {} closed", addr);
- }
- USAGE.write().await.remove(&id);
- } else {
- log::info!("New relay request {} from {}", rf.uuid, addr);
- PEERS.lock().await.insert(rf.uuid.clone(), Box::new(stream));
- sleep(30.).await;
- PEERS.lock().await.remove(&rf.uuid);
- }
- }
- }
- }
- }
- }
- async fn relay(
- addr: SocketAddr,
- stream: &mut impl StreamTrait,
- peer: &mut Box<dyn StreamTrait>,
- total_limiter: Limiter,
- id: String,
- ) -> ResultType<()> {
- let ip = addr.ip().to_string();
- let mut tm = std::time::Instant::now();
- let mut elapsed = 0;
- let mut total = 0;
- let mut total_s = 0;
- let mut highest_s = 0;
- let mut downgrade: bool = false;
- let mut blacked: bool = false;
- let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64;
- let limiter = <Limiter>::new(sb);
- let blacklist_limiter = <Limiter>::new(LIMIT_SPEED.load(Ordering::SeqCst) as _);
- let downgrade_threshold =
- (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms
- let mut timer = interval(Duration::from_secs(3));
- let mut last_recv_time = std::time::Instant::now();
- loop {
- tokio::select! {
- res = peer.recv() => {
- if let Some(Ok(bytes)) = res {
- last_recv_time = std::time::Instant::now();
- let nb = bytes.len() * 8;
- if blacked || downgrade {
- blacklist_limiter.consume(nb).await;
- } else {
- limiter.consume(nb).await;
- }
- total_limiter.consume(nb).await;
- total += nb;
- total_s += nb;
- if !bytes.is_empty() {
- stream.send_raw(bytes.into()).await?;
- }
- } else {
- break;
- }
- },
- res = stream.recv() => {
- if let Some(Ok(bytes)) = res {
- last_recv_time = std::time::Instant::now();
- let nb = bytes.len() * 8;
- if blacked || downgrade {
- blacklist_limiter.consume(nb).await;
- } else {
- limiter.consume(nb).await;
- }
- total_limiter.consume(nb).await;
- total += nb;
- total_s += nb;
- if !bytes.is_empty() {
- peer.send_raw(bytes.into()).await?;
- }
- } else {
- break;
- }
- },
- _ = timer.tick() => {
- if last_recv_time.elapsed().as_secs() > 30 {
- bail!("Timeout");
- }
- }
- }
- let n = tm.elapsed().as_millis() as usize;
- if n >= 1_000 {
- if BLOCKLIST.read().await.get(&ip).is_some() {
- log::info!("{} blocked", ip);
- break;
- }
- blacked = BLACKLIST.read().await.get(&ip).is_some();
- tm = std::time::Instant::now();
- let speed = total_s / n;
- if speed > highest_s {
- highest_s = speed;
- }
- elapsed += n;
- USAGE.write().await.insert(
- id.clone(),
- (elapsed as _, total as _, highest_s as _, speed as _),
- );
- total_s = 0;
- if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst)
- && !downgrade
- && total > elapsed * downgrade_threshold
- {
- downgrade = true;
- log::info!(
- "Downgrade {}, exceed downgrade threshold {}bit/ms in {}ms",
- id,
- downgrade_threshold,
- elapsed
- );
- }
- }
- }
- Ok(())
- }
- fn get_server_sk(key: &str) -> String {
- let mut key = key.to_owned();
- if let Ok(sk) = base64::decode(&key) {
- if sk.len() == sign::SECRETKEYBYTES {
- log::info!("The key is a crypto private key");
- key = base64::encode(&sk[(sign::SECRETKEYBYTES / 2)..]);
- }
- }
- if key == "-" || key == "_" {
- let (pk, _) = crate::common::gen_sk(300);
- key = pk;
- }
- if !key.is_empty() {
- log::info!("Key: {}", key);
- }
- key
- }
- #[async_trait]
- trait StreamTrait: Send + Sync + 'static {
- async fn recv(&mut self) -> Option<Result<BytesMut, Error>>;
- async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()>;
- fn is_ws(&self) -> bool;
- fn set_raw(&mut self);
- }
- #[async_trait]
- impl StreamTrait for FramedStream {
- async fn recv(&mut self) -> Option<Result<BytesMut, Error>> {
- self.next().await
- }
- async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> {
- self.send_bytes(bytes).await
- }
- fn is_ws(&self) -> bool {
- false
- }
- fn set_raw(&mut self) {
- self.set_raw();
- }
- }
- #[async_trait]
- impl StreamTrait for tokio_tungstenite::WebSocketStream<TcpStream> {
- async fn recv(&mut self) -> Option<Result<BytesMut, Error>> {
- if let Some(msg) = self.next().await {
- match msg {
- Ok(msg) => {
- match msg {
- tungstenite::Message::Binary(bytes) => {
- Some(Ok(bytes[..].into())) // to-do: poor performance
- }
- _ => Some(Ok(BytesMut::new())),
- }
- }
- Err(err) => Some(Err(Error::new(std::io::ErrorKind::Other, err.to_string()))),
- }
- } else {
- None
- }
- }
- async fn send_raw(&mut self, bytes: Bytes) -> ResultType<()> {
- Ok(self
- .send(tungstenite::Message::Binary(bytes.to_vec()))
- .await?) // to-do: poor performance
- }
- fn is_ws(&self) -> bool {
- true
- }
- fn set_raw(&mut self) {}
- }
|