Browse Source

Fix #324 to remove unsafe

rustdesk 2 years ago
parent
commit
1142cf105b
2 changed files with 75 additions and 85 deletions
  1. 61 71
      src/relay_server.rs
  2. 14 14
      src/rendezvous_server.rs

+ 61 - 71
src/relay_server.rs

@@ -25,6 +25,7 @@ use std::{
25 25
     io::prelude::*,
26 26
     io::Error,
27 27
     net::SocketAddr,
28
+    sync::atomic::{AtomicUsize, Ordering},
28 29
 };
29 30
 
30 31
 type Usage = (usize, usize, usize, usize);
@@ -36,11 +37,11 @@ lazy_static::lazy_static! {
36 37
     static ref BLOCKLIST: RwLock<HashSet<String>> = Default::default();
37 38
 }
38 39
 
39
-static mut DOWNGRADE_THRESHOLD: f64 = 0.66;
40
-static mut DOWNGRADE_START_CHECK: usize = 1_800_000; // in ms
41
-static mut LIMIT_SPEED: usize = 4 * 1024 * 1024; // in bit/s
42
-static mut TOTAL_BANDWIDTH: usize = 1024 * 1024 * 1024; // in bit/s
43
-static mut SINGLE_BANDWIDTH: usize = 16 * 1024 * 1024; // in bit/s
40
+static DOWNGRADE_THRESHOLD_100: AtomicUsize = AtomicUsize::new(66); // 0.66
41
+static DOWNGRADE_START_CHECK: AtomicUsize = AtomicUsize::new(1_800_000); // in ms
42
+static LIMIT_SPEED: AtomicUsize = AtomicUsize::new(4 * 1024 * 1024); // in bit/s
43
+static TOTAL_BANDWIDTH: AtomicUsize = AtomicUsize::new(1024 * 1024 * 1024); // in bit/s
44
+static SINGLE_BANDWIDTH: AtomicUsize = AtomicUsize::new(16 * 1024 * 1024); // in bit/s
44 45
 const BLACKLIST_FILE: &str = "blacklist.txt";
45 46
 const BLOCKLIST_FILE: &str = "blocklist.txt";
46 47
 
@@ -99,57 +100,53 @@ fn check_params() {
99 100
         .map(|x| x.parse::<f64>().unwrap_or(0.))
100 101
         .unwrap_or(0.);
101 102
     if tmp > 0. {
102
-        unsafe {
103
-            DOWNGRADE_THRESHOLD = tmp;
104
-        }
103
+        DOWNGRADE_THRESHOLD_100.store((tmp * 100.) as _, Ordering::SeqCst);
105 104
     }
106
-    unsafe { log::info!("DOWNGRADE_THRESHOLD: {}", DOWNGRADE_THRESHOLD) };
105
+    log::info!(
106
+        "DOWNGRADE_THRESHOLD: {}",
107
+        DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
108
+    );
107 109
     let tmp = std::env::var("DOWNGRADE_START_CHECK")
108 110
         .map(|x| x.parse::<usize>().unwrap_or(0))
109 111
         .unwrap_or(0);
110 112
     if tmp > 0 {
111
-        unsafe {
112
-            DOWNGRADE_START_CHECK = tmp * 1000;
113
-        }
113
+        DOWNGRADE_START_CHECK.store(tmp * 1000, Ordering::SeqCst);
114 114
     }
115
-    unsafe { log::info!("DOWNGRADE_START_CHECK: {}s", DOWNGRADE_START_CHECK / 1000) };
115
+    log::info!(
116
+        "DOWNGRADE_START_CHECK: {}s",
117
+        DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000
118
+    );
116 119
     let tmp = std::env::var("LIMIT_SPEED")
117 120
         .map(|x| x.parse::<f64>().unwrap_or(0.))
118 121
         .unwrap_or(0.);
119 122
     if tmp > 0. {
120
-        unsafe {
121
-            LIMIT_SPEED = (tmp * 1024. * 1024.) as usize;
122
-        }
123
+        LIMIT_SPEED.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
123 124
     }
124
-    unsafe { log::info!("LIMIT_SPEED: {}Mb/s", LIMIT_SPEED as f64 / 1024. / 1024.) };
125
+    log::info!(
126
+        "LIMIT_SPEED: {}Mb/s",
127
+        LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
128
+    );
125 129
     let tmp = std::env::var("TOTAL_BANDWIDTH")
126 130
         .map(|x| x.parse::<f64>().unwrap_or(0.))
127 131
         .unwrap_or(0.);
128 132
     if tmp > 0. {
129
-        unsafe {
130
-            TOTAL_BANDWIDTH = (tmp * 1024. * 1024.) as usize;
131
-        }
133
+        TOTAL_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
132 134
     }
133
-    unsafe {
134
-        log::info!(
135
-            "TOTAL_BANDWIDTH: {}Mb/s",
136
-            TOTAL_BANDWIDTH as f64 / 1024. / 1024.
137
-        )
138
-    };
135
+
136
+    log::info!(
137
+        "TOTAL_BANDWIDTH: {}Mb/s",
138
+        TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
139
+    );
139 140
     let tmp = std::env::var("SINGLE_BANDWIDTH")
140 141
         .map(|x| x.parse::<f64>().unwrap_or(0.))
141 142
         .unwrap_or(0.);
142 143
     if tmp > 0. {
143
-        unsafe {
144
-            SINGLE_BANDWIDTH = (tmp * 1024. * 1024.) as usize;
145
-        }
144
+        SINGLE_BANDWIDTH.store((tmp * 1024. * 1024.) as usize, Ordering::SeqCst);
146 145
     }
147
-    unsafe {
148
-        log::info!(
149
-            "SINGLE_BANDWIDTH: {}Mb/s",
150
-            SINGLE_BANDWIDTH as f64 / 1024. / 1024.
151
-        )
152
-    };
146
+    log::info!(
147
+        "SINGLE_BANDWIDTH: {}Mb/s",
148
+        SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
149
+    )
153 150
 }
