Просмотр исходного кода

Merge branch 'master' into fix-clippy-warning

RustDesk лет назад: 3
Родитель
Сommit
0e01cfcd3a
3 измененных файлов с 52 добавлено и 48 удалено
  1. 12 22
      libs/hbb_common/src/udp.rs
  2. 6 11
      src/relay_server.rs
  3. 34 15
      src/rendezvous_server.rs

+ 12 - 22
libs/hbb_common/src/udp.rs

@@ -5,7 +5,7 @@ use futures::{SinkExt, StreamExt};
5
 use protobuf::Message;
5
 use protobuf::Message;
6
 use socket2::{Domain, Socket, Type};
6
 use socket2::{Domain, Socket, Type};
7
 use std::net::SocketAddr;
7
 use std::net::SocketAddr;
8
-use tokio::net::{ToSocketAddrs, UdpSocket};
8
+use tokio::net::{lookup_host, ToSocketAddrs, UdpSocket};
9
 use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs};
9
 use tokio_socks::{udp::Socks5UdpFramed, IntoTargetAddr, TargetAddr, ToProxyAddrs};
10
 use tokio_util::{codec::BytesCodec, udp::UdpFramed};
10
 use tokio_util::{codec::BytesCodec, udp::UdpFramed};
11
 
11
 
@@ -37,38 +37,28 @@ fn new_socket(addr: SocketAddr, reuse: bool, buf_size: usize) -> Result<Socket,
37
         addr,
37
         addr,
38
         socket.recv_buffer_size()
38
         socket.recv_buffer_size()
39
     );
39
     );
40
+    if addr.is_ipv6() && addr.ip().is_unspecified() && addr.port() > 0 {
41
+        socket.set_only_v6(false).ok();
42
+    }
40
     socket.bind(&addr.into())?;
43
     socket.bind(&addr.into())?;
41
     Ok(socket)
44
     Ok(socket)
42
 }
45
 }
43
 
46
 
44
 impl FramedSocket {
47
 impl FramedSocket {
45
     pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> {
48
     pub async fn new<T: ToSocketAddrs>(addr: T) -> ResultType<Self> {
46
-        let socket = UdpSocket::bind(addr).await?;
47
-        Ok(Self::Direct(UdpFramed::new(socket, BytesCodec::new())))
48
-    }
49
-
50
-    pub async fn new_reuse<T: std::net::ToSocketAddrs>(addr: T) -> ResultType<Self> {
51
-        let addr = addr
52
-            .to_socket_addrs()?
53
-            .next()
54
-            .context("could not resolve to any address")?;
55
-        let socket = new_socket(addr, true, 0)?.into_udp_socket();
56
-        Ok(Self::Direct(UdpFramed::new(
57
-            UdpSocket::from_std(socket)?,
58
-            BytesCodec::new(),
59
-        )))
49
+        Self::new_reuse(addr, false, 0).await
60
     }
50
     }
61
 
51
 
62
-    pub async fn new_with_buf_size<T: std::net::ToSocketAddrs>(
52
+    pub async fn new_reuse<T: ToSocketAddrs>(
63
         addr: T,
53
         addr: T,
54
+        reuse: bool,
64
         buf_size: usize,
55
         buf_size: usize,
65
     ) -> ResultType<Self> {
56
     ) -> ResultType<Self> {
66
-        let addr = addr
67
-            .to_socket_addrs()?
57
+        let addr = lookup_host(&addr).await?
68
             .next()
58
             .next()
69
             .context("could not resolve to any address")?;
59
             .context("could not resolve to any address")?;
70
         Ok(Self::Direct(UdpFramed::new(
60
         Ok(Self::Direct(UdpFramed::new(
71
-            UdpSocket::from_std(new_socket(addr, false, buf_size)?.into_udp_socket())?,
61
+            UdpSocket::from_std(new_socket(addr, reuse, buf_size)?.into_udp_socket())?,
72
             BytesCodec::new(),
62
             BytesCodec::new(),
73
         )))
63
         )))
74
     }
64
     }
@@ -168,12 +158,12 @@ impl FramedSocket {
168
         }
158
         }
169
     }
159
     }
170
 
160
 
171
-    pub fn is_ipv4(&self) -> bool {
161
+    pub fn local_addr(&self) -> Option<SocketAddr> {
172
         if let FramedSocket::Direct(x) = self {
162
         if let FramedSocket::Direct(x) = self {
173
             if let Ok(v) = x.get_ref().local_addr() {
163
             if let Ok(v) = x.get_ref().local_addr() {
174
-                return v.is_ipv4();
164
+                return Some(v);
175
             }
165
             }
176
         }
166
         }
177
-        true
167
+        None
178
     }
168
     }
