socket_client.rs 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. use crate::{
  2. config::{Config, NetworkType},
  3. tcp::FramedStream,
  4. udp::FramedSocket,
  5. ResultType,
  6. };
  7. use anyhow::Context;
  8. use std::net::SocketAddr;
  9. use tokio::net::ToSocketAddrs;
  10. use tokio_socks::{IntoTargetAddr, TargetAddr};
  11. #[inline]
  12. pub fn check_port<T: std::string::ToString>(host: T, port: i32) -> String {
  13. let host = host.to_string();
  14. if crate::is_ipv6_str(&host) {
  15. if host.starts_with('[') {
  16. return host;
  17. }
  18. return format!("[{host}]:{port}");
  19. }
  20. if !host.contains(':') {
  21. return format!("{host}:{port}");
  22. }
  23. host
  24. }
  25. #[inline]
  26. pub fn increase_port<T: std::string::ToString>(host: T, offset: i32) -> String {
  27. let host = host.to_string();
  28. if crate::is_ipv6_str(&host) {
  29. if host.starts_with('[') {
  30. let tmp: Vec<&str> = host.split("]:").collect();
  31. if tmp.len() == 2 {
  32. let port: i32 = tmp[1].parse().unwrap_or(0);
  33. if port > 0 {
  34. return format!("{}]:{}", tmp[0], port + offset);
  35. }
  36. }
  37. }
  38. } else if host.contains(':') {
  39. let tmp: Vec<&str> = host.split(':').collect();
  40. if tmp.len() == 2 {
  41. let port: i32 = tmp[1].parse().unwrap_or(0);
  42. if port > 0 {
  43. return format!("{}:{}", tmp[0], port + offset);
  44. }
  45. }
  46. }
  47. host
  48. }
  49. pub fn test_if_valid_server(host: &str) -> String {
  50. let host = check_port(host, 0);
  51. use std::net::ToSocketAddrs;
  52. match Config::get_network_type() {
  53. NetworkType::Direct => match host.to_socket_addrs() {
  54. Err(err) => err.to_string(),
  55. Ok(_) => "".to_owned(),
  56. },
  57. NetworkType::ProxySocks => match &host.into_target_addr() {
  58. Err(err) => err.to_string(),
  59. Ok(_) => "".to_owned(),
  60. },
  61. }
  62. }
  63. pub trait IsResolvedSocketAddr {
  64. fn resolve(&self) -> Option<&SocketAddr>;
  65. }
  66. impl IsResolvedSocketAddr for SocketAddr {
  67. fn resolve(&self) -> Option<&SocketAddr> {
  68. Some(self)
  69. }
  70. }
  71. impl IsResolvedSocketAddr for String {
  72. fn resolve(&self) -> Option<&SocketAddr> {
  73. None
  74. }
  75. }
  76. impl IsResolvedSocketAddr for &str {
  77. fn resolve(&self) -> Option<&SocketAddr> {
  78. None
  79. }
  80. }
  81. #[inline]
  82. pub async fn connect_tcp<
  83. 't,
  84. T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display,
  85. >(
  86. target: T,
  87. ms_timeout: u64,
  88. ) -> ResultType<FramedStream> {
  89. connect_tcp_local(target, None, ms_timeout).await
  90. }
  91. pub async fn connect_tcp_local<
  92. 't,
  93. T: IntoTargetAddr<'t> + ToSocketAddrs + IsResolvedSocketAddr + std::fmt::Display,
  94. >(
  95. target: T,
  96. local: Option<SocketAddr>,
  97. ms_timeout: u64,
  98. ) -> ResultType<FramedStream> {
  99. if let Some(conf) = Config::get_socks() {
  100. return FramedStream::connect(
  101. conf.proxy.as_str(),
  102. target,
  103. local,
  104. conf.username.as_str(),
  105. conf.password.as_str(),
  106. ms_timeout,
  107. )
  108. .await;
  109. }
  110. if let Some(target) = target.resolve() {
  111. if let Some(local) = local {
  112. if local.is_ipv6() && target.is_ipv4() {
  113. let target = query_nip_io(target).await?;
  114. return FramedStream::new(target, Some(local), ms_timeout).await;
  115. }
  116. }
  117. }
  118. FramedStream::new(target, local, ms_timeout).await
  119. }
  120. #[inline]
  121. pub fn is_ipv4(target: &TargetAddr<'_>) -> bool {
  122. match target {
  123. TargetAddr::Ip(addr) => addr.is_ipv4(),
  124. _ => true,
  125. }
  126. }
  127. #[inline]
  128. pub async fn query_nip_io(addr: &SocketAddr) -> ResultType<SocketAddr> {
  129. tokio::net::lookup_host(format!("{}.nip.io:{}", addr.ip(), addr.port()))
  130. .await?
  131. .find(|x| x.is_ipv6())
  132. .context("Failed to get ipv6 from nip.io")
  133. }
  134. #[inline]
  135. pub fn ipv4_to_ipv6(addr: String, ipv4: bool) -> String {
  136. if !ipv4 && crate::is_ipv4_str(&addr) {
  137. if let Some(ip) = addr.split(':').next() {
  138. return addr.replace(ip, &format!("{ip}.nip.io"));
  139. }
  140. }
  141. addr
  142. }
  143. async fn test_target(target: &str) -> ResultType<SocketAddr> {
  144. if let Ok(Ok(s)) = super::timeout(1000, tokio::net::TcpStream::connect(target)).await {
  145. if let Ok(addr) = s.peer_addr() {
  146. return Ok(addr);
  147. }
  148. }
  149. tokio::net::lookup_host(target)
  150. .await?
  151. .next()
  152. .context(format!("Failed to look up host for {target}"))
  153. }
  154. #[inline]
  155. pub async fn new_udp_for(
  156. target: &str,
  157. ms_timeout: u64,
  158. ) -> ResultType<(FramedSocket, TargetAddr<'static>)> {
  159. let (ipv4, target) = if NetworkType::Direct == Config::get_network_type() {
  160. let addr = test_target(target).await?;
  161. (addr.is_ipv4(), addr.into_target_addr()?)
  162. } else {
  163. (true, target.into_target_addr()?)
  164. };
  165. Ok((
  166. new_udp(Config::get_any_listen_addr(ipv4), ms_timeout).await?,
  167. target.to_owned(),
  168. ))
  169. }
  170. async fn new_udp<T: ToSocketAddrs>(local: T, ms_timeout: u64) -> ResultType<FramedSocket> {
  171. match Config::get_socks() {
  172. None => Ok(FramedSocket::new(local).await?),
  173. Some(conf) => {
  174. let socket = FramedSocket::new_proxy(
  175. conf.proxy.as_str(),
  176. local,
  177. conf.username.as_str(),
  178. conf.password.as_str(),
  179. ms_timeout,
  180. )
  181. .await?;
  182. Ok(socket)
  183. }
  184. }
  185. }
  186. pub async fn rebind_udp_for(
  187. target: &str,
  188. ) -> ResultType<Option<(FramedSocket, TargetAddr<'static>)>> {
  189. if Config::get_network_type() != NetworkType::Direct {
  190. return Ok(None);
  191. }
  192. let addr = test_target(target).await?;
  193. let v4 = addr.is_ipv4();
  194. Ok(Some((
  195. FramedSocket::new(Config::get_any_listen_addr(v4)).await?,
  196. addr.into_target_addr()?.to_owned(),
  197. )))
  198. }
  199. #[cfg(test)]
  200. mod tests {
  201. use std::net::ToSocketAddrs;
  202. use super::*;
  203. #[test]
  204. fn test_nat64() {
  205. test_nat64_async();
  206. }
  207. #[tokio::main(flavor = "current_thread")]
  208. async fn test_nat64_async() {
  209. assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), true), "1.1.1.1");
  210. assert_eq!(ipv4_to_ipv6("1.1.1.1".to_owned(), false), "1.1.1.1.nip.io");
  211. assert_eq!(
  212. ipv4_to_ipv6("1.1.1.1:8080".to_owned(), false),
  213. "1.1.1.1.nip.io:8080"
  214. );
  215. assert_eq!(
  216. ipv4_to_ipv6("rustdesk.com".to_owned(), false),
  217. "rustdesk.com"
  218. );
  219. if ("rustdesk.com:80")
  220. .to_socket_addrs()
  221. .unwrap()
  222. .next()
  223. .unwrap()
  224. .is_ipv6()
  225. {
  226. assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap())
  227. .await
  228. .unwrap()
  229. .is_ipv6());
  230. return;
  231. }
  232. assert!(query_nip_io(&"1.1.1.1:80".parse().unwrap()).await.is_err());
  233. }
  234. #[test]
  235. fn test_test_if_valid_server() {
  236. assert!(!test_if_valid_server("a").is_empty());
  237. // on Linux, "1" is resolved to "0.0.0.1"
  238. assert!(test_if_valid_server("1.1.1.1").is_empty());
  239. assert!(test_if_valid_server("1.1.1.1:1").is_empty());
  240. }
  241. #[test]
  242. fn test_check_port() {
  243. assert_eq!(check_port("[1:2]:12", 32), "[1:2]:12");
  244. assert_eq!(check_port("1:2", 32), "[1:2]:32");
  245. assert_eq!(check_port("z1:2", 32), "z1:2");
  246. assert_eq!(check_port("1.1.1.1", 32), "1.1.1.1:32");
  247. assert_eq!(check_port("1.1.1.1:32", 32), "1.1.1.1:32");
  248. assert_eq!(check_port("test.com:32", 0), "test.com:32");
  249. assert_eq!(increase_port("[1:2]:12", 1), "[1:2]:13");
  250. assert_eq!(increase_port("1.2.2.4:12", 1), "1.2.2.4:13");
  251. assert_eq!(increase_port("1.2.2.4", 1), "1.2.2.4");
  252. assert_eq!(increase_port("test.com", 1), "test.com");
  253. assert_eq!(increase_port("test.com:13", 4), "test.com:17");
  254. assert_eq!(increase_port("1:13", 4), "1:13");
  255. assert_eq!(increase_port("22:1:13", 4), "22:1:13");
  256. assert_eq!(increase_port("z1:2", 1), "z1:3");
  257. }
  258. }