154 151
 
155 152
 async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
@@ -233,76 +230,68 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
233 230
             if let Some(v) = fds.next() {
234 231
                 if let Ok(v) = v.parse::<f64>() {
235 232
                     if v > 0. {
236
-                        unsafe {
237
-                            DOWNGRADE_THRESHOLD = v;
238
-                        }
233
+                        DOWNGRADE_THRESHOLD_100.store((v * 100.) as _, Ordering::SeqCst);
239 234
                     }
240 235
                 }
241 236
             } else {
242
-                unsafe {
243
-                    res = format!("{DOWNGRADE_THRESHOLD}\n");
244
-                }
237
+                res = format!(
238
+                    "{}\n",
239
+                    DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100.
240
+                );
245 241
             }
246 242
         }
247 243
         Some("downgrade-start-check" | "t") => {
248 244
             if let Some(v) = fds.next() {
249 245
                 if let Ok(v) = v.parse::<usize>() {
250 246
                     if v > 0 {
251
-                        unsafe {
252
-                            DOWNGRADE_START_CHECK = v * 1000;
253
-                        }
247
+                        DOWNGRADE_START_CHECK.store(v * 1000, Ordering::SeqCst);
254 248
                     }
255 249
                 }
256 250
             } else {
257
-                unsafe {
258
-                    res = format!("{}s\n", DOWNGRADE_START_CHECK / 1000);
259
-                }
251
+                res = format!("{}s\n", DOWNGRADE_START_CHECK.load(Ordering::SeqCst) / 1000);
260 252
             }
261 253
         }
262 254
         Some("limit-speed" | "ls") => {
263 255
             if let Some(v) = fds.next() {
264 256
                 if let Ok(v) = v.parse::<f64>() {
265 257
                     if v > 0. {
266
-                        unsafe {
267
-                            LIMIT_SPEED = (v * 1024. * 1024.) as _;
268
-                        }
258
+                        LIMIT_SPEED.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
269 259
                     }
270 260
                 }
271 261
             } else {
272
-                unsafe {
273
-                    res = format!("{}Mb/s\n", LIMIT_SPEED as f64 / 1024. / 1024.);
274
-                }
262
+                res = format!(
263
+                    "{}Mb/s\n",
264
+                    LIMIT_SPEED.load(Ordering::SeqCst) as f64 / 1024. / 1024.
265
+                );
275 266
             }
276 267
         }
