1use std::error::Error;
4use std::fmt::{self, Display, Formatter};
5use std::io::{Error as IoError, Write};
6use std::marker::PhantomData;
7use std::num::ParseIntError;
8use std::str::Utf8Error;
9
10use bytes::buf::BufMut;
11use bytes::{Buf, BytesMut};
12use memchr::memmem;
13use serde::{de::DeserializeOwned, Serialize};
14use tracing::{trace, warn};
15
16#[cfg(feature = "runtime-agnostic")]
17use async_codec_lite::{Decoder, Encoder};
18#[cfg(feature = "runtime-tokio")]
19use tokio_util::codec::{Decoder, Encoder};
20
21#[derive(Debug)]
23pub enum ParseError {
24 Body(serde_json::Error),
26 Encode(IoError),
28 Headers(httparse::Error),
30 InvalidContentType,
32 InvalidContentLength(ParseIntError),
34 MissingContentLength,
36 Utf8(Utf8Error),
38}
39
40impl Display for ParseError {
41 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
42 match *self {
43 ParseError::Body(ref e) => write!(f, "unable to parse JSON body: {e}"),
44 ParseError::Encode(ref e) => write!(f, "failed to encode response: {e}"),
45 ParseError::Headers(ref e) => write!(f, "failed to parse headers: {e}"),
46 ParseError::InvalidContentType => write!(f, "unable to parse content type"),
47 ParseError::InvalidContentLength(ref e) => {
48 write!(f, "unable to parse content length: {e}")
49 }
50 ParseError::MissingContentLength => {
51 write!(f, "missing required `Content-Length` header")
52 }
53 ParseError::Utf8(ref e) => write!(f, "request contains invalid UTF8: {e}"),
54 }
55 }
56}
57
58impl Error for ParseError {
59 fn source(&self) -> Option<&(dyn Error + 'static)> {
60 match *self {
61 ParseError::Body(ref e) => Some(e),
62 ParseError::Encode(ref e) => Some(e),
63 ParseError::InvalidContentLength(ref e) => Some(e),
64 ParseError::Utf8(ref e) => Some(e),
65 _ => None,
66 }
67 }
68}
69
70impl From<serde_json::Error> for ParseError {
71 fn from(error: serde_json::Error) -> Self {
72 ParseError::Body(error)
73 }
74}
75
76impl From<IoError> for ParseError {
77 fn from(error: IoError) -> Self {
78 ParseError::Encode(error)
79 }
80}
81
82impl From<httparse::Error> for ParseError {
83 fn from(error: httparse::Error) -> Self {
84 ParseError::Headers(error)
85 }
86}
87
88impl From<ParseIntError> for ParseError {
89 fn from(error: ParseIntError) -> Self {
90 ParseError::InvalidContentLength(error)
91 }
92}
93
94impl From<Utf8Error> for ParseError {
95 fn from(error: Utf8Error) -> Self {
96 ParseError::Utf8(error)
97 }
98}
99
100pub struct LanguageServerCodec<T> {
102 content_len: Option<usize>,
103 _marker: PhantomData<T>,
104}
105
106impl<T> Default for LanguageServerCodec<T> {
107 fn default() -> Self {
108 LanguageServerCodec {
109 content_len: None,
110 _marker: PhantomData,
111 }
112 }
113}
114
115#[cfg(feature = "runtime-agnostic")]
116impl<T: Serialize> Encoder for LanguageServerCodec<T> {
117 type Item = T;
118 type Error = ParseError;
119
120 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
121 let msg = serde_json::to_string(&item)?;
122 trace!("-> {}", msg);
123
124 dst.reserve(msg.len() + number_of_digits(msg.len()) + 20);
127 let mut writer = dst.writer();
128 write!(writer, "Content-Length: {}\r\n\r\n{}", msg.len(), msg)?;
129 writer.flush()?;
130
131 Ok(())
132 }
133}
134
135#[cfg(feature = "runtime-tokio")]
136impl<T: Serialize> Encoder<T> for LanguageServerCodec<T> {
137 type Error = ParseError;
138
139 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
140 let msg = serde_json::to_string(&item)?;
141 trace!("-> {}", msg);
142
143 dst.reserve(msg.len() + number_of_digits(msg.len()) + 20);
146 let mut writer = dst.writer();
147 write!(writer, "Content-Length: {}\r\n\r\n{}", msg.len(), msg)?;
148 writer.flush()?;
149
150 Ok(())
151 }
152}
153
154#[inline]
155fn number_of_digits(mut n: usize) -> usize {
156 let mut num_digits = 0;
157
158 while n > 0 {
159 n /= 10;
160 num_digits += 1;
161 }
162
163 num_digits
164}
165
166impl<T: DeserializeOwned> Decoder for LanguageServerCodec<T> {
167 type Item = T;
168 type Error = ParseError;
169
170 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
171 if let Some(content_len) = self.content_len {
172 if src.len() < content_len {
173 return Ok(None);
174 }
175
176 let bytes = &src[..content_len];
177 let message = std::str::from_utf8(bytes)?;
178
179 let result = if message.is_empty() {
180 Ok(None)
181 } else {
182 trace!("<- {}", message);
183 match serde_json::from_str(message) {
184 Ok(parsed) => Ok(Some(parsed)),
185 Err(err) => Err(err.into()),
186 }
187 };
188
189 src.advance(content_len);
190 self.content_len = None; result
193 } else {
194 let mut dst = [httparse::EMPTY_HEADER; 2];
195
196 let (headers_len, headers) = match httparse::parse_headers(src, &mut dst)? {
197 httparse::Status::Complete(output) => output,
198 httparse::Status::Partial => return Ok(None),
199 };
200
201 match decode_headers(headers) {
202 Ok(content_len) => {
203 src.advance(headers_len);
204 self.content_len = Some(content_len);
205 self.decode(src) }
207 Err(err) => {
208 match err {
209 ParseError::MissingContentLength => {}
210 _ => src.advance(headers_len),
211 }
212
213 src.advance(memmem::find(src, b"Content-Length").unwrap_or_default());
215 Err(err)
216 }
217 }
218 }
219 }
220}
221
222fn decode_headers(headers: &[httparse::Header<'_>]) -> Result<usize, ParseError> {
223 let mut content_len = None;
224
225 for header in headers {
226 match header.name {
227 "Content-Length" => {
228 let string = std::str::from_utf8(header.value)?;
229 let parsed_len = string.parse()?;
230 content_len = Some(parsed_len);
231 }
232 "Content-Type" => {
233 let string = std::str::from_utf8(header.value)?;
234 let charset = string
235 .split(';')
236 .skip(1)
237 .map(|param| param.trim())
238 .find_map(|param| param.strip_prefix("charset="));
239
240 match charset {
241 Some("utf-8") | Some("utf8") => {}
242 _ => return Err(ParseError::InvalidContentType),
243 }
244 }
245 other => warn!("encountered unsupported header: {:?}", other),
246 }
247 }
248
249 if let Some(content_len) = content_len {
250 Ok(content_len)
251 } else {
252 Err(ParseError::MissingContentLength)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use bytes::BytesMut;
259 use serde_json::Value;
260
261 use super::*;
262
263 macro_rules! assert_err {
264 ($expression:expr, $($pattern:tt)+) => {
265 match $expression {
266 $($pattern)+ => (),
267 ref e => panic!("expected `{}` but got `{:?}`", stringify!($($pattern)+), e),
268 }
269 }
270 }
271
272 fn encode_message(content_type: Option<&str>, message: &str) -> String {
273 let content_type = content_type
274 .map(|ty| format!("\r\nContent-Type: {ty}"))
275 .unwrap_or_default();
276
277 format!(
278 "Content-Length: {}{}\r\n\r\n{}",
279 message.len(),
280 content_type,
281 message
282 )
283 }
284
285 #[test]
286 fn encode_and_decode() {
287 let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
288 let encoded = encode_message(None, decoded);
289
290 let mut codec = LanguageServerCodec::default();
291 let mut buffer = BytesMut::new();
292 let item: Value = serde_json::from_str(decoded).unwrap();
293 codec.encode(item, &mut buffer).unwrap();
294 assert_eq!(buffer, BytesMut::from(encoded.as_str()));
295
296 let mut buffer = BytesMut::from(encoded.as_str());
297 let message = codec.decode(&mut buffer).unwrap();
298 let decoded = serde_json::from_str(decoded).unwrap();
299 assert_eq!(message, Some(decoded));
300 }
301
302 #[test]
303 fn decodes_optional_content_type() {
304 let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
305 let content_type = "application/vscode-jsonrpc; charset=utf-8";
306 let encoded = encode_message(Some(content_type), decoded);
307
308 let mut codec = LanguageServerCodec::default();
309 let mut buffer = BytesMut::from(encoded.as_str());
310 let message = codec.decode(&mut buffer).unwrap();
311 let decoded_: Value = serde_json::from_str(decoded).unwrap();
312 assert_eq!(message, Some(decoded_));
313
314 let content_type = "application/vscode-jsonrpc; charset=utf8";
315 let encoded = encode_message(Some(content_type), decoded);
316
317 let mut buffer = BytesMut::from(encoded.as_str());
318 let message = codec.decode(&mut buffer).unwrap();
319 let decoded_: Value = serde_json::from_str(decoded).unwrap();
320 assert_eq!(message, Some(decoded_));
321
322 let content_type = "application/vscode-jsonrpc; charset=invalid";
323 let encoded = encode_message(Some(content_type), decoded);
324
325 let mut buffer = BytesMut::from(encoded.as_str());
326 assert_err!(
327 codec.decode(&mut buffer),
328 Err(ParseError::InvalidContentType)
329 );
330
331 let content_type = "application/vscode-jsonrpc";
332 let encoded = encode_message(Some(content_type), decoded);
333
334 let mut buffer = BytesMut::from(encoded.as_str());
335 assert_err!(
336 codec.decode(&mut buffer),
337 Err(ParseError::InvalidContentType)
338 );
339
340 let content_type = "this-mime-should-be-ignored; charset=utf8";
341 let encoded = encode_message(Some(content_type), decoded);
342
343 let mut buffer = BytesMut::from(encoded.as_str());
344 let message = codec.decode(&mut buffer).unwrap();
345 let decoded_: Value = serde_json::from_str(decoded).unwrap();
346 assert_eq!(message, Some(decoded_));
347 }
348
349 #[test]
350 fn decodes_zero_length_message() {
351 let content_type = "application/vscode-jsonrpc; charset=utf-8";
352 let encoded = encode_message(Some(content_type), "");
353
354 let mut codec = LanguageServerCodec::default();
355 let mut buffer = BytesMut::from(encoded.as_str());
356 let message: Option<Value> = codec.decode(&mut buffer).unwrap();
357 assert_eq!(message, None);
358 }
359
360 #[test]
361 fn recovers_from_parse_error() {
362 let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
363 let encoded = encode_message(None, decoded);
364 let mixed = format!("foobar{encoded}Content-Length: foobar\r\n\r\n{encoded}");
365
366 let mut codec = LanguageServerCodec::default();
367 let mut buffer = BytesMut::from(mixed.as_str());
368 assert_err!(
369 codec.decode(&mut buffer),
370 Err(ParseError::MissingContentLength)
371 );
372
373 let message: Option<Value> = codec.decode(&mut buffer).unwrap();
374 let first_valid = serde_json::from_str(decoded).unwrap();
375 assert_eq!(message, Some(first_valid));
376 assert_err!(
377 codec.decode(&mut buffer),
378 Err(ParseError::InvalidContentLength(_))
379 );
380
381 let message = codec.decode(&mut buffer).unwrap();
382 let second_valid = serde_json::from_str(decoded).unwrap();
383 assert_eq!(message, Some(second_valid));
384
385 let message = codec.decode(&mut buffer).unwrap();
386 assert_eq!(message, None);
387 }
388
389 #[test]
390 fn decodes_small_chunks() {
391 let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
392 let content_type = "application/vscode-jsonrpc; charset=utf-8";
393 let encoded = encode_message(Some(content_type), decoded);
394
395 let mut codec = LanguageServerCodec::default();
396 let mut buffer = BytesMut::from(encoded.as_str());
397
398 let rest = buffer.split_off(40);
399 let message = codec.decode(&mut buffer).unwrap();
400 assert_eq!(message, None);
401 buffer.unsplit(rest);
402
403 let rest = buffer.split_off(80);
404 let message = codec.decode(&mut buffer).unwrap();
405 assert_eq!(message, None);
406 buffer.unsplit(rest);
407
408 let rest = buffer.split_off(16);
409 let message = codec.decode(&mut buffer).unwrap();
410 assert_eq!(message, None);
411 buffer.unsplit(rest);
412
413 let decoded: Value = serde_json::from_str(decoded).unwrap();
414 let message = codec.decode(&mut buffer).unwrap();
415 assert_eq!(message, Some(decoded));
416 }
417}