bytes_codec.rs 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. use bytes::{Buf, BufMut, Bytes, BytesMut};
  2. use std::io;
  3. use tokio_util::codec::{Decoder, Encoder};
  4. #[derive(Debug, Clone, Copy)]
  5. pub struct BytesCodec {
  6. state: DecodeState,
  7. raw: bool,
  8. max_packet_length: usize,
  9. }
  10. #[derive(Debug, Clone, Copy)]
  11. enum DecodeState {
  12. Head,
  13. Data(usize),
  14. }
  15. impl BytesCodec {
  16. pub fn new() -> Self {
  17. Self {
  18. state: DecodeState::Head,
  19. raw: false,
  20. max_packet_length: usize::MAX,
  21. }
  22. }
  23. pub fn set_raw(&mut self) {
  24. self.raw = true;
  25. }
  26. pub fn set_max_packet_length(&mut self, n: usize) {
  27. self.max_packet_length = n;
  28. }
  29. fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
  30. if src.is_empty() {
  31. return Ok(None);
  32. }
  33. let head_len = ((src[0] & 0x3) + 1) as usize;
  34. if src.len() < head_len {
  35. return Ok(None);
  36. }
  37. let mut n = src[0] as usize;
  38. if head_len > 1 {
  39. n |= (src[1] as usize) << 8;
  40. }
  41. if head_len > 2 {
  42. n |= (src[2] as usize) << 16;
  43. }
  44. if head_len > 3 {
  45. n |= (src[3] as usize) << 24;
  46. }
  47. n >>= 2;
  48. if n > self.max_packet_length {
  49. return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet"));
  50. }
  51. src.advance(head_len);
  52. src.reserve(n);
  53. return Ok(Some(n));
  54. }
  55. fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
  56. if src.len() < n {
  57. return Ok(None);
  58. }
  59. Ok(Some(src.split_to(n)))
  60. }
  61. }
  62. impl Decoder for BytesCodec {
  63. type Item = BytesMut;
  64. type Error = io::Error;
  65. fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
  66. if self.raw {
  67. if !src.is_empty() {
  68. let len = src.len();
  69. return Ok(Some(src.split_to(len)));
  70. } else {
  71. return Ok(None);
  72. }
  73. }
  74. let n = match self.state {
  75. DecodeState::Head => match self.decode_head(src)? {
  76. Some(n) => {
  77. self.state = DecodeState::Data(n);
  78. n
  79. }
  80. None => return Ok(None),
  81. },
  82. DecodeState::Data(n) => n,
  83. };
  84. match self.decode_data(n, src)? {
  85. Some(data) => {
  86. self.state = DecodeState::Head;
  87. Ok(Some(data))
  88. }
  89. None => Ok(None),
  90. }
  91. }
  92. }
  93. impl Encoder<Bytes> for BytesCodec {
  94. type Error = io::Error;
  95. fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
  96. if self.raw {
  97. buf.reserve(data.len());
  98. buf.put(data);
  99. return Ok(());
  100. }
  101. if data.len() <= 0x3F {
  102. buf.put_u8((data.len() << 2) as u8);
  103. } else if data.len() <= 0x3FFF {
  104. buf.put_u16_le((data.len() << 2) as u16 | 0x1);
  105. } else if data.len() <= 0x3FFFFF {
  106. let h = (data.len() << 2) as u32 | 0x2;
  107. buf.put_u16_le((h & 0xFFFF) as u16);
  108. buf.put_u8((h >> 16) as u8);
  109. } else if data.len() <= 0x3FFFFFFF {
  110. buf.put_u32_le((data.len() << 2) as u32 | 0x3);
  111. } else {
  112. return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow"));
  113. }
  114. buf.extend(data);
  115. Ok(())
  116. }
  117. }
  118. #[cfg(test)]
  119. mod tests {
  120. use super::*;
  121. #[test]
  122. fn test_codec1() {
  123. let mut codec = BytesCodec::new();
  124. let mut buf = BytesMut::new();
  125. let mut bytes: Vec<u8> = Vec::new();
  126. bytes.resize(0x3F, 1);
  127. assert!(!codec.encode(bytes.into(), &mut buf).is_err());
  128. let buf_saved = buf.clone();
  129. assert_eq!(buf.len(), 0x3F + 1);
  130. if let Ok(Some(res)) = codec.decode(&mut buf) {
  131. assert_eq!(res.len(), 0x3F);
  132. assert_eq!(res[0], 1);
  133. } else {
  134. assert!(false);
  135. }
  136. let mut codec2 = BytesCodec::new();
  137. let mut buf2 = BytesMut::new();
  138. if let Ok(None) = codec2.decode(&mut buf2) {
  139. } else {
  140. assert!(false);
  141. }
  142. buf2.extend(&buf_saved[0..1]);
  143. if let Ok(None) = codec2.decode(&mut buf2) {
  144. } else {
  145. assert!(false);
  146. }
  147. buf2.extend(&buf_saved[1..]);
  148. if let Ok(Some(res)) = codec2.decode(&mut buf2) {
  149. assert_eq!(res.len(), 0x3F);
  150. assert_eq!(res[0], 1);
  151. } else {
  152. assert!(false);
  153. }
  154. }
  155. #[test]
  156. fn test_codec2() {
  157. let mut codec = BytesCodec::new();
  158. let mut buf = BytesMut::new();
  159. let mut bytes: Vec<u8> = Vec::new();
  160. assert!(!codec.encode("".into(), &mut buf).is_err());
  161. assert_eq!(buf.len(), 1);
  162. bytes.resize(0x3F + 1, 2);
  163. assert!(!codec.encode(bytes.into(), &mut buf).is_err());
  164. assert_eq!(buf.len(), 0x3F + 2 + 2);
  165. if let Ok(Some(res)) = codec.decode(&mut buf) {
  166. assert_eq!(res.len(), 0);
  167. } else {
  168. assert!(false);
  169. }
  170. if let Ok(Some(res)) = codec.decode(&mut buf) {
  171. assert_eq!(res.len(), 0x3F + 1);
  172. assert_eq!(res[0], 2);
  173. } else {
  174. assert!(false);
  175. }
  176. }
  177. #[test]
  178. fn test_codec3() {
  179. let mut codec = BytesCodec::new();
  180. let mut buf = BytesMut::new();
  181. let mut bytes: Vec<u8> = Vec::new();
  182. bytes.resize(0x3F - 1, 3);
  183. assert!(!codec.encode(bytes.into(), &mut buf).is_err());
  184. assert_eq!(buf.len(), 0x3F + 1 - 1);
  185. if let Ok(Some(res)) = codec.decode(&mut buf) {
  186. assert_eq!(res.len(), 0x3F - 1);
  187. assert_eq!(res[0], 3);
  188. } else {
  189. assert!(false);
  190. }
  191. }
  192. #[test]
  193. fn test_codec4() {
  194. let mut codec = BytesCodec::new();
  195. let mut buf = BytesMut::new();
  196. let mut bytes: Vec<u8> = Vec::new();
  197. bytes.resize(0x3FFF, 4);
  198. assert!(!codec.encode(bytes.into(), &mut buf).is_err());
  199. assert_eq!(buf.len(), 0x3FFF + 2);
  200. if let Ok(Some(res)) = codec.decode(&mut buf) {
  201. assert_eq!(res.len(), 0x3FFF);
  202. assert_eq!(res[0], 4);
  203. } else {
  204. assert!(false);
  205. }
  206. }
  207. #[test]
  208. fn test_codec5() {
  209. let mut codec = BytesCodec::new();
  210. let mut buf = BytesMut::new();
  211. let mut bytes: Vec<u8> = Vec::new();
  212. bytes.resize(0x3FFFFF, 5);
  213. assert!(!codec.encode(bytes.into(), &mut buf).is_err());
  214. assert_eq!(buf.len(), 0x3FFFFF + 3);
  215. if let Ok(Some(res)) = codec.decode(&mut buf) {
  216. assert_eq!(res.len(), 0x3FFFFF);
  217. assert_eq!(res[0], 5);
  218. } else {
  219. assert!(false);
  220. }
  221. }
  222. #[test]
  223. fn test_codec6() {
  224. let mut codec = BytesCodec::new();
  225. let mut buf = BytesMut::new();
  226. let mut bytes: Vec<u8> = Vec::new();
  227. bytes.resize(0x3FFFFF + 1, 6);
  228. assert!(!codec.encode(bytes.into(), &mut buf).is_err());
  229. let buf_saved = buf.clone();
  230. assert_eq!(buf.len(), 0x3FFFFF + 4 + 1);
  231. if let Ok(Some(res)) = codec.decode(&mut buf) {
  232. assert_eq!(res.len(), 0x3FFFFF + 1);
  233. assert_eq!(res[0], 6);
  234. } else {
  235. assert!(false);
  236. }
  237. let mut codec2 = BytesCodec::new();
  238. let mut buf2 = BytesMut::new();
  239. buf2.extend(&buf_saved[0..1]);
  240. if let Ok(None) = codec2.decode(&mut buf2) {
  241. } else {
  242. assert!(false);
  243. }
  244. buf2.extend(&buf_saved[1..6]);
  245. if let Ok(None) = codec2.decode(&mut buf2) {
  246. } else {
  247. assert!(false);
  248. }
  249. buf2.extend(&buf_saved[6..]);
  250. if let Ok(Some(res)) = codec2.decode(&mut buf2) {
  251. assert_eq!(res.len(), 0x3FFFFF + 1);
  252. assert_eq!(res[0], 6);
  253. } else {
  254. assert!(false);
  255. }
  256. }
  257. }