277 268
         Some("total-bandwidth" | "tb") => {
278 269
             if let Some(v) = fds.next() {
279 270
                 if let Ok(v) = v.parse::<f64>() {
280 271
                     if v > 0. {
281
-                        unsafe {
282
-                            TOTAL_BANDWIDTH = (v * 1024. * 1024.) as _;
283
-                            limiter.set_speed_limit(TOTAL_BANDWIDTH as _);
284
-                        }
272
+                        TOTAL_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
273
+                        limiter.set_speed_limit(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
285 274
                     }
286 275
                 }
287 276
             } else {
288
-                unsafe {
289
-                    res = format!("{}Mb/s\n", TOTAL_BANDWIDTH as f64 / 1024. / 1024.);
290
-                }
277
+                res = format!(
278
+                    "{}Mb/s\n",
279
+                    TOTAL_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
280
+                );
291 281
             }
292 282
         }
293 283
         Some("single-bandwidth" | "sb") => {
294 284
             if let Some(v) = fds.next() {
295 285
                 if let Ok(v) = v.parse::<f64>() {
296 286
                     if v > 0. {
297
-                        unsafe {
298
-                            SINGLE_BANDWIDTH = (v * 1024. * 1024.) as _;
299
-                        }
287
+                        SINGLE_BANDWIDTH.store((v * 1024. * 1024.) as _, Ordering::SeqCst);
300 288
                     }
301 289
                 }
302 290
             } else {
303
-                unsafe {
304
-                    res = format!("{}Mb/s\n", SINGLE_BANDWIDTH as f64 / 1024. / 1024.);
305
-                }
291
+                res = format!(
292
+                    "{}Mb/s\n",
293
+                    SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64 / 1024. / 1024.
294
+                );
306 295
             }
307 296
         }
308 297
         Some("usage" | "u") => {
@@ -336,7 +325,7 @@ async fn check_cmd(cmd: &str, limiter: Limiter) -> String {
336 325
 
337 326
 async fn io_loop(listener: TcpListener, listener2: TcpListener, key: &str) {
338 327
     check_params();
339
-    let limiter = <Limiter>::new(unsafe { TOTAL_BANDWIDTH as _ });
328
+    let limiter = <Limiter>::new(TOTAL_BANDWIDTH.load(Ordering::SeqCst) as _);
340 329
     loop {
341 330
         tokio::select! {
342 331
             res = listener.accept() => {
@@ -475,10 +464,11 @@ async fn relay(
475 464
     let mut highest_s = 0;
476 465
     let mut downgrade: bool = false;
477 466
     let mut blacked: bool = false;
478
-    let limiter = <Limiter>::new(unsafe { SINGLE_BANDWIDTH as _ });
479
-    let blacklist_limiter = <Limiter>::new(unsafe { LIMIT_SPEED as _ });
467
+    let sb = SINGLE_BANDWIDTH.load(Ordering::SeqCst) as f64;
468
+    let limiter = <Limiter>::new(sb);
469
+    let blacklist_limiter = <Limiter>::new(LIMIT_SPEED.load(Ordering::SeqCst) as _);
480 470
     let downgrade_threshold =
481
-        (unsafe { SINGLE_BANDWIDTH as f64 * DOWNGRADE_THRESHOLD } / 1000.) as usize; // in bit/ms
471
+        (sb * DOWNGRADE_THRESHOLD_100.load(Ordering::SeqCst) as f64 / 100. / 1000.) as usize; // in bit/ms
482 472
     let mut timer = interval(Duration::from_secs(3));
483 473
     let mut last_recv_time = std::time::Instant::now();
484 474
     loop {
@@ -546,7 +536,7 @@ async fn relay(
546 536
                 (elapsed as _, total as _, highest_s as _, speed as _),
547 537
             );
548 538
             total_s = 0;
549
-            if elapsed > unsafe { DOWNGRADE_START_CHECK }
539
+            if elapsed > DOWNGRADE_START_CHECK.load(Ordering::SeqCst)
550 540
                 && !downgrade
551 541
                 && total > elapsed * downgrade_threshold
552 542
             {

+ 14 - 14
src/rendezvous_server.rs

@@ -35,6 +35,7 @@ use sodiumoxide::crypto::sign;
35 35
 use std::{
36 36
     collections::HashMap,
37 37
     net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
38
+    sync::atomic::{AtomicBool, AtomicUsize, Ordering},
38 39
     sync::Arc,
39 40
     time::Instant,
40 41
 };
@@ -55,10 +56,10 @@ enum Sink {
55 56
 }
56 57
 type Sender = mpsc::UnboundedSender<Data>;
57 58
 type Receiver = mpsc::UnboundedReceiver<Data>;
58
-static mut ROTATION_RELAY_SERVER: usize = 0;
59
+static ROTATION_RELAY_SERVER: AtomicUsize = AtomicUsize::new(0);
59 60
 type RelayServers = Vec<String>;
60 61
 static CHECK_RELAY_TIMEOUT: u64 = 3_000;
61
-static mut ALWAYS_USE_RELAY: bool = false;
62
+static ALWAYS_USE_RELAY: AtomicBool = AtomicBool::new(false);
62 63
 
63 64
 #[derive(Clone)]
64 65
 struct Inner {
@@ -147,13 +148,11 @@ impl RendezvousServer {
147 148
             .to_uppercase()
148 149
             == "Y"
149 150
         {
150
-            unsafe {
151
-                ALWAYS_USE_RELAY = true;
152
-            }
151
+            ALWAYS_USE_RELAY.store(true, Ordering::SeqCst);
153 152
         }
154 153
         log::info!(
155 154
             "ALWAYS_USE_RELAY={}",
156
-            if unsafe { ALWAYS_USE_RELAY } {
155
+            if ALWAYS_USE_RELAY.load(Ordering::SeqCst) {
157 156
                 "Y"
158 157
             } else {
159 158
                 "N"
@@ -711,7 +710,7 @@ impl RendezvousServer {
711 710
             let peer_is_lan = self.is_lan(peer_addr);
712 711
             let is_lan = self.is_lan(addr);
713 712
             let mut relay_server = self.get_relay_server(addr.ip(), peer_addr.ip());
714
-            if unsafe { ALWAYS_USE_RELAY } || (peer_is_lan ^ is_lan) {
713
+            if ALWAYS_USE_RELAY.load(Ordering::SeqCst) || (peer_is_lan ^ is_lan) {
715 714
                 if peer_is_lan {
716 715
                     // https://github.com/rustdesk/rustdesk-server/issues/24
717 716
                     relay_server = self.inner.local_ip.clone()
@@ -905,10 +904,7 @@ impl RendezvousServer {
905 904
         } else if self.relay_servers.len() == 1 {
906 905
             return self.relay_servers[0].clone();
907 906
         }
908
-        let i = unsafe {
909
-            ROTATION_RELAY_SERVER += 1;
910
-            ROTATION_RELAY_SERVER % self.relay_servers.len()
911
-        };
907
+        let i = ROTATION_RELAY_SERVER.fetch_add(1, Ordering::SeqCst) % self.relay_servers.len();
912 908
         self.relay_servers[i].clone()
913 909
     }
914 910
 
@@ -1027,13 +1023,17 @@ impl RendezvousServer {
1027 1023
             Some("always-use-relay" | "aur") => {
1028 1024
                 if let Some(rs) = fds.next() {
1029 1025
                     if rs.to_uppercase() == "Y" {
1030
-                        unsafe { ALWAYS_USE_RELAY = true };
1026
+                        ALWAYS_USE_RELAY.store(true, Ordering::SeqCst);
1031 1027
                     } else {
1032
-                        unsafe { ALWAYS_USE_RELAY = false };
1028
+                        ALWAYS_USE_RELAY.store(false, Ordering::SeqCst);
1033 1029
                     }
1034 1030
                     self.tx.send(Data::RelayServers0(rs.to_owned())).ok();
1035 1031
                 } else {
1036
-                    let _ = writeln!(res, "ALWAYS_USE_RELAY: {:?}", unsafe { ALWAYS_USE_RELAY });
1032
+                    let _ = writeln!(
1033
+                        res,
1034
+                        "ALWAYS_USE_RELAY: {:?}",
1035
+                        ALWAYS_USE_RELAY.load(Ordering::SeqCst)
1036
+                    );
1037 1037
                 }
1038 1038
             }
1039 1039
             Some("test-geo" | "tg") => {