common.rs 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. use clap::App;
  2. use hbb_common::{
  3. allow_err, anyhow::{Context, Result}, get_version_number, log, tokio, ResultType
  4. };
  5. use ini::Ini;
  6. use sodiumoxide::crypto::sign;
  7. use std::{
  8. io::prelude::*,
  9. io::Read,
  10. net::SocketAddr,
  11. time::{Instant, SystemTime},
  12. };
  13. #[allow(dead_code)]
  14. pub(crate) fn get_expired_time() -> Instant {
  15. let now = Instant::now();
  16. now.checked_sub(std::time::Duration::from_secs(3600))
  17. .unwrap_or(now)
  18. }
  19. #[allow(dead_code)]
  20. pub(crate) fn test_if_valid_server(host: &str, name: &str) -> ResultType<SocketAddr> {
  21. use std::net::ToSocketAddrs;
  22. let res = if host.contains(':') {
  23. host.to_socket_addrs()?.next().context("")
  24. } else {
  25. format!("{}:{}", host, 0)
  26. .to_socket_addrs()?
  27. .next()
  28. .context("")
  29. };
  30. if res.is_err() {
  31. log::error!("Invalid {} {}: {:?}", name, host, res);
  32. }
  33. res
  34. }
  35. #[allow(dead_code)]
  36. pub(crate) fn get_servers(s: &str, tag: &str) -> Vec<String> {
  37. let servers: Vec<String> = s
  38. .split(',')
  39. .filter(|x| !x.is_empty() && test_if_valid_server(x, tag).is_ok())
  40. .map(|x| x.to_owned())
  41. .collect();
  42. log::info!("{}={:?}", tag, servers);
  43. servers
  44. }
  45. #[allow(dead_code)]
  46. #[inline]
  47. fn arg_name(name: &str) -> String {
  48. name.to_uppercase().replace('_', "-")
  49. }
  50. #[allow(dead_code)]
  51. pub fn init_args(args: &str, name: &str, about: &str) {
  52. let matches = App::new(name)
  53. .version(crate::version::VERSION)
  54. .author("Purslane Ltd. <info@rustdesk.com>")
  55. .about(about)
  56. .args_from_usage(args)
  57. .get_matches();
  58. if let Ok(v) = Ini::load_from_file(".env") {
  59. if let Some(section) = v.section(None::<String>) {
  60. section
  61. .iter()
  62. .for_each(|(k, v)| std::env::set_var(arg_name(k), v));
  63. }
  64. }
  65. if let Some(config) = matches.value_of("config") {
  66. if let Ok(v) = Ini::load_from_file(config) {
  67. if let Some(section) = v.section(None::<String>) {
  68. section
  69. .iter()
  70. .for_each(|(k, v)| std::env::set_var(arg_name(k), v));
  71. }
  72. }
  73. }
  74. for (k, v) in matches.args {
  75. if let Some(v) = v.vals.first() {
  76. std::env::set_var(arg_name(k), v.to_string_lossy().to_string());
  77. }
  78. }
  79. }
  80. #[allow(dead_code)]
  81. #[inline]
  82. pub fn get_arg(name: &str) -> String {
  83. get_arg_or(name, "".to_owned())
  84. }
  85. #[allow(dead_code)]
  86. #[inline]
  87. pub fn get_arg_or(name: &str, default: String) -> String {
  88. std::env::var(arg_name(name)).unwrap_or(default)
  89. }
  90. #[allow(dead_code)]
  91. #[inline]
  92. pub fn now() -> u64 {
  93. SystemTime::now()
  94. .duration_since(SystemTime::UNIX_EPOCH)
  95. .map(|x| x.as_secs())
  96. .unwrap_or_default()
  97. }
  98. pub fn gen_sk(wait: u64) -> (String, Option<sign::SecretKey>) {
  99. let sk_file = "id_ed25519";
  100. if wait > 0 && !std::path::Path::new(sk_file).exists() {
  101. std::thread::sleep(std::time::Duration::from_millis(wait));
  102. }
  103. if let Ok(mut file) = std::fs::File::open(sk_file) {
  104. let mut contents = String::new();
  105. if file.read_to_string(&mut contents).is_ok() {
  106. let contents = contents.trim();
  107. let sk = base64::decode(contents).unwrap_or_default();
  108. if sk.len() == sign::SECRETKEYBYTES {
  109. let mut tmp = [0u8; sign::SECRETKEYBYTES];
  110. tmp[..].copy_from_slice(&sk);
  111. let pk = base64::encode(&tmp[sign::SECRETKEYBYTES / 2..]);
  112. log::info!("Private key comes from {}", sk_file);
  113. return (pk, Some(sign::SecretKey(tmp)));
  114. } else {
  115. // don't use log here, since it is async
  116. println!("Fatal error: malformed private key in {sk_file}.");
  117. std::process::exit(1);
  118. }
  119. }
  120. } else {
  121. let gen_func = || {
  122. let (tmp, sk) = sign::gen_keypair();
  123. (base64::encode(tmp), sk)
  124. };
  125. let (mut pk, mut sk) = gen_func();
  126. for _ in 0..300 {
  127. if !pk.contains('/') && !pk.contains(':') {
  128. break;
  129. }
  130. (pk, sk) = gen_func();
  131. }
  132. let pub_file = format!("{sk_file}.pub");
  133. if let Ok(mut f) = std::fs::File::create(&pub_file) {
  134. f.write_all(pk.as_bytes()).ok();
  135. if let Ok(mut f) = std::fs::File::create(sk_file) {
  136. let s = base64::encode(&sk);
  137. if f.write_all(s.as_bytes()).is_ok() {
  138. log::info!("Private/public key written to {}/{}", sk_file, pub_file);
  139. log::debug!("Public key: {}", pk);
  140. return (pk, Some(sk));
  141. }
  142. }
  143. }
  144. }
  145. ("".to_owned(), None)
  146. }
  147. #[cfg(unix)]
  148. pub async fn listen_signal() -> Result<()> {
  149. use hbb_common::tokio;
  150. use hbb_common::tokio::signal::unix::{signal, SignalKind};
  151. tokio::spawn(async {
  152. let mut s = signal(SignalKind::terminate())?;
  153. let terminate = s.recv();
  154. let mut s = signal(SignalKind::interrupt())?;
  155. let interrupt = s.recv();
  156. let mut s = signal(SignalKind::quit())?;
  157. let quit = s.recv();
  158. tokio::select! {
  159. _ = terminate => {
  160. log::info!("signal terminate");
  161. }
  162. _ = interrupt => {
  163. log::info!("signal interrupt");
  164. }
  165. _ = quit => {
  166. log::info!("signal quit");
  167. }
  168. }
  169. Ok(())
  170. })
  171. .await?
  172. }
  173. #[cfg(not(unix))]
  174. pub async fn listen_signal() -> Result<()> {
  175. let () = std::future::pending().await;
  176. unreachable!();
  177. }
  178. pub fn check_software_update() {
  179. const ONE_DAY_IN_SECONDS: u64 = 60 * 60 * 24;
  180. std::thread::spawn(move || loop {
  181. std::thread::spawn(move || allow_err!(check_software_update_()));
  182. std::thread::sleep(std::time::Duration::from_secs(ONE_DAY_IN_SECONDS));
  183. });
  184. }
  185. #[tokio::main(flavor = "current_thread")]
  186. async fn check_software_update_() -> hbb_common::ResultType<()> {
  187. let (request, url) = hbb_common::version_check_request(hbb_common::VER_TYPE_RUSTDESK_SERVER.to_string());
  188. let latest_release_response = reqwest::Client::builder().build()?
  189. .post(url)
  190. .json(&request)
  191. .send()
  192. .await?;
  193. let bytes = latest_release_response.bytes().await?;
  194. let resp: hbb_common::VersionCheckResponse = serde_json::from_slice(&bytes)?;
  195. let response_url = resp.url;
  196. let latest_release_version = response_url.rsplit('/').next().unwrap_or_default();
  197. if get_version_number(&latest_release_version) > get_version_number(crate::version::VERSION) {
  198. log::info!("new version is available: {}", latest_release_version);
  199. }
  200. Ok(())
  201. }