rustdesk лет назад: 2
Родитель
Сommit
1142cf105b
2 измененных файлов с 75 добавлено и 85 удалено
  1. 61 71
      src/relay_server.rs
  2. 14 14
      src/rendezvous_server.rs

+ 61 - 71
src/relay_server.rs

@@ -25,6 +25,7 @@ use std::{
25
     io::prelude::*,
25
     io::prelude::*,
26
     io::Error,
26
     io::Error,
27
     net::SocketAddr,
27
     net::SocketAddr,
28
+    sync::atomic::{AtomicUsize, Ordering},
28
 };
29
 };
29
 
30
 
30
 type Usage = (usize, usize, usize, usize);
31
 type Usage = (usize, usize, usize, usize);
@@ -36,11 +37,11 @@ lazy_static::lazy_static! {
36
     static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default();
37
     static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default();
37
 }
38
 }
38
 
39
 
39
-static mut DOWNGRADE_THRESHOLD: f64 = 0.66;
40
-static mut DOWNGRADE_START_CHECK: usize = 1_800_000; // in ms
41
-static mut LIMIT_SPEED: usize = 4 * 1024 * 1024; // in bit/s
42
-static mut TOTAL_BANDWIDTH: usize = 1024 * 1024 * 1024; // in bit/s
43
-static mut SINGLE_BANDWIDTH: usize = 16 * 1024 * 1024; // in bit/s
40
+static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66
41
+static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms
42
+static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s
43
+static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s
44
+static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s
44
 const BLACKLIST_FILE: &str = "blacklist.txt";
45
 const BLACKLIST_FILE: &str = "blacklist.txt";
45
 const BLOCKLIST_FILE: &str = "blocklist.txt";
46
 const BLOCKLIST_FILE: &str = "blocklist.txt";
46
 
47
 
@@ -99,57 +100,53 @@ fn check_params() {
99
         .map(|x| x.parse::<f64>().unwrap_or(0.))
100
         .map(|x| x.parse::<f64>().unwrap_or(0.))
100
         .unwrap_or(0.);
101
         .unwrap_or(0.);
101
     if tmp > 0. {
102
     if tmp > 0. {
102
-        unsafe {
103
-            DOWNGRADE_THRESHOLD = tmp;
104
-        }
103
+        DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst);
105
     }
104
     }
106
-    unsafe { log::info!("DOWNGRADE_THRESHOLD: {}", DOWNGRADE_THRESHOLD) };
105
+    log::info!(
106
+        "DOWNGRADE_THRESHOLD: {}",
107
+        DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
108
+    );
107
     let tmp = std::env::var("DOWNGRADE_START_CHECK")
109
     let tmp = std::env::var("DOWNGRADE_START_CHECK")
108
         .map(|x| x.parse::<usize>().unwrap_or(0))
110
         .map(|x| x.parse::<usize>().unwrap_or(0))
109
         .unwrap_or(0);
111
         .unwrap_or(0);
110
     if tmp > 0 {
112
     if tmp > 0 {
111
-        unsafe {
112
-            DOWNGRADE_START_CHECK = tmp * 1000;
113
-        }
113
+        DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst);
114
     }
114
     }
115
-    unsafe { log::info!("DOWNGRADE_START_CHECK: {}s", DOWNGRADE_START_CHECK / 1000) };
115
+    log::info!(
116
+        "DOWNGRADE_START_CHECK: {}s",
117
+        DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000
118
+    );
116
     let tmp = std::env::var("LIMIT_SPEED")
119
     let tmp = std::env::var("LIMIT_SPEED")
117
         .map(|x| x.parse::<f64>().unwrap_or(0.))
120
         .map(|x| x.parse::<f64>().unwrap_or(0.))
118
         .unwrap_or(0.);
121
         .unwrap_or(0.);
119
     if tmp > 0. {
122
     if tmp > 0. {
120
-        unsafe {
121
-            LIMIT_SPEED = (tmp * 1024. * 1024.) as usize;
122
-        }
123
+        LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
123
     }
124
     }
124
-    unsafe { log::info!("LIMIT_SPEED: {}Mb/s", LIMIT_SPEED as f64 / 1024. / 1024.) };
125
+    log::info!(
126
+        "LIMIT_SPEED: {}Mb/s",
127
+        LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
128
+    );
125
     let tmp = std::env::var("TOTAL_BANDWIDTH")
129
     let tmp = std::env::var("TOTAL_BANDWIDTH")
126
         .map(|x| x.parse::<f64>().unwrap_or(0.))
130
         .map(|x| x.parse::<f64>().unwrap_or(0.))
127
         .unwrap_or(0.);
131
         .unwrap_or(0.);
