common.rs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. use clap::App;
  2. use hbb_common::{
  3. anyhow::{Context, Result},
  4. log, ResultType,
  5. };
  6. use ini::Ini;
  7. use sodiumoxide::crypto::sign;
  8. use std::{
  9. io::prelude::*,
  10. io::Read,
  11. net::SocketAddr,
  12. time::{Instant, SystemTime},
  13. };
  14. #[allow(dead_code)]
  15. pub(crate) fn get_expired_time() -> Instant {
  16. let now = Instant::now();
  17. now.checked_sub(std::time::Duration::from_secs(3600))
  18. .unwrap_or(now)
  19. }
  20. #[allow(dead_code)]
  21. pub(crate) fn test_if_valid_server(host: &str, name: &str) -> ResultType<SocketAddr> {
  22. use std::net::ToSocketAddrs;
  23. let res = if host.contains(':') {
  24. host.to_socket_addrs()?.next().context("")
  25. } else {
  26. format!("{}:{}", host, 0)
  27. .to_socket_addrs()?
  28. .next()
  29. .context("")
  30. };
  31. if res.is_err() {
  32. log::error!("Invalid {} {}: {:?}", name, host, res);
  33. }
  34. res
  35. }
  36. #[allow(dead_code)]
  37. pub(crate) fn get_servers(s: &str, tag: &str) -> Vec<String> {
  38. let servers: Vec<String> = s
  39. .split(',')
  40. .filter(|x| !x.is_empty() && test_if_valid_server(x, tag).is_ok())
  41. .map(|x| x.to_owned())
  42. .collect();
  43. log::info!("{}={:?}", tag, servers);
  44. servers
  45. }
  46. #[allow(dead_code)]
  47. #[inline]
  48. fn arg_name(name: &str) -> String {
  49. name.to_uppercase().replace('_', "-")
  50. }
  51. #[allow(dead_code)]
  52. pub fn init_args(args: &str, name: &str, about: &str) {
  53. let matches = App::new(name)
  54. .version(crate::version::VERSION)
  55. .author("Purslane Ltd. <info@rustdesk.com>")
  56. .about(about)
  57. .args_from_usage(args)
  58. .get_matches();
  59. if let Ok(v) = Ini::load_from_file(".env") {
  60. if let Some(section) = v.section(None::<String>) {
  61. section
  62. .iter()
  63. .for_each(|(k, v)| std::env::set_var(arg_name(k), v));
  64. }
  65. }
  66. if let Some(config) = matches.value_of("config") {
  67. if let Ok(v) = Ini::load_from_file(config) {
  68. if let Some(section) = v.section(None::<String>) {
  69. section
  70. .iter()
  71. .for_each(|(k, v)| std::env::set_var(arg_name(k), v));
  72. }
  73. }
  74. }
  75. for (k, v) in matches.args {
  76. if let Some(v) = v.vals.first() {
  77. std::env::set_var(arg_name(k), v.to_string_lossy().to_string());
  78. }
  79. }
  80. }
  81. #[allow(dead_code)]
  82. #[inline]
  83. pub fn get_arg(name: &str) -> String {
  84. get_arg_or(name, "".to_owned())
  85. }
  86. #[allow(dead_code)]
  87. #[inline]
  88. pub fn get_arg_or(name: &str, default: String) -> String {
  89. std::env::var(arg_name(name)).unwrap_or(default)
  90. }
  91. #[allow(dead_code)]
  92. #[inline]
  93. pub fn now() -> u64 {
  94. SystemTime::now()
  95. .duration_since(SystemTime::UNIX_EPOCH)
  96. .map(|x| x.as_secs())
  97. .unwrap_or_default()
  98. }
  99. pub fn gen_sk(wait: u64) -> (String, Option<sign::SecretKey>) {
  100. let sk_file = "id_ed25519";
  101. if wait > 0 && !std::path::Path::new(sk_file).exists() {
  102. std::thread::sleep(std::time::Duration::from_millis(wait));
  103. }
  104. if let Ok(mut file) = std::fs::File::open(sk_file) {
  105. let mut contents = String::new();
  106. if file.read_to_string(&mut contents).is_ok() {
  107. let contents = contents.trim();
  108. let sk = base64::decode(contents).unwrap_or_default();
  109. if sk.len() == sign::SECRETKEYBYTES {
  110. let mut tmp = [0u8; sign::SECRETKEYBYTES];
  111. tmp[..].copy_from_slice(&sk);
  112. let pk = base64::encode(&tmp[sign::SECRETKEYBYTES / 2..]);
  113. log::info!("Private key comes from {}", sk_file);
  114. return (pk, Some(sign::SecretKey(tmp)));
  115. } else {
  116. // don't use log here, since it is async
  117. println!("Fatal error: malformed private key in {sk_file}.");
  118. std::process::exit(1);
  119. }
  120. }
  121. } else {
  122. let gen_func = || {
  123. let (tmp, sk) = sign::gen_keypair();
  124. (base64::encode(tmp), sk)
  125. };
  126. let (mut pk, mut sk) = gen_func();
  127. for _ in 0..300 {
  128. if !pk.contains('/') && !pk.contains(':') {
  129. break;
  130. }
  131. (pk, sk) = gen_func();
  132. }
  133. let pub_file = format!("{sk_file}.pub");
  134. if let Ok(mut f) = std::fs::File::create(&pub_file) {
  135. f.write_all(pk.as_bytes()).ok();
  136. if let Ok(mut f) = std::fs::File::create(sk_file) {
  137. let s = base64::encode(&sk);
  138. if f.write_all(s.as_bytes()).is_ok() {
  139. log::info!("Private/public key written to {}/{}", sk_file, pub_file);
  140. log::debug!("Public key: {}", pk);
  141. return (pk, Some(sk));
  142. }
  143. }
  144. }
  145. }
  146. ("".to_owned(), None)
  147. }
  148. #[cfg(unix)]
  149. pub async fn listen_signal() -> Result<()> {
  150. use hbb_common::tokio;
  151. use hbb_common::tokio::signal::unix::{signal, SignalKind};
  152. tokio::spawn(async {
  153. let mut s = signal(SignalKind::terminate())?;
  154. let terminate = s.recv();
  155. let mut s = signal(SignalKind::interrupt())?;
  156. let interrupt = s.recv();
  157. let mut s = signal(SignalKind::quit())?;
  158. let quit = s.recv();
  159. tokio::select! {
  160. _ = terminate => {
  161. log::info!("signal terminate");
  162. }
  163. _ = interrupt => {
  164. log::info!("signal interrupt");
  165. }
  166. _ = quit => {
  167. log::info!("signal quit");
  168. }
  169. }
  170. Ok(())
  171. })
  172. .await?
  173. }
  174. #[cfg(not(unix))]
  175. pub async fn listen_signal() -> Result<()> {
  176. let () = std::future::pending().await;
  177. unreachable!();
  178. }