open-trade лет назад: 5
Родитель
Сommit
d16fb31ecf
3 измененных файлов с 86 добавлено и 50 удалено
  1. 1 1
      libs/hbb_common
  2. 82 46
      src/rendezvous_server.rs
  3. 3 3
      src/sled_async.rs

+ 1 - 1
libs/hbb_common

@@ -1 +1 @@
1
-Subproject commit 5e977383041f35f856ee2dfc5fcf07c8300e2da5
1
+Subproject commit c86ebe2402ee1f092ce8e44e89368234708a766f

+ 82 - 46
src/rendezvous_server.rs

@@ -65,38 +65,10 @@ impl PeerMap {
65
     }
65
     }
66
 
66
 
67
     #[inline]
67
     #[inline]
68
-    async fn update_addr(&mut self, key: String, socket_addr: SocketAddr) {
69
-        let mut lock = self.map.write().unwrap();
70
-        let last_reg_time = Instant::now();
71
-        if let Some(old) = lock.get_mut(&key) {
72
-            old.socket_addr = socket_addr;
73
-            old.last_reg_time = last_reg_time;
74
-        } else {
75
-            let mut me = self.clone();
76
-            tokio::spawn(async move {
77
-                let v = me.db.get(key.clone()).await;
78
-                let pk = if let Some(v) = super::SledAsync::deserialize::<PeerSerde>(&v) {
79
-                    v.pk
80
-                } else {
81
-                    Vec::new()
82
-                };
83
-                me.map.write().unwrap().insert(
84
-                    key,
85
-                    Peer {
86
-                        socket_addr,
87
-                        last_reg_time,
88
-                        pk,
89
-                    },
90
-                );
91
-            });
92
-        }
93
-    }
94
-
95
-    #[inline]
96
-    fn update_key(&mut self, key: String, socket_addr: SocketAddr, pk: Vec<u8>) {
68
+    fn update_pk(&mut self, id: String, socket_addr: SocketAddr, pk: Vec<u8>) {
97
         let mut lock = self.map.write().unwrap();
69
         let mut lock = self.map.write().unwrap();
98
         lock.insert(
70
         lock.insert(
99
-            key.clone(),
71
+            id.clone(),
100
             Peer {
72
             Peer {
101
                 socket_addr,
73
                 socket_addr,
102
                 last_reg_time: Instant::now(),
74
                 last_reg_time: Instant::now(),
@@ -104,19 +76,20 @@ impl PeerMap {
104
             },
76
             },
105
         );
77
         );
106
         let ip = socket_addr.ip().to_string();
78
         let ip = socket_addr.ip().to_string();
107
-        self.db.insert(key, PeerSerde { ip, pk });
79
+        self.db.insert(id, PeerSerde { ip, pk });
108
     }
80
     }
109
 
81
 
110
     #[inline]
82
     #[inline]
111
-    async fn get(&mut self, key: String) -> Option<Peer> {
112
-        let p = self.map.read().unwrap().get(&key).map(|x| x.clone());
83
+    async fn get(&mut self, id: &str) -> Option<Peer> {
84
+        let p = self.map.read().unwrap().get(id).map(|x| x.clone());
113
         if p.is_some() {
85
         if p.is_some() {
114
             return p;
86
             return p;
115
         } else {
87
         } else {
116
-            let v = self.db.get(key.clone()).await;
88
+            let id = id.to_owned();
89
+            let v = self.db.get(id.clone()).await;
117
             if let Some(v) = super::SledAsync::deserialize::<PeerSerde>(&v) {
90
             if let Some(v) = super::SledAsync::deserialize::<PeerSerde>(&v) {
118
                 self.map.write().unwrap().insert(
91
                 self.map.write().unwrap().insert(
119
-                    key,
92
+                    id,
120
                     Peer {
93
                     Peer {
121
                         pk: v.pk,
94
                         pk: v.pk,
122
                         ..Default::default()
95
                         ..Default::default()
@@ -129,19 +102,20 @@ impl PeerMap {
129
     }
102
     }
130
 
103
 
131
     #[inline]
104
     #[inline]
132
-    fn is_in_memory(&self, key: &str) -> bool {
133
-        self.map.read().unwrap().contains_key(key)
105
+    fn is_in_memory(&self, id: &str) -> bool {
106
+        self.map.read().unwrap().contains_key(id)
134
     }
107
     }
135
 }
108
 }
136
 
109
 
137
 const REG_TIMEOUT: i32 = 30_000;
110
 const REG_TIMEOUT: i32 = 30_000;
138
 type Sink = SplitSink<Framed<TcpStream, BytesCodec>, Bytes>;
111
 type Sink = SplitSink<Framed<TcpStream, BytesCodec>, Bytes>;
112
+type Sender = mpsc::UnboundedSender<(RendezvousMessage, SocketAddr)>;
139
 
113
 
140
 #[derive(Clone)]
114
 #[derive(Clone)]
141
 pub struct RendezvousServer {
115
 pub struct RendezvousServer {
142
     tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>,
116
     tcp_punch: Arc<Mutex<HashMap<SocketAddr, Sink>>>,
143
     pm: PeerMap,
117
     pm: PeerMap,
144
-    tx: mpsc::UnboundedSender<(RendezvousMessage, SocketAddr)>,
118
+    tx: Sender,
145
 }
119
 }
146
 
120
 
147
 impl RendezvousServer {
121
 impl RendezvousServer {
@@ -208,19 +182,29 @@ impl RendezvousServer {
208
                     // B registered
182
                     // B registered
209
                     if rp.id.len() > 0 {
183
                     if rp.id.len() > 0 {
210
                         log::debug!("New peer registered: {:?} {:?}", &rp.id, &addr);
184
                         log::debug!("New peer registered: {:?} {:?}", &rp.id, &addr);
211
-                        self.pm.update_addr(rp.id, addr).await;
212
-                        let mut msg_out = RendezvousMessage::new();
213
-                        msg_out.set_register_peer_response(RegisterPeerResponse::default());
214
-                        socket.send(&msg_out, addr).await?
185
+                        self.update_addr(rp.id, addr, socket).await?;
215
                     }
186
                     }
216
                 }
187
                 }
217
-                Some(rendezvous_message::Union::register_key(rk)) => {
188
+                Some(rendezvous_message::Union::register_pk(rk)) => {
218
                     let id = rk.id;
189
                     let id = rk.id;
219
-                    if let Some(peer) = self.pm.get(id.clone()).await {
190
+                    let mut res = register_pk_response::Result::OK;
191
+                    if let Some(peer) = self.pm.get(&id).await {
220
                         if peer.pk.is_empty() {
192
                         if peer.pk.is_empty() {
221
-                            self.pm.update_key(id, addr, rk.key);
193
+                            self.pm.update_pk(id, addr, rk.pk);
194
+                        } else {
195
+                            if peer.pk != rk.pk {
196
+                                res = register_pk_response::Result::PK_MISMATCH;
197
+                            }
222
                         }
198
                         }
199
+                    } else {
200
+                        self.pm.update_pk(id, addr, rk.pk);
223
                     }
201
                     }
202
+                    let mut msg_out = RendezvousMessage::new();
203
+                    msg_out.set_register_pk_response(RegisterPkResponse {
204
+                        result: res.into(),
205
+                        ..Default::default()
206
+                    });
207
+                    socket.send(&msg_out, addr).await?
224
                 }
208
                 }
225
                 Some(rendezvous_message::Union::punch_hole_request(ph)) => {
209
                 Some(rendezvous_message::Union::punch_hole_request(ph)) => {
226
                     let id = ph.id;
210
                     let id = ph.id;
@@ -249,6 +233,58 @@ impl RendezvousServer {
249
         Ok(())
233
         Ok(())
250
     }
234
     }
251
 
235
 
236
+    #[inline]
237
+    async fn update_addr(
238
+        &mut self,
239
+        id: String,
240
+        socket_addr: SocketAddr,
241
+        socket: &mut FramedSocket,
242
+    ) -> ResultType<()> {
243
+        let mut lock = self.pm.map.write().unwrap();
244
+        let last_reg_time = Instant::now();
245
+        if let Some(old) = lock.get_mut(&id) {
246
+            old.socket_addr = socket_addr;
247
+            old.last_reg_time = last_reg_time;
248
+            let request_pk = old.pk.is_empty();
249
+            drop(lock);
250
+            let mut msg_out = RendezvousMessage::new();
251
+            msg_out.set_register_peer_response(RegisterPeerResponse {
252
+                request_pk,
253
+                ..Default::default()
254
+            });
255
+            socket.send(&msg_out, socket_addr).await?;
256
+        } else {
257
+            drop(lock);
258
+            let mut pm = self.pm.clone();
259
+            let tx = self.tx.clone();
260
+            tokio::spawn(async move {
261
+                let v = pm.db.get(id.clone()).await;
262
+                let pk = {
263
+                    if let Some(v) = super::SledAsync::deserialize::<PeerSerde>(&v) {
264
+                        v.pk
265
+                    } else {
266
+                        Vec::new()
267
+                    }
268
+                };
269
+                let mut msg_out = RendezvousMessage::new();
270
+                msg_out.set_register_peer_response(RegisterPeerResponse {
271
+                    request_pk: pk.is_empty(),
272
+                    ..Default::default()
273
+                });
274
+                tx.send((msg_out, socket_addr)).ok();
275
+                pm.map.write().unwrap().insert(
276
+                    id,
277
+                    Peer {
278
+                        socket_addr,
279
+                        last_reg_time,
280
+                        pk,
281
+                    },
282
+                );
283
+            });
284
+        }
285
+        Ok(())
286
+    }
287
+
252
     #[inline]
288
     #[inline]
253
     async fn handle_hole_sent<'a>(
289
     async fn handle_hole_sent<'a>(
254
         &mut self,
290
         &mut self,
@@ -316,7 +352,7 @@ impl RendezvousServer {
316
         // fetch local addrs if in same intranet.
352
         // fetch local addrs if in same intranet.
317
         // because punch hole won't work if in the same intranet,
353
         // because punch hole won't work if in the same intranet,
318
         // all routers will drop such self-connections.
354
         // all routers will drop such self-connections.
319
-        if let Some(peer) = self.pm.get(id.clone()).await {
355
+        if let Some(peer) = self.pm.get(&id).await {
320
             if peer.last_reg_time.elapsed().as_millis() as i32 >= REG_TIMEOUT {
356
             if peer.last_reg_time.elapsed().as_millis() as i32 >= REG_TIMEOUT {
321
                 let mut msg_out = RendezvousMessage::new();
357
                 let mut msg_out = RendezvousMessage::new();
322
                 msg_out.set_punch_hole_response(PunchHoleResponse {
358
                 msg_out.set_punch_hole_response(PunchHoleResponse {

+ 3 - 3
src/sled_async.rs

@@ -8,7 +8,7 @@ use hbb_common::{
8
 enum Action {
8
 enum Action {
9
     Insert((String, Vec<u8>)),
9
     Insert((String, Vec<u8>)),
10
     Get((String, mpsc::Sender<Option<sled::IVec>>)),
10
     Get((String, mpsc::Sender<Option<sled::IVec>>)),
11
-    Close,
11
+    _Close,
12
 }
12
 }
13
 
13
 
14
 #[derive(Clone)]
14
 #[derive(Clone)]
@@ -55,14 +55,14 @@ impl SledAsync {
55
                             .await
55
                             .await
56
                     );
56
                     );
57
                 }
57
                 }
58
-                Action::Close => break,
58
+                Action::_Close => break,
59
             }
59
             }
60
         }
60
         }
61
     }
61
     }
62
 
62
 
63
     pub fn _close(self, j: std::thread::JoinHandle<()>) {
63
     pub fn _close(self, j: std::thread::JoinHandle<()>) {
64
         if let Some(tx) = &self.tx {
64
         if let Some(tx) = &self.tx {
65
-            allow_err!(tx.send(Action::Close));
65
+            allow_err!(tx.send(Action::_Close));
66
         }
66
         }
67
         allow_err!(j.join());
67
         allow_err!(j.join());
68
     }
68
     }