128
     if tmp > 0. {
132
     if tmp > 0. {
129
-        unsafe {
130
-            TOTAL_BANDWIDTH = (tmp * 1024. * 1024.) as usize;
131
-        }
133
+        TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
132
     }
134
     }
133
-    unsafe {
134
-        log::info!(
135
-            "TOTAL_BANDWIDTH: {}Mb/s",
136
-            TOTAL_BANDWIDTH as f64 / 1024. / 1024.
137
-        )
138
-    };
135
+
136
+    log::info!(
137
+        "TOTAL_BANDWIDTH: {}Mb/s",
138
+        TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
139
+    );
139
     let tmp = std::env::var("SINGLE_BANDWIDTH")
140
     let tmp = std::env::var("SINGLE_BANDWIDTH")
140
         .map(|x| x.parse::<f64>().unwrap_or(0.))
141
         .map(|x| x.parse::<f64>().unwrap_or(0.))
141
         .unwrap_or(0.);
142
         .unwrap_or(0.);
142
     if tmp > 0. {
143
     if tmp > 0. {
143
-        unsafe {
144
-            SINGLE_BANDWIDTH = (tmp * 1024. * 1024.) as usize;
145
-        }
144
+        SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
146
     }
145
     }
147
-    unsafe {
148
-        log::info!(
149
-            "SINGLE_BANDWIDTH: {}Mb/s",
150
-            SINGLE_BANDWIDTH as f64 / 1024. / 1024.
151
-        )
152
-    };
146
+    log::info!(
147
+        "SINGLE_BANDWIDTH: {}Mb/s",
148
+        SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
149
+    )
153
 }
150
 }
154
 
151
 
155
 async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
152
 async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
@@ -233,76 +230,68 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
233
             if let Some(v) = fds.next() {
230
             if let Some(v) = fds.next() {
234
                 if let Ok(v) = v.parse::<f64>() {
231
                 if let Ok(v) = v.parse::<f64>() {
235
                     if v > 0. {
232
                     if v > 0. {
236
-                        unsafe {
237
-                            DOWNGRADE_THRESHOLD = v;
238
-                        }
233
+                        DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst);
239
                     }
234
                     }
240
                 }
235
                 }
241
             } else {
236
             } else {
242
-                unsafe {
243
-                    res = format!("{DOWNGRADE_THRESHOLD}\n");
244
-                }
237
+                res = format!(
238
+                    "{}\n",
239
+                    DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
240
+                );
245
             }
241
             }
246
         }
242
         }
247
         Some("downgrade-start-check" | "t") => {
243
         Some("downgrade-start-check" | "t") => {
248
             if let Some(v) = fds.next() {
244
             if let Some(v) = fds.next() {
249
                 if let Ok(v) = v.parse::<usize>() {
245
                 if let Ok(v) = v.parse::<usize>() {
250
                     if v > 0 {
246
                     if v > 0 {
251
-                        unsafe {
252
-                            DOWNGRADE_START_CHECK = v * 1000;
253
-                        }
247
+                        DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst);
254
                     }
248
                     }
255
                 }
249
                 }
256
             } else {
250
             } else {
257
-                unsafe {
258
-                    res = format!("{}s\n", DOWNGRADE_START_CHECK / 1000);
259
-                }
251
+                res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000);
260
             }
252
             }
261
         }
253
         }
262
         Some("limit-speed" | "ls") => {
254
         Some("limit-speed" | "ls") => {
263
             if let Some(v) = fds.next() {
255
             if let Some(v) = fds.next() {
264
                 if let Ok(v) = v.parse::<f64>() {
256
                 if let Ok(v) = v.parse::<f64>() {
265
                     if v > 0. {
257
                     if v > 0. {
266
-                        unsafe {
267
-                            LIMIT_SPEED = (v * 1024. * 1024.) as _;
268
-                        }
258
+                        LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
269
                     }
259
                     }
270
                 }
260
                 }
271
             } else {
261
             } else {
272
-                unsafe {
273
-                    res = format!("{}Mb/s\n", LIMIT_SPEED as f64 / 1024. / 1024.);
274
-                }
262
+                res = format!(
263
+                    "{}Mb/s\n",
264
+                    LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
265
+                );
275
             }
266
             }
276
         }
267
         }
277
         Some("total-bandwidth" | "tb") => {
268
         Some("total-bandwidth" | "tb") => {
278
             if let Some(v) = fds.next() {
269
             if let Some(v) = fds.next() {
279
                 if let Ok(v) = v.parse::<f64>() {
270
                 if let Ok(v) = v.parse::<f64>() {
280
                     if v > 0. {
271
                     if v > 0. {
281
-                        unsafe {
282
-                            TOTAL_BANDWIDTH = (v * 1024. * 1024.) as _;
283
-                            limiter.set_speed_limit(TOTAL_BANDWIDTH as _);
284
-                        }
272
+                        TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
273
+                        limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
285
                     }
274
                     }
286
                 }
275
                 }
287
             } else {
276
             } else {
288
-                unsafe {
289
-                    res = format!("{}Mb/s\n", TOTAL_BANDWIDTH as f64 / 1024. / 1024.);
290
-                }
277
+                res = format!(
278
+                    "{}Mb/s\n",
279
+                    TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
280
+                );
291
             }
281
             }
292
         }
282
         }
