Browse Source

refactor for preparing sled

open-trade 5 years ago
parent
commit
1f4f1cc8e2
1 changed files with 80 additions and 38 deletions
  1. 80 38
      src/rendezvous_server.rs

+ 80 - 38
src/rendezvous_server.rs

@@ -19,34 +19,38 @@ use hbb_common::{
19
 use std::{
19
 use std::{
20
     collections::HashMap,
20
     collections::HashMap,
21
     net::SocketAddr,
21
     net::SocketAddr,
22
-    sync::{Arc, Mutex},
22
+    sync::{Arc, Mutex, RwLock},
23
     time::Instant,
23
     time::Instant,
24
 };
24
 };
25
 
25
 
26
+#[derive(Clone)]
26
 struct Peer {
27
 struct Peer {
27
     socket_addr: SocketAddr,
28
     socket_addr: SocketAddr,
28
     last_reg_time: Instant,
29
     last_reg_time: Instant,
29
 }
30
 }
30
 
31
 
32
+#[derive(Clone)]
31
 struct PeerMap {
33
 struct PeerMap {
32
-    map: HashMap<String, Peer>,
34
+    map: Arc<RwLock<HashMap<String, Peer>>>,
33
     db: sled::Db,
35
     db: sled::Db,
34
 }
36
 }
35
 
37
 
36
 impl PeerMap {
38
 impl PeerMap {
37
     fn new() -> ResultType<Self> {
39
     fn new() -> ResultType<Self> {
38
         Ok(Self {
40
         Ok(Self {
39
-            map: HashMap::new(),
41
+            map: Default::default(),
40
             db: sled::open("./sled.db")?,
42
             db: sled::open("./sled.db")?,
41
         })
43
         })
42
     }
44
     }
43
 
45
 
46
+    #[inline]
44
     fn insert(&mut self, key: String, peer: Peer) {
47
     fn insert(&mut self, key: String, peer: Peer) {
45
-        self.map.insert(key, peer);
48
+        self.map.write().unwrap().insert(key, peer);
46
     }
49
     }
47
 
50
 
48
-    fn get(&self, key: &str) -> Option<&Peer> {
49
-        self.map.get(key)
51
+    #[inline]
52
+    fn get(&self, key: &str) -> Option<Peer> {
53
+        self.map.read().unwrap().get(key).map(|x| x.clone())
50
     }
54
     }
51
 }
55
 }
52
 
56
 
@@ -56,38 +60,40 @@ type Sink = SplitSink<Framed<TcpStream, BytesCodec>, Bytes>;
56
 #[derive(Clone)]
60
 #[derive(Clone)]
57
 pub struct RendezvousServer {
61
 pub struct RendezvousServer {
58
     tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>,
62
     tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>,
63
+    pm: PeerMap,
64
+    tx: mpsc::UnboundedSender<(RendezvousMessage, SocketAddr)>,
59
 }
65
 }
60
 
66
 
61
 impl RendezvousServer {
67
 impl RendezvousServer {
62
     pub async fn start(addr: &str) -> ResultType<()> {
68
     pub async fn start(addr: &str) -> ResultType<()> {
63
-        let mut pm = PeerMap::new()?;
64
         let mut socket = FramedSocket::new(addr).await?;
69
         let mut socket = FramedSocket::new(addr).await?;
70
+        let (tx, mut rx) = mpsc::unbounded_channel::<(RendezvousMessage, SocketAddr)>();
65
         let mut rs = Self {
71
         let mut rs = Self {
66
             tcp_punch: Arc::new(Mutex::new(HashMap::new())),
72
             tcp_punch: Arc::new(Mutex::new(HashMap::new())),
73
+            pm: PeerMap::new()?,
74
+            tx: tx.clone(),
67
         };
75
         };
68
-        let (tx, mut rx) = mpsc::unbounded_channel::<(SocketAddr, String)>();
69
         let mut listener = new_listener(addr, true).await?;
76
         let mut listener = new_listener(addr, true).await?;
70
         loop {
77
         loop {
71
             tokio::select! {
78
             tokio::select! {
72
-                Some((addr, id)) = rx.recv() => {
73
-                    allow_err!(rs.handle_punch_hole_request(addr, &id, &mut socket, true, &pm).await);
79
+                Some((msg, addr)) = rx.recv() => {
80
+                    allow_err!(socket.send(&msg, addr).await);
74
                 }
81
                 }
75
                 Some(Ok((bytes, addr))) = socket.next() => {
82
                 Some(Ok((bytes, addr))) = socket.next() => {
76
-                    allow_err!(rs.handle_msg(&bytes, addr, &mut socket, &mut pm).await);
83
+                    allow_err!(rs.handle_msg(&bytes, addr, &mut socket).await);
77
                 }
84
                 }
78
                 Ok((stream, addr)) = listener.accept() => {
85
                 Ok((stream, addr)) = listener.accept() => {
79
                     log::debug!("Tcp connection from {:?}", addr);
86
                     log::debug!("Tcp connection from {:?}", addr);
80
                     let (a, mut b) = Framed::new(stream, BytesCodec::new()).split();
87
                     let (a, mut b) = Framed::new(stream, BytesCodec::new()).split();
81
                     let tcp_punch = rs.tcp_punch.clone();
88
                     let tcp_punch = rs.tcp_punch.clone();
82
                     tcp_punch.lock().unwrap().insert(addr, a);
89
                     tcp_punch.lock().unwrap().insert(addr, a);
83
-                    let tx = tx.clone();
84
                     let mut rs = rs.clone();
90
                     let mut rs = rs.clone();
85
                     tokio::spawn(async move {
91
                     tokio::spawn(async move {
86
                         while let Some(Ok(bytes)) = b.next().await {
92
                         while let Some(Ok(bytes)) = b.next().await {
87
                             if let Ok(msg_in) = parse_from_bytes::<RendezvousMessage>(&bytes) {
93
                             if let Ok(msg_in) = parse_from_bytes::<RendezvousMessage>(&bytes) {
88
                                 match msg_in.union {
94
                                 match msg_in.union {
89
                                     Some(rendezvous_message::Union::punch_hole_request(ph)) => {
95
                                     Some(rendezvous_message::Union::punch_hole_request(ph)) => {
90
-                                        allow_err!(tx.send((addr, ph.id)));
96
+                                        allow_err!(rs.handle_tcp_punch_hole_request(addr, &ph.id).await);
91
                                     }
97
                                     }
92
                                     Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
98
                                     Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
93
                                         allow_err!(rs.handle_hole_sent(&phs, addr, None).await);
99
                                         allow_err!(rs.handle_hole_sent(&phs, addr, None).await);
@@ -109,12 +115,12 @@ impl RendezvousServer {
109
         }
115
         }
110
     }
116
     }
111
 
117
 
118
+    #[inline]
112
     async fn handle_msg(
119
     async fn handle_msg(
113
         &mut self,
120
         &mut self,
114
         bytes: &BytesMut,
121
         bytes: &BytesMut,
115
         addr: SocketAddr,
122
         addr: SocketAddr,
116
         socket: &mut FramedSocket,
123
         socket: &mut FramedSocket,
117
-        pm: &mut PeerMap,
118
     ) -> ResultType<()> {
124
     ) -> ResultType<()> {
119
         if let Ok(msg_in) = parse_from_bytes::<RendezvousMessage>(&bytes) {
125
         if let Ok(msg_in) = parse_from_bytes::<RendezvousMessage>(&bytes) {
120
             match msg_in.union {
126
             match msg_in.union {
@@ -122,7 +128,7 @@ impl RendezvousServer {
122
                     // B registered
128
                     // B registered
123
                     if rp.id.len() > 0 {
129
                     if rp.id.len() > 0 {
124
                         log::debug!("New peer registered: {:?} {:?}", &rp.id, &addr);
130
                         log::debug!("New peer registered: {:?} {:?}", &rp.id, &addr);
125
-                        pm.insert(
131
+                        self.pm.insert(
126
                             rp.id,
132
                             rp.id,
127
                             Peer {
133
                             Peer {
128
                                 socket_addr: addr,
134
                                 socket_addr: addr,
@@ -135,8 +141,7 @@ impl RendezvousServer {
135
                     }
141
                     }
136
                 }
142
                 }
137
                 Some(rendezvous_message::Union::punch_hole_request(ph)) => {
143
                 Some(rendezvous_message::Union::punch_hole_request(ph)) => {
138
-                    self.handle_punch_hole_request(addr, &ph.id, socket, false, &pm)
139
-                        .await?;
144
+                    self.handle_udp_punch_hole_request(addr, &ph.id).await?;
140
                 }
145
                 }
141
                 Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
146
                 Some(rendezvous_message::Union::punch_hole_sent(phs)) => {
142
                     self.handle_hole_sent(&phs, addr, Some(socket)).await?;
147
                     self.handle_hole_sent(&phs, addr, Some(socket)).await?;
@@ -150,6 +155,7 @@ impl RendezvousServer {
150
         Ok(())
155
         Ok(())
151
     }
156
     }
152
 
157
 
158
+    #[inline]
153
     async fn handle_hole_sent<'a>(
159
     async fn handle_hole_sent<'a>(
154
         &mut self,
160
         &mut self,
155
         phs: &PunchHoleSent,
161
         phs: &PunchHoleSent,
@@ -172,11 +178,12 @@ impl RendezvousServer {
172
         if let Some(socket) = socket {
178
         if let Some(socket) = socket {
173
             socket.send(&msg_out, addr_a).await?;
179
             socket.send(&msg_out, addr_a).await?;
174
         } else {
180
         } else {
175
-            self.send_to_tcp(&msg_out, addr_a).await?;
181
+            self.send_to_tcp(&msg_out, addr_a).await;
176
         }
182
         }
177
         Ok(())
183
         Ok(())
178
     }
184
     }
179
 
185
 
186
+    #[inline]
180
     async fn handle_local_addr<'a>(
187
     async fn handle_local_addr<'a>(
181
         &mut self,
188
         &mut self,
182
         la: &LocalAddr,
189
         la: &LocalAddr,
@@ -199,36 +206,30 @@ impl RendezvousServer {
199
         if let Some(socket) = socket {
206
         if let Some(socket) = socket {
200
             socket.send(&msg_out, addr_a).await?;
207
             socket.send(&msg_out, addr_a).await?;
201
         } else {
208
         } else {
202
-            self.send_to_tcp(&msg_out, addr_a).await?;
209
+            self.send_to_tcp(&msg_out, addr_a).await;
203
         }
210
         }
204
         Ok(())
211
         Ok(())
205
     }
212
     }
206
 
213
 
214
+    #[inline]
207
     async fn handle_punch_hole_request(
215
     async fn handle_punch_hole_request(
208
         &mut self,
216
         &mut self,
209
         addr: SocketAddr,
217
         addr: SocketAddr,
210
         id: &str,
218
         id: &str,
211
-        socket: &mut FramedSocket,
212
-        is_tcp: bool,
213
-        pm: &PeerMap,
214
-    ) -> ResultType<()> {
219
+    ) -> ResultType<(RendezvousMessage, Option<SocketAddr>)> {
215
         // punch hole request from A, forward to B,
220
         // punch hole request from A, forward to B,
216
         // check if in same intranet first,
221
         // check if in same intranet first,
217
         // fetch local addrs if in same intranet.
222
         // fetch local addrs if in same intranet.
218
         // because punch hole won't work if in the same intranet,
223
         // because punch hole won't work if in the same intranet,
219
         // all routers will drop such self-connections.
224
         // all routers will drop such self-connections.
220
-        if let Some(peer) = pm.get(id) {
225
+        if let Some(peer) = self.pm.get(id) {
221
             if peer.last_reg_time.elapsed().as_millis() as i32 >= REG_TIMEOUT {
226
             if peer.last_reg_time.elapsed().as_millis() as i32 >= REG_TIMEOUT {
222
                 let mut msg_out = RendezvousMessage::new();
227
                 let mut msg_out = RendezvousMessage::new();
223
                 msg_out.set_punch_hole_response(PunchHoleResponse {
228
                 msg_out.set_punch_hole_response(PunchHoleResponse {
224
                     failure: punch_hole_response::Failure::OFFLINE.into(),
229
                     failure: punch_hole_response::Failure::OFFLINE.into(),
225
                     ..Default::default()
230
                     ..Default::default()
226
                 });
231
                 });
227
-                return if is_tcp {
228
-                    self.send_to_tcp(&msg_out, addr).await
229
-                } else {
230
-                    socket.send(&msg_out, addr).await
231
-                };
232
+                return Ok((msg_out, None));
232
             }
233
             }
233
             let mut msg_out = RendezvousMessage::new();
234
             let mut msg_out = RendezvousMessage::new();
234
             let same_intranet = match peer.socket_addr {
235
             let same_intranet = match peer.socket_addr {
@@ -265,32 +266,73 @@ impl RendezvousServer {
265
                     ..Default::default()
266
                     ..Default::default()
266
                 });
267
                 });
267
             }
268
             }
268
-            socket.send(&msg_out, peer.socket_addr).await?;
269
+            return Ok((msg_out, Some(peer.socket_addr)));
269
         } else {
270
         } else {
270
             let mut msg_out = RendezvousMessage::new();
271
             let mut msg_out = RendezvousMessage::new();
271
             msg_out.set_punch_hole_response(PunchHoleResponse {
272
             msg_out.set_punch_hole_response(PunchHoleResponse {
272
                 failure: punch_hole_response::Failure::ID_NOT_EXIST.into(),
273
                 failure: punch_hole_response::Failure::ID_NOT_EXIST.into(),
273
                 ..Default::default()
274
                 ..Default::default()
274
             });
275
             });
275
-            return if is_tcp {
276
-                self.send_to_tcp(&msg_out, addr).await
277
-            } else {
278
-                socket.send(&msg_out, addr).await
279
-            };
276
+            return Ok((msg_out, None));
280
         }
277
         }
281
-        Ok(())
282
     }
278
     }
283
 
279
 
284
-    async fn send_to_tcp(&mut self, msg: &RendezvousMessage, addr: SocketAddr) -> ResultType<()> {
280
+    #[inline]
281
+    async fn send_to_tcp(&mut self, msg: &RendezvousMessage, addr: SocketAddr) {
285
         let tcp = self.tcp_punch.lock().unwrap().remove(&addr);
282
         let tcp = self.tcp_punch.lock().unwrap().remove(&addr);
286
         if let Some(mut tcp) = tcp {
283
         if let Some(mut tcp) = tcp {
287
             if let Ok(bytes) = msg.write_to_bytes() {
284
             if let Ok(bytes) = msg.write_to_bytes() {
288
                 tokio::spawn(async move {
285
                 tokio::spawn(async move {
289
                     allow_err!(tcp.send(Bytes::from(bytes)).await);
286
                     allow_err!(tcp.send(Bytes::from(bytes)).await);
290
-                    log::debug!("Send punch hole to {} via tcp", addr);
291
                 });
287
                 });
292
             }
288
             }
293
         }
289
         }
290
+    }
291
+
292
+    #[inline]
293
+    async fn send_to_tcp_sync(
294
+        &mut self,
295
+        msg: &RendezvousMessage,
296
+        addr: SocketAddr,
297
+    ) -> ResultType<()> {
298
+        let tcp = self.tcp_punch.lock().unwrap().remove(&addr);
299
+        if let Some(mut tcp) = tcp {
300
+            if let Ok(bytes) = msg.write_to_bytes() {
301
+                tcp.send(Bytes::from(bytes)).await?;
302
+            }
303
+        }
304
+        Ok(())
305
+    }
306
+
307
+    #[inline]
308
+    async fn handle_tcp_punch_hole_request(
309
+        &mut self,
310
+        addr: SocketAddr,
311
+        id: &str,
312
+    ) -> ResultType<()> {
313
+        let (msg, to_addr) = self.handle_punch_hole_request(addr, id).await?;
314
+        if let Some(addr) = to_addr {
315
+            self.tx.send((msg, addr))?;
316
+        } else {
317
+            self.send_to_tcp_sync(&msg, addr).await?;
318
+        }
319
+        Ok(())
320
+    }
321
+
322
+    #[inline]
323
+    async fn handle_udp_punch_hole_request(
324
+        &mut self,
325
+        addr: SocketAddr,
326
+        id: &str,
327
+    ) -> ResultType<()> {
328
+        let (msg, to_addr) = self.handle_punch_hole_request(addr, id).await?;
329
+        self.tx.send((
330
+            msg,
331
+            match to_addr {
332
+                Some(addr) => addr,
333
+                None => addr,
334
+            },
335
+        ))?;
294
         Ok(())
336
         Ok(())
295
     }
337
     }
296
 }
338
 }