bytes_codec.rs 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. use bytes::{Buf, BufMut, Bytes, BytesMut};
  2. use std::io;
  3. use tokio_util::codec::{Decoder, Encoder};
  4. #[derive(Debug, Clone, Copy)]
  5. pub struct BytesCodec {
  6. state: DecodeState,
  7. raw: bool,
  8. max_packet_length: usize,
  9. }
  10. #[derive(Debug, Clone, Copy)]
  11. enum DecodeState {
  12. Head,
  13. Data(usize),
  14. }
  15. impl Default for BytesCodec {
  16. fn default() -> Self {
  17. Self::new()
  18. }
  19. }
  20. impl BytesCodec {
  21. pub fn new() -> Self {
  22. Self {
  23. state: DecodeState::Head,
  24. raw: false,
  25. max_packet_length: usize::MAX,
  26. }
  27. }
  28. pub fn set_raw(&mut self) {
  29. self.raw = true;
  30. }
  31. pub fn set_max_packet_length(&mut self, n: usize) {
  32. self.max_packet_length = n;
  33. }
  34. fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
  35. if src.is_empty() {
  36. return Ok(None);
  37. }
  38. let head_len = ((src[0] & 0x3) + 1) as usize;
  39. if src.len() < head_len {
  40. return Ok(None);
  41. }
  42. let mut n = src[0] as usize;
  43. if head_len > 1 {
  44. n |= (src[1] as usize) << 8;
  45. }
  46. if head_len > 2 {
  47. n |= (src[2] as usize) << 16;
  48. }
  49. if head_len > 3 {
  50. n |= (src[3] as usize) << 24;
  51. }
  52. n >>= 2;
  53. if n > self.max_packet_length {
  54. return Err(io::Error::new(io::ErrorKind::InvalidData, "Too big packet"));
  55. }
  56. src.advance(head_len);
  57. src.reserve(n);
  58. Ok(Some(n))
  59. }
  60. fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
  61. if src.len() < n {
  62. return Ok(None);
  63. }
  64. Ok(Some(src.split_to(n)))
  65. }
  66. }
  67. impl Decoder for BytesCodec {
  68. type Item = BytesMut;
  69. type Error = io::Error;
  70. fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BytesMut>, io::Error> {
  71. if self.raw {
  72. if !src.is_empty() {
  73. let len = src.len();
  74. return Ok(Some(src.split_to(len)));
  75. } else {
  76. return Ok(None);
  77. }
  78. }
  79. let n = match self.state {
  80. DecodeState::Head => match self.decode_head(src)? {
  81. Some(n) => {
  82. self.state = DecodeState::Data(n);
  83. n
  84. }
  85. None => return Ok(None),
  86. },
  87. DecodeState::Data(n) => n,
  88. };
  89. match self.decode_data(n, src)? {
  90. Some(data) => {
  91. self.state = DecodeState::Head;
  92. Ok(Some(data))
  93. }
  94. None => Ok(None),
  95. }
  96. }
  97. }
  98. impl Encoder<Bytes> for BytesCodec {
  99. type Error = io::Error;
  100. fn encode(&mut self, data: Bytes, buf: &mut BytesMut) -> Result<(), io::Error> {
  101. if self.raw {
  102. buf.reserve(data.len());
  103. buf.put(data);
  104. return Ok(());
  105. }
  106. if data.len() <= 0x3F {
  107. buf.put_u8((data.len() << 2) as u8);
  108. } else if data.len() <= 0x3FFF {
  109. buf.put_u16_le((data.len() << 2) as u16 | 0x1);
  110. } else if data.len() <= 0x3FFFFF {
  111. let h = (data.len() << 2) as u32 | 0x2;
  112. buf.put_u16_le((h & 0xFFFF) as u16);
  113. buf.put_u8((h >> 16) as u8);
  114. } else if data.len() <= 0x3FFFFFFF {
  115. buf.put_u32_le((data.len() << 2) as u32 | 0x3);
  116. } else {
  117. return Err(io::Error::new(io::ErrorKind::InvalidInput, "Overflow"));
  118. }
  119. buf.extend(data);
  120. Ok(())
  121. }
  122. }
  123. #[cfg(test)]
  124. mod tests {
  125. use super::*;
  126. #[test]
  127. fn test_codec1() {
  128. let mut codec = BytesCodec::new();
  129. let mut buf = BytesMut::new();
  130. let mut bytes: Vec<u8> = Vec::new();
  131. bytes.resize(0x3F, 1);
  132. assert!(codec.encode(bytes.into(), &mut buf).is_ok());
  133. let buf_saved = buf.clone();
  134. assert_eq!(buf.len(), 0x3F + 1);
  135. if let Ok(Some(res)) = codec.decode(&mut buf) {
  136. assert_eq!(res.len(), 0x3F);
  137. assert_eq!(res[0], 1);
  138. } else {
  139. panic!();
  140. }
  141. let mut codec2 = BytesCodec::new();
  142. let mut buf2 = BytesMut::new();
  143. if let Ok(None) = codec2.decode(&mut buf2) {
  144. } else {
  145. panic!();
  146. }
  147. buf2.extend(&buf_saved[0..1]);
  148. if let Ok(None) = codec2.decode(&mut buf2) {
  149. } else {
  150. panic!();
  151. }
  152. buf2.extend(&buf_saved[1..]);
  153. if let Ok(Some(res)) = codec2.decode(&mut buf2) {
  154. assert_eq!(res.len(), 0x3F);
  155. assert_eq!(res[0], 1);
  156. } else {
  157. panic!();
  158. }
  159. }
  160. #[test]
  161. fn test_codec2() {
  162. let mut codec = BytesCodec::new();
  163. let mut buf = BytesMut::new();
  164. let mut bytes: Vec<u8> = Vec::new();
  165. assert!(codec.encode("".into(), &mut buf).is_ok());
  166. assert_eq!(buf.len(), 1);
  167. bytes.resize(0x3F + 1, 2);
  168. assert!(codec.encode(bytes.into(), &mut buf).is_ok());
  169. assert_eq!(buf.len(), 0x3F + 2 + 2);
  170. if let Ok(Some(res)) = codec.decode(&mut buf) {
  171. assert_eq!(res.len(), 0);
  172. } else {
  173. panic!();
  174. }
  175. if let Ok(Some(res)) = codec.decode(&mut buf) {
  176. assert_eq!(res.len(), 0x3F + 1);
  177. assert_eq!(res[0], 2);
  178. } else {
  179. panic!();
  180. }
  181. }
  182. #[test]
  183. fn test_codec3() {
  184. let mut codec = BytesCodec::new();
  185. let mut buf = BytesMut::new();
  186. let mut bytes: Vec<u8> = Vec::new();
  187. bytes.resize(0x3F - 1, 3);
  188. assert!(codec.encode(bytes.into(), &mut buf).is_ok());
  189. assert_eq!(buf.len(), 0x3F + 1 - 1);
  190. if let Ok(Some(res)) = codec.decode(&mut buf) {
  191. assert_eq!(res.len(), 0x3F - 1);
  192. assert_eq!(res[0], 3);
  193. } else {
  194. panic!();
  195. }
  196. }
  197. #[test]
  198. fn test_codec4() {
  199. let mut codec = BytesCodec::new();
  200. let mut buf = BytesMut::new();
  201. let mut bytes: Vec<u8> = Vec::new();
  202. bytes.resize(0x3FFF, 4);
  203. assert!(codec.encode(bytes.into(), &mut buf).is_ok());
  204. assert_eq!(buf.len(), 0x3FFF + 2);
  205. if let Ok(Some(res)) = codec.decode(&mut buf) {
  206. assert_eq!(res.len(), 0x3FFF);
  207. assert_eq!(res[0], 4);
  208. } else {
  209. panic!();
  210. }
  211. }
  212. #[test]
  213. fn test_codec5() {
  214. let mut codec = BytesCodec::new();
  215. let mut buf = BytesMut::new();
  216. let mut bytes: Vec<u8> = Vec::new();
  217. bytes.resize(0x3FFFFF, 5);
  218. assert!(codec.encode(bytes.into(), &mut buf).is_ok());
  219. assert_eq!(buf.len(), 0x3FFFFF + 3);
  220. if let Ok(Some(res)) = codec.decode(&mut buf) {
  221. assert_eq!(res.len(), 0x3FFFFF);
  222. assert_eq!(res[0], 5);
  223. } else {
  224. panic!();
  225. }
  226. }
  227. #[test]
  228. fn test_codec6() {
  229. let mut codec = BytesCodec::new();
  230. let mut buf = BytesMut::new();
  231. let mut bytes: Vec<u8> = Vec::new();
  232. bytes.resize(0x3FFFFF + 1, 6);
  233. assert!(codec.encode(bytes.into(), &mut buf).is_ok());
  234. let buf_saved = buf.clone();
  235. assert_eq!(buf.len(), 0x3FFFFF + 4 + 1);
  236. if let Ok(Some(res)) = codec.decode(&mut buf) {
  237. assert_eq!(res.len(), 0x3FFFFF + 1);
  238. assert_eq!(res[0], 6);
  239. } else {
  240. panic!();
  241. }
  242. let mut codec2 = BytesCodec::new();
  243. let mut buf2 = BytesMut::new();
  244. buf2.extend(&buf_saved[0..1]);
  245. if let Ok(None) = codec2.decode(&mut buf2) {
  246. } else {
  247. panic!();
  248. }
  249. buf2.extend(&buf_saved[1..6]);
  250. if let Ok(None) = codec2.decode(&mut buf2) {
  251. } else {
  252. panic!();
  253. }
  254. buf2.extend(&buf_saved[6..]);
  255. if let Ok(Some(res)) = codec2.decode(&mut buf2) {
  256. assert_eq!(res.len(), 0x3FFFFF + 1);
  257. assert_eq!(res[0], 6);
  258. } else {
  259. panic!();
  260. }
  261. }
  262. }