293
         Some("single-bandwidth" | "sb") => {
283
         Some("single-bandwidth" | "sb") => {
294
             if let Some(v) = fds.next() {
284
             if let Some(v) = fds.next() {
295
                 if let Ok(v) = v.parse::<f64>() {
285
                 if let Ok(v) = v.parse::<f64>() {
296
                     if v > 0. {
286
                     if v > 0. {
297
-                        unsafe {
298
-                            SINGLE_BANDWIDTH = (v * 1024. * 1024.) as _;
299
-                        }
287
+                        SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
300
                     }
288
                     }
301
                 }
289
                 }
302
             } else {
290
             } else {
303
-                unsafe {
304
-                    res = format!("{}Mb/s\n", SINGLE_BANDWIDTH as f64 / 1024. / 1024.);
305
-                }
291
+                res = format!(
292
+                    "{}Mb/s\n",
293
+                    SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
294
+                );
306
             }
295
             }
307
         }
296
         }
308
         Some("usage" | "u") => {
297
         Some("usage" | "u") => {
@@ -336,7 +325,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
336
 
325
 
337
 async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) {
326
 async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) {
338
     check_params();
327
     check_params();
339
-    let limiter = <Limiter>::new(unsafe { TOTAL_BANDWIDTH as _ });
328
+    let limiter = <Limiter>::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
340
     loop {
329
     loop {
341
         tokio::select! {
330
         tokio::select! {
342
             res = listener.accept() => {
331
             res = listener.accept() => {
@@ -475,10 +464,11 @@ async fn relay(
475
     let mut highest_s = 0;
464
     let mut highest_s = 0;
476
     let mut downgrade: bool = false;
465
     let mut downgrade: bool = false;
477
     let mut blacked: bool = false;
466
     let mut blacked: bool = false;
478
-    let limiter = <Limiter>::new(unsafe { SINGLE_BANDWIDTH as _ });
479
-    let blacklist_limiter = <Limiter>::new(unsafe { LIMIT_SPEED as _ });
467
+    let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64;
468
+    let limiter = <Limiter>::new(sb);
469
+    let blacklist_limiter = <Limiter>::new(LIMIT_SPEED.load(Ordering::SeqCst) as _);
480
     let downgrade_threshold =
470
     let downgrade_threshold =
481
-        (unsafe { SINGLE_BANDWIDTH as f64 * DOWNGRADE_THRESHOLD } / 1000.) as usize; // in bit/ms
471
+        (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms
482
     let mut timer = interval(Duration::from_secs(3));
472
     let mut timer = interval(Duration::from_secs(3));
483
     let mut last_recv_time = std::time::Instant::now();
473
     let mut last_recv_time = std::time::Instant::now();
484
     loop {
474
     loop {
@@ -546,7 +536,7 @@ async fn relay(
546
                 (elapsed as _, total as _, highest_s as _, speed as _),
536
                 (elapsed as _, total as _, highest_s as _, speed as _),
547
             );
537
             );
548
             total_s = 0;
538
             total_s = 0;
549
-            if elapsed > unsafe { DOWNGRADE_START_CHECK }
539
+            if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst)
550
                 && !downgrade
540
                 && !downgrade
551
                 && total > elapsed * downgrade_threshold
541
                 && total > elapsed * downgrade_threshold
552
             {
542
             {

+ 14 - 14
src/rendezvous_server.rs

@@ -35,6 +35,7 @@ use sodiumoxide::crypto::sign;
35
 use std::{
35
 use std::{
36
     collections::HashMap,
36
     collections::HashMap,
37
     net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
37
     net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
38
+    sync::atomic::{AtomicBool, AtomicUsize, Ordering},
38
     sync::Arc,
39
     sync::Arc,
39
     time::Instant,
40
     time::Instant,
40
 };
41
 };
@@ -55,10 +56,10 @@ enum Sink {
55
 }
56
 }
56
 type Sender = mpsc::UnboundedSender<Data>;
57
 type Sender = mpsc::UnboundedSender<Data>;
57
 type Receiver = mpsc::UnboundedReceiver<Data>;
58
 type Receiver = mpsc::UnboundedReceiver<Data>;
58
-static mut ROTATION_RELAY_SERVER: usize = 0;
59
+static ROTATION_RELAY_SERVER: AtomicUsize = AtomicUsize::new(0);
59
 type RelayServers = Vec<String>;
60
 type RelayServers = Vec<String>;
60
 static CHECK_RELAY_TIMEOUT: u64 = 3_000;
61
 static CHECK_RELAY_TIMEOUT: u64 = 3_000;
61
-static mut ALWAYS_USE_RELAY: bool = false;
62
+static ALWAYS_USE_RELAY: AtomicBool = AtomicBool::new(false);
62
 
63
 
63
 #[derive(Clone)]
64
 #[derive(Clone)]
64
 struct Inner {
65
 struct Inner {
@@ -147,13 +148,11 @@ impl RendezvousServer {
147
             .to_uppercase()
148
             .to_uppercase()
148
             == "Y"
149
             == "Y"
149
         {
150
         {
150
-            unsafe {
151
-                ALWAYS_USE_RELAY = true;
152
-            }
151
+            ALWAYS_USE_RELAY.store(true, Ordering::SeqCst);
153
         }
152
         }
154
         log::info!(
153
         log::info!(
155
             "ALWAYS_USE_RELAY={}",
154
             "ALWAYS_USE_RELAY={}",
156
-            if unsafe { ALWAYS_USE_RELAY } {
155
+            if ALWAYS_USE_RELAY.load(Ordering::SeqCst) {
157
                 "Y"
156
                 "Y"
158
             } else {
157
             } else {
159
                 "N"
158
                 "N"
@@ -711,7 +710,7 @@ impl RendezvousServer {
711
             let peer_is_lan = self.is_lan(peer_addr);
710
             let peer_is_lan = self.is_lan(peer_addr);
712
             let is_lan = self.is_lan(addr);
711
             let is_lan = self.is_lan(addr);
713
             let mut relay_server = self.get_relay_server(addr.ip(), peer_addr.ip());
712
             let mut relay_server = self.get_relay_server(addr.ip(), peer_addr.ip());
714
-            if unsafe { ALWAYS_USE_RELAY } || (peer_is_lan ^ is_lan) {
713
+            if ALWAYS_USE_RELAY.load(Ordering::SeqCst) || (peer_is_lan ^ is_lan) {
715
                 if peer_is_lan {
714
                 if peer_is_lan {
716
                     // https://github.com/rustdesk/rustdesk-server/issues/24
715
                     // https://github.com/rustdesk/rustdesk-server/issues/24
717
                     relay_server = self.inner.local_ip.clone()
716
                     relay_server = self.inner.local_ip.clone()
@@ -905,10 +904,7 @@ impl RendezvousServer {
905
         } else if self.relay_servers.len() == 1 {
904
         } else if self.relay_servers.len() == 1 {
906
             return self.relay_servers[0].clone();
905
             return self.relay_servers[0].clone();
907
         }
906
         }
908
-        let i = unsafe {
909
-            ROTATION_RELAY_SERVER += 1;
910
-            ROTATION_RELAY_SERVER % self.relay_servers.len()
911
-        };
907
+        let i = ROTATION_RELAY_SERVER.fetch_add(1, Ordering::SeqCst) % self.relay_servers.len();
912
         self.relay_servers[i].clone()
908
         self.relay_servers[i].clone()
913
     }
909
     }
914
 
910
 
@@ -1027,13 +1023,17 @@ impl RendezvousServer {
1027
             Some("always-use-relay" | "aur") => {
1023
             Some("always-use-relay" | "aur") => {
1028
                 if let Some(rs) = fds.next() {
1024
                 if let Some(rs) = fds.next() {
1029
                     if rs.to_uppercase() == "Y" {
1025
                     if rs.to_uppercase() == "Y" {
1030
-                        unsafe { ALWAYS_USE_RELAY = true };
1026
+                        ALWAYS_USE_RELAY.store(true, Ordering::SeqCst);
1031
                     } else {
1027
                     } else {
1032
-                        unsafe { ALWAYS_USE_RELAY = false };
1028
+                        ALWAYS_USE_RELAY.store(false, Ordering::SeqCst);
1033
                     }
1029
                     }
1034
                     self.tx.send(Data::RelayServers0(rs.to_owned())).ok();
1030
                     self.tx.send(Data::RelayServers0(rs.to_owned())).ok();
1035
                 } else {
1031
                 } else {
1036
-                    let _ = writeln!(res, "ALWAYS_USE_RELAY: {:?}", unsafe { ALWAYS_USE_RELAY });
1032
+                    let _ = writeln!(
1033
+                        res,
1034
+                        "ALWAYS_USE_RELAY: {:?}",
1035
+                        ALWAYS_USE_RELAY.load(Ordering::SeqCst)
1036
+                    );
1037
                 }
1037
                 }
1038
             }
1038
             }
1039
             Some("test-geo" | "tg") => {
1039
             Some("test-geo" | "tg") => {