rendezvous_server.rs 20 KB


  1. use hbb_common::{
  2. allow_err,
  3. bytes::{Bytes, BytesMut},
  4. bytes_codec::BytesCodec,
  5. futures_util::{
  6. sink::SinkExt,
  7. stream::{SplitSink, StreamExt},
  8. },
  9. log,
  10. protobuf::Message as _,
  11. rendezvous_proto::*,
  12. tcp::new_listener,
  13. timeout,
  14. tokio::{self, net::TcpStream, sync::mpsc},
  15. tokio_util::codec::Framed,
  16. udp::FramedSocket,
  17. AddrMangle, ResultType,
  18. };
  19. use serde_derive::{Deserialize, Serialize};
  20. use std::{
  21. collections::HashMap,
  22. net::SocketAddr,
  23. sync::{Arc, Mutex, RwLock},
  24. time::Instant,
  25. };
  26. #[derive(Clone, Debug)]
  27. struct Peer {
  28. socket_addr: SocketAddr,
  29. last_reg_time: Instant,
  30. uuid: Vec<u8>,
  31. pk: Vec<u8>,
  32. }
  33. impl Default for Peer {
  34. fn default() -> Self {
  35. Self {
  36. socket_addr: "0.0.0.0:0".parse().unwrap(),
  37. last_reg_time: Instant::now()
  38. .checked_sub(std::time::Duration::from_secs(3600))
  39. .unwrap(),
  40. uuid: Vec::new(),
  41. pk: Vec::new(),
  42. }
  43. }
  44. }
  45. #[derive(Debug, Serialize, Deserialize, Default)]
  46. struct PeerSerde {
  47. #[serde(default)]
  48. ip: String,
  49. #[serde(default)]
  50. uuid: Vec<u8>,
  51. #[serde(default)]
  52. pk: Vec<u8>,
  53. }
  54. #[derive(Clone)]
  55. struct PeerMap {
  56. map: Arc<RwLock<HashMap<String, Peer>>>,
  57. db: super::SledAsync,
  58. }
  59. impl PeerMap {
  60. fn new() -> ResultType<Self> {
  61. Ok(Self {
  62. map: Default::default(),
  63. db: super::SledAsync::new("./sled.db", true)?,
  64. })
  65. }
  66. #[inline]
  67. fn update_pk(&mut self, id: String, socket_addr: SocketAddr, uuid: Vec<u8>, pk: Vec<u8>) {
  68. log::info!("update_pk {} {:?} {:?} {:?}", id, socket_addr, uuid, pk);
  69. let mut lock = self.map.write().unwrap();
  70. lock.insert(
  71. id.clone(),
  72. Peer {
  73. socket_addr,
  74. last_reg_time: Instant::now(),
  75. uuid: uuid.clone(),
  76. pk: pk.clone(),
  77. },
  78. );
  79. drop(lock);
  80. let ip = socket_addr.ip().to_string();
  81. self.db.insert(id, PeerSerde { ip, uuid, pk });
  82. }
  83. #[inline]
  84. async fn get(&mut self, id: &str) -> Option<Peer> {
  85. let p = self.map.read().unwrap().get(id).map(|x| x.clone());
  86. if p.is_some() {
  87. return p;
  88. } else {
  89. let id = id.to_owned();
  90. let v = self.db.get(id.clone()).await;
  91. if let Some(v) = super::SledAsync::deserialize::<PeerSerde>(&v) {
  92. self.map.write().unwrap().insert(
  93. id,
  94. Peer {
  95. uuid: v.uuid,
  96. pk: v.pk,
  97. ..Default::default()
  98. },
  99. );
  100. return Some(Peer::default());
  101. }
  102. }
  103. None
  104. }
  105. #[inline]
  106. fn is_in_memory(&self, id: &str) -> bool {
  107. self.map.read().unwrap().contains_key(id)
  108. }
  109. }
  110. const REG_TIMEOUT: i32 = 30_000;
  111. type Sink = SplitSink<Framed<TcpStream, BytesCodec>, Bytes>;
  112. type Sender = mpsc::UnboundedSender<(RendezvousMessage, SocketAddr)>;
  113. #[derive(Clone)]
  114. pub struct RendezvousServer {
  115. tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>,
  116. pm: PeerMap,
  117. tx: Sender,
  118. relay_server: String,
  119. serial: i32,
  120. rendezvous_servers: Vec<String>,
  121. }
  122. impl RendezvousServer {
  123. pub async fn start(
  124. addr: &str,
  125. relay_server: String,
  126. serial: i32,
  127. rendezvous_servers: Vec<String>,
  128. ) -> ResultType<()> {
  129. let mut socket = FramedSocket::new(addr).await?;
  130. let (tx, mut rx) = mpsc::unbounded_channel::<(RendezvousMessage, SocketAddr)>();
  131. let mut rs = Self {
  132. tcp_punch: Arc::new(Mutex::new(HashMap::new())),
  133. pm: PeerMap::new()?,
  134. tx: tx.clone(),
  135. relay_server,
  136. serial,
  137. rendezvous_servers,
  138. };
  139. let mut listener = new_listener(addr, false).await?;
  140. loop {
  141. tokio::select! {
  142. Some((msg, addr)) = rx.recv() => {
  143. allow_err!(socket.send(&msg, addr).await);
  144. }
  145. Some(Ok((bytes, addr))) = socket.next() => {
  146. allow_err!(rs.handle_msg(&bytes, addr, &mut socket).await);
  147. }
  148. Ok((stream, addr)) = listener.accept() => {
  149. log::debug!("Tcp connection from {:?}", addr);
  150. let (a, mut b) = Framed::new(stream, BytesCodec::new()).split();
  151. let tcp_punch = rs.tcp_punch.clone();
  152. tcp_punch.lock().unwrap().insert(addr, a);
  153. let mut rs = rs.clone();
  154. tokio::spawn(async move {
  155. while let Ok(Some(Ok(bytes))) = timeout(30_000, b.next()).await {
  156. if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
  157. match msg_in.union {
  158. Some(rendezvous_message::Union::punch_hole_request(ph)) => {
  159. allow_err!(rs.handle_tcp_punch_hole_request(addr, ph.id).await);
  160. }
  161. Some(rendezvous_message::Union::request_relay(mut rf)) => {
  162. if let Some(peer) = rs.pm.map.read().unwrap().get(&rf.id).map(|x| x.clone()) {
  163. let mut msg_out = RendezvousMessage::new();
  164. rf.socket_addr = AddrMangle::encode(addr);
  165. msg_out.set_request_relay(rf);
  166. rs.tx.send((msg_out, peer.socket_addr)).ok();
  167. }
  168. }
  169. Some(rendezvous_message::Union::request_relay_response(mut rfr)) => {
  170. let addr_b = AddrMangle::decode(&rfr.socket_addr);
  171. rfr.socket_addr = Default::default();
  172. let mut msg_out = RendezvousMessage::new();
  173. msg_out.set_request_relay_response(rfr);
  174. let sender_b = rs.tcp_punch.lock().unwrap().remove(&addr_b);
  175. if let Some(mut sender_b) = sender_b {
  176. if let Ok(bytes) = msg_out.write_to_bytes() {
  177. allow_err!(sender_b.send(Bytes::from(bytes)).await);
  178. }
  179. }
  180. break;
  181. }
  182. Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
  183. allow_err!(rs.handle_hole_sent(phs, addr, None).await);
  184. break;
  185. }
  186. Some(rendezvous_message::Union::local_addr(la)) => {
  187. allow_err!(rs.handle_local_addr(&la, addr, None).await);
  188. break;
  189. }
  190. _ => {
  191. break;
  192. }
  193. }
  194. } else {
  195. break;
  196. }
  197. }
  198. rs.tcp_punch.lock().unwrap().remove(&addr);
  199. log::debug!("Tcp connection from {:?} closed", addr);
  200. });
  201. }
  202. }
  203. }
  204. }
  205. #[inline]
  206. async fn handle_msg(
  207. &mut self,
  208. bytes: &BytesMut,
  209. addr: SocketAddr,
  210. socket: &mut FramedSocket,
  211. ) -> ResultType<()> {
  212. if let Ok(msg_in) = RendezvousMessage::parse_from_bytes(&bytes) {
  213. match msg_in.union {
  214. Some(rendezvous_message::Union::register_peer(rp)) => {
  215. // B registered
  216. if rp.id.len() > 0 {
  217. log::debug!("New peer registered: {:?} {:?}", &rp.id, &addr);
  218. self.update_addr(rp.id, addr, socket).await?;
  219. if self.serial != rp.serial {
  220. let mut msg_out = RendezvousMessage::new();
  221. let mut mi = MiscInfo::new();
  222. mi.set_configure_update(ConfigUpdate {
  223. serial: self.serial,
  224. rendezvous_servers: self.rendezvous_servers.clone(),
  225. ..Default::default()
  226. });
  227. msg_out.set_misc_info(mi);
  228. socket.send(&msg_out, addr).await?;
  229. }
  230. }
  231. }
  232. Some(rendezvous_message::Union::register_pk(rk)) => {
  233. if rk.uuid.is_empty() {
  234. return Ok(());
  235. }
  236. let id = rk.id;
  237. let mut res = register_pk_response::Result::OK;
  238. if let Some(peer) = self.pm.get(&id).await {
  239. if !peer.uuid.is_empty() && peer.uuid != rk.uuid {
  240. log::warn!(
  241. "Peer {} uuid mismatch: {:?} vs {:?}",
  242. id,
  243. rk.uuid,
  244. peer.uuid
  245. );
  246. res = register_pk_response::Result::UUID_MISMATCH;
  247. } else if peer.uuid.is_empty() || peer.pk != rk.pk {
  248. self.pm.update_pk(id, addr, rk.uuid, rk.pk);
  249. }
  250. } else {
  251. self.pm.update_pk(id, addr, rk.uuid, rk.pk);
  252. }
  253. let mut msg_out = RendezvousMessage::new();
  254. msg_out.set_register_pk_response(RegisterPkResponse {
  255. result: res.into(),
  256. ..Default::default()
  257. });
  258. socket.send(&msg_out, addr).await?
  259. }
  260. Some(rendezvous_message::Union::punch_hole_request(ph)) => {
  261. let id = ph.id;
  262. if self.pm.is_in_memory(&id) {
  263. self.handle_udp_punch_hole_request(addr, id).await?;
  264. } else {
  265. // not in memory, fetch from db with spawn in case blocking me
  266. let mut me = self.clone();
  267. tokio::spawn(async move {
  268. allow_err!(me.handle_udp_punch_hole_request(addr, id).await);
  269. });
  270. }
  271. }
  272. Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
  273. self.handle_hole_sent(phs, addr, Some(socket)).await?;
  274. }
  275. Some(rendezvous_message::Union::local_addr(la)) => {
  276. self.handle_local_addr(&la, addr, Some(socket)).await?;
  277. }
  278. Some(rendezvous_message::Union::misc_info(mi)) => match mi.union {
  279. Some(misc_info::Union::configure_update(mut cu)) => {
  280. if addr.ip() == std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
  281. {
  282. self.serial = cu.serial;
  283. self.rendezvous_servers = cu
  284. .rendezvous_servers
  285. .drain(..)
  286. .map(|x| {
  287. if !x.contains(":") {
  288. format!("{}:21116", x)
  289. } else {
  290. x
  291. }
  292. })
  293. .filter(|x| x.parse::<std::net::SocketAddr>().is_ok())
  294. .collect();
  295. log::info!(
  296. "configure updated: serial={} rendezvous-servers={:?}",
  297. self.serial,
  298. self.rendezvous_servers
  299. );
  300. }
  301. }
  302. _ => {}
  303. },
  304. _ => {}
  305. }
  306. }
  307. Ok(())
  308. }
  309. #[inline]
  310. async fn update_addr(
  311. &mut self,
  312. id: String,
  313. socket_addr: SocketAddr,
  314. socket: &mut FramedSocket,
  315. ) -> ResultType<()> {
  316. let mut lock = self.pm.map.write().unwrap();
  317. let last_reg_time = Instant::now();
  318. if let Some(old) = lock.get_mut(&id) {
  319. old.socket_addr = socket_addr;
  320. old.last_reg_time = last_reg_time;
  321. let request_pk = old.pk.is_empty();
  322. drop(lock);
  323. let mut msg_out = RendezvousMessage::new();
  324. msg_out.set_register_peer_response(RegisterPeerResponse {
  325. request_pk,
  326. ..Default::default()
  327. });
  328. socket.send(&msg_out, socket_addr).await?;
  329. } else {
  330. drop(lock);
  331. let mut pm = self.pm.clone();
  332. let tx = self.tx.clone();
  333. tokio::spawn(async move {
  334. let v = pm.db.get(id.clone()).await;
  335. let (uuid, pk) = {
  336. if let Some(v) = super::SledAsync::deserialize::<PeerSerde>(&v) {
  337. (v.uuid, v.pk)
  338. } else {
  339. (Vec::new(), Vec::new())
  340. }
  341. };
  342. let mut msg_out = RendezvousMessage::new();
  343. msg_out.set_register_peer_response(RegisterPeerResponse {
  344. request_pk: pk.is_empty(),
  345. ..Default::default()
  346. });
  347. tx.send((msg_out, socket_addr)).ok();
  348. pm.map.write().unwrap().insert(
  349. id,
  350. Peer {
  351. socket_addr,
  352. last_reg_time,
  353. uuid,
  354. pk,
  355. },
  356. );
  357. });
  358. }
  359. Ok(())
  360. }
  361. #[inline]
  362. async fn handle_hole_sent<'a>(
  363. &mut self,
  364. phs: PunchHoleSent,
  365. addr: SocketAddr,
  366. socket: Option<&'a mut FramedSocket>,
  367. ) -> ResultType<()> {
  368. // punch hole sent from B, tell A that B is ready to be connected
  369. let addr_a = AddrMangle::decode(&phs.socket_addr);
  370. log::debug!(
  371. "{} punch hole response to {:?} from {:?}",
  372. if socket.is_none() { "TCP" } else { "UDP" },
  373. &addr_a,
  374. &addr
  375. );
  376. let mut msg_out = RendezvousMessage::new();
  377. let pk = match self.pm.get(&phs.id).await {
  378. Some(peer) => peer.pk,
  379. _ => Vec::new(),
  380. };
  381. let mut relay_server = phs.relay_server;
  382. if relay_server.is_empty() {
  383. relay_server = self.relay_server.clone();
  384. }
  385. msg_out.set_punch_hole_response(PunchHoleResponse {
  386. socket_addr: AddrMangle::encode(addr),
  387. pk,
  388. relay_server,
  389. ..Default::default()
  390. });
  391. if let Some(socket) = socket {
  392. socket.send(&msg_out, addr_a).await?;
  393. } else {
  394. self.send_to_tcp(&msg_out, addr_a).await;
  395. }
  396. Ok(())
  397. }
  398. #[inline]
  399. async fn handle_local_addr<'a>(
  400. &mut self,
  401. la: &LocalAddr,
  402. addr: SocketAddr,
  403. socket: Option<&'a mut FramedSocket>,
  404. ) -> ResultType<()> {
  405. // relay local addrs of B to A
  406. let addr_a = AddrMangle::decode(&la.socket_addr);
  407. log::debug!(
  408. "{} local addrs response to {:?} from {:?}",
  409. if socket.is_none() { "TCP" } else { "UDP" },
  410. &addr_a,
  411. &addr
  412. );
  413. let mut msg_out = RendezvousMessage::new();
  414. msg_out.set_punch_hole_response(PunchHoleResponse {
  415. socket_addr: la.local_addr.clone(),
  416. ..Default::default()
  417. });
  418. if let Some(socket) = socket {
  419. socket.send(&msg_out, addr_a).await?;
  420. } else {
  421. self.send_to_tcp(&msg_out, addr_a).await;
  422. }
  423. Ok(())
  424. }
  425. #[inline]
  426. async fn handle_punch_hole_request(
  427. &mut self,
  428. addr: SocketAddr,
  429. id: String,
  430. ) -> ResultType<(RendezvousMessage, Option<SocketAddr>)> {
  431. // punch hole request from A, relay to B,
  432. // check if in same intranet first,
  433. // fetch local addrs if in same intranet.
  434. // because punch hole won't work if in the same intranet,
  435. // all routers will drop such self-connections.
  436. if let Some(peer) = self.pm.get(&id).await {
  437. if peer.last_reg_time.elapsed().as_millis() as i32 >= REG_TIMEOUT {
  438. let mut msg_out = RendezvousMessage::new();
  439. msg_out.set_punch_hole_response(PunchHoleResponse {
  440. failure: punch_hole_response::Failure::OFFLINE.into(),
  441. ..Default::default()
  442. });
  443. return Ok((msg_out, None));
  444. }
  445. let mut msg_out = RendezvousMessage::new();
  446. let same_intranet = match peer.socket_addr {
  447. SocketAddr::V4(a) => match addr {
  448. SocketAddr::V4(b) => a.ip() == b.ip(),
  449. _ => false,
  450. },
  451. SocketAddr::V6(a) => match addr {
  452. SocketAddr::V6(b) => a.ip() == b.ip(),
  453. _ => false,
  454. },
  455. };
  456. let socket_addr = AddrMangle::encode(addr);
  457. if same_intranet {
  458. log::debug!(
  459. "Fetch local addr {:?} {:?} request from {:?}",
  460. id,
  461. &peer.socket_addr,
  462. &addr
  463. );
  464. msg_out.set_fetch_local_addr(FetchLocalAddr {
  465. socket_addr,
  466. ..Default::default()
  467. });
  468. } else {
  469. log::debug!(
  470. "Punch hole {:?} {:?} request from {:?}",
  471. id,
  472. &peer.socket_addr,
  473. &addr
  474. );
  475. msg_out.set_punch_hole(PunchHole {
  476. socket_addr,
  477. ..Default::default()
  478. });
  479. }
  480. return Ok((msg_out, Some(peer.socket_addr)));
  481. } else {
  482. let mut msg_out = RendezvousMessage::new();
  483. msg_out.set_punch_hole_response(PunchHoleResponse {
  484. failure: punch_hole_response::Failure::ID_NOT_EXIST.into(),
  485. ..Default::default()
  486. });
  487. return Ok((msg_out, None));
  488. }
  489. }
  490. #[inline]
  491. async fn send_to_tcp(&mut self, msg: &RendezvousMessage, addr: SocketAddr) {
  492. let tcp = self.tcp_punch.lock().unwrap().remove(&addr);
  493. if let Some(mut tcp) = tcp {
  494. if let Ok(bytes) = msg.write_to_bytes() {
  495. tokio::spawn(async move {
  496. allow_err!(tcp.send(Bytes::from(bytes)).await);
  497. });
  498. }
  499. }
  500. }
  501. #[inline]
  502. async fn send_to_tcp_sync(
  503. &mut self,
  504. msg: &RendezvousMessage,
  505. addr: SocketAddr,
  506. ) -> ResultType<()> {
  507. let tcp = self.tcp_punch.lock().unwrap().remove(&addr);
  508. if let Some(mut tcp) = tcp {
  509. if let Ok(bytes) = msg.write_to_bytes() {
  510. tcp.send(Bytes::from(bytes)).await?;
  511. }
  512. }
  513. Ok(())
  514. }
  515. #[inline]
  516. async fn handle_tcp_punch_hole_request(
  517. &mut self,
  518. addr: SocketAddr,
  519. id: String,
  520. ) -> ResultType<()> {
  521. let (msg, to_addr) = self.handle_punch_hole_request(addr, id).await?;
  522. if let Some(addr) = to_addr {
  523. self.tx.send((msg, addr))?;
  524. } else {
  525. self.send_to_tcp_sync(&msg, addr).await?;
  526. }
  527. Ok(())
  528. }
  529. #[inline]
  530. async fn handle_udp_punch_hole_request(
  531. &mut self,
  532. addr: SocketAddr,
  533. id: String,
  534. ) -> ResultType<()> {
  535. let (msg, to_addr) = self.handle_punch_hole_request(addr, id).await?;
  536. self.tx.send((
  537. msg,
  538. match to_addr {
  539. Some(addr) => addr,
  540. None => addr,
  541. },
  542. ))?;
  543. Ok(())
  544. }
  545. }