179
 }
169
 }

+ 6 - 11
src/relay_server.rs

@@ -8,7 +8,7 @@ use hbb_common::{
8
     protobuf::Message as _,
8
     protobuf::Message as _,
9
     rendezvous_proto::*,
9
     rendezvous_proto::*,
10
     sleep,
10
     sleep,
11
-    tcp::{new_listener, FramedStream},
11
+    tcp::{listen_any, FramedStream},
12
     timeout,
12
     timeout,
13
     tokio::{
13
     tokio::{
14
         self,
14
         self,
@@ -77,19 +77,14 @@ pub async fn start(port: &str, key: &str) -> ResultType<()> {
77
         BLOCKLIST_FILE,
77
         BLOCKLIST_FILE,
78
         BLOCKLIST.read().await.len()
78
         BLOCKLIST.read().await.len()
79
     );
79
     );
80
-    let addr = format!("0.0.0.0:{}", port);
81
-    log::info!("Listening on tcp {}", addr);
82
-    let addr2 = format!("0.0.0.0:{}", port.parse::<u16>().unwrap() + 2);
83
-    log::info!("Listening on websocket {}", addr2);
80
+    let port: u16 = port.parse()?;
81
+    log::info!("Listening on tcp :{}", port);
82
+    let port2 = port + 2;
83
+    log::info!("Listening on websocket :{}", port2);
84
     let main_task = async move {
84
     let main_task = async move {
85
         loop {
85
         loop {
86
             log::info!("Start");
86
             log::info!("Start");
87
-            io_loop(
88
-                new_listener(&addr, false).await?,
89
-                new_listener(&addr2, false).await?,
90
-                &key,
91
-            )
92
-            .await;
87
+            io_loop(listen_any(port).await?, listen_any(port2).await?, &key).await;
93
         }
88
         }
94
     };
89
     };
95
     let listen_signal = crate::common::listen_signal();
90
     let listen_signal = crate::common::listen_signal();

+ 34 - 15
src/rendezvous_server.rs

@@ -15,7 +15,7 @@ use hbb_common::{
15
         register_pk_response::Result::{TOO_FREQUENT, UUID_MISMATCH},
15
         register_pk_response::Result::{TOO_FREQUENT, UUID_MISMATCH},
16
         *,
16
         *,
17
     },
17
     },
18
-    tcp::{new_listener, FramedStream},
18
+    tcp::{listen_any, FramedStream},
19
     timeout,
19
     timeout,
20
     tokio::{
20
     tokio::{
21
         self,
21
         self,
@@ -32,7 +32,7 @@ use ipnetwork::Ipv4Network;
32
 use sodiumoxide::crypto::sign;
32
 use sodiumoxide::crypto::sign;
33
 use std::{
33
 use std::{
34
     collections::HashMap,
34
     collections::HashMap,
35
-    net::{IpAddr, Ipv4Addr, SocketAddr},
35
+    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
36
     sync::Arc,
36
     sync::Arc,
37
     time::Instant,
37
     time::Instant,
38
 };
38
 };
@@ -92,15 +92,15 @@ impl RendezvousServer {
92
     pub async fn start(port: i32, serial: i32, key: &str, rmem: usize) -> ResultType<()> {
92
     pub async fn start(port: i32, serial: i32, key: &str, rmem: usize) -> ResultType<()> {
93
         let (key, sk) = Self::get_server_sk(key);
93
         let (key, sk) = Self::get_server_sk(key);
94
         let addr = format!("0.0.0.0:{}", port);
94
         let addr = format!("0.0.0.0:{}", port);
95
-        let addr2 = format!("0.0.0.0:{}", port - 1);
96
-        let addr3 = format!("0.0.0.0:{}", port + 2);
95
+        let nat_port = port - 1;
96
+        let ws_port = port + 2;
97
         let pm = PeerMap::new().await?;
97
         let pm = PeerMap::new().await?;
98
         log::info!("serial={}", serial);
98
         log::info!("serial={}", serial);
99
         let rendezvous_servers = get_servers(&get_arg("rendezvous-servers"), "rendezvous-servers");
99
         let rendezvous_servers = get_servers(&get_arg("rendezvous-servers"), "rendezvous-servers");
100
-        log::info!("Listening on tcp/udp {}", addr);
101
-        log::info!("Listening on tcp {}, extra port for NAT test", addr2);
102
-        log::info!("Listening on websocket {}", addr3);
103
-        let mut socket = FramedSocket::new_with_buf_size(&addr, rmem).await?;
100
+        log::info!("Listening on tcp/udp :{}", port);
101
+        log::info!("Listening on tcp :{}, extra port for NAT test", nat_port);
102
+        log::info!("Listening on websocket :{}", ws_port);
103
+        let mut socket = create_udp_listener(port, rmem).await?;
104
         let (tx, mut rx) = mpsc::unbounded_channel::<Data>();
104
         let (tx, mut rx) = mpsc::unbounded_channel::<Data>();
105
         let software_url = get_arg("software-url");
105
         let software_url = get_arg("software-url");
106
         let version = hbb_common::get_version_from_url(&software_url);
106
         let version = hbb_common::get_version_from_url(&software_url);
@@ -138,9 +138,9 @@ impl RendezvousServer {
138
         log::info!("local-ip: {:?}", rs.inner.local_ip);
138
         log::info!("local-ip: {:?}", rs.inner.local_ip);
139
         std::env::set_var("PORT_FOR_API", port.to_string());
139
         std::env::set_var("PORT_FOR_API", port.to_string());
140
         rs.parse_relay_servers(&get_arg("relay-servers"));
140
         rs.parse_relay_servers(&get_arg("relay-servers"));
141
-        let mut listener = new_listener(&addr, false).await?;
142
-        let mut listener2 = new_listener(&addr2, false).await?;
143
-        let mut listener3 = new_listener(&addr3, false).await?;
141
+        let mut listener = create_tcp_listener(port).await?;
142
+        let mut listener2 = create_tcp_listener(nat_port).await?;
143
+        let mut listener3 = create_tcp_listener(ws_port).await?;
144
         let test_addr = std::env::var("TEST_HBBS").unwrap_or_default();
144
         let test_addr = std::env::var("TEST_HBBS").unwrap_or_default();
145
         if std::env::var("ALWAYS_USE_RELAY")
145
         if std::env::var("ALWAYS_USE_RELAY")
146
             .unwrap_or_default()
146
             .unwrap_or_default()
@@ -186,19 +186,19 @@ impl RendezvousServer {
186
                 {
186
                 {
187
                     LoopFailure::UdpSocket => {
187
                     LoopFailure::UdpSocket => {
188
                         drop(socket);
188
                         drop(socket);
189
-                        socket = FramedSocket::new_with_buf_size(&addr, rmem).await?;
189
+                        socket = create_udp_listener(port, rmem).await?;
190
                     }
190
                     }
191
                     LoopFailure::Listener => {
191
                     LoopFailure::Listener => {
192
                         drop(listener);
192
                         drop(listener);
193
-                        listener = new_listener(&addr, false).await?;
193
+                        listener = create_tcp_listener(port).await?;
194
                     }
194
                     }
195
                     LoopFailure::Listener2 => {
195
                     LoopFailure::Listener2 => {
196
                         drop(listener2);
196
                         drop(listener2);
197
-                        listener2 = new_listener(&addr2, false).await?;
197
+                        listener2 = create_tcp_listener(nat_port).await?;
198
                     }
198
                     }
199
                     LoopFailure::Listener3 => {
199
                     LoopFailure::Listener3 => {
200
                         drop(listener3);
200
                         drop(listener3);
201
-                        listener3 = new_listener(&addr3, false).await?;
201
+                        listener3 = create_tcp_listener(ws_port).await?;
202
                     }
202
                     }
203
                 }
203
                 }
204
             }
204
             }
@@ -1267,3 +1267,22 @@ async fn send_rk_res(
1267
     });
1267
     });
1268
     socket.send(&msg_out, addr).await
1268
     socket.send(&msg_out, addr).await
1269
 }
1269
 }
1270
+
1271
+async fn create_udp_listener(port: i32, rmem: usize) -> ResultType<FramedSocket> {
1272
+    let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port as _);
1273
+    if let Ok(s) = FramedSocket::new_reuse(&addr, false, rmem).await {
1274
+        log::debug!("listen on udp {:?}", s.local_addr());
1275
+        return Ok(s);
1276
+    }
1277
+    let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port as _);
1278
+    let s = FramedSocket::new_reuse(&addr, false, rmem).await?;
1279
+    log::debug!("listen on udp {:?}", s.local_addr());
1280
+    return Ok(s);
1281
+}
1282
+
1283
+#[inline]
1284
+async fn create_tcp_listener(port: i32) -> ResultType<TcpListener> {
1285
+    let s = listen_any(port as _).await?;
1286
+    log::debug!("listen on tcp {:?}", s.local_addr());
1287
+    Ok(s)
1288
+}