tower_lsp/
codec.rs

1//! Encoder and decoder for Language Server Protocol messages.
2
3use std::error::Error;
4use std::fmt::{self, Display, Formatter};
5use std::io::{Error as IoError, Write};
6use std::marker::PhantomData;
7use std::num::ParseIntError;
8use std::str::Utf8Error;
9
10use bytes::buf::BufMut;
11use bytes::{Buf, BytesMut};
12use memchr::memmem;
13use serde::{de::DeserializeOwned, Serialize};
14use tracing::{trace, warn};
15
16#[cfg(feature = "runtime-agnostic")]
17use async_codec_lite::{Decoder, Encoder};
18#[cfg(feature = "runtime-tokio")]
19use tokio_util::codec::{Decoder, Encoder};
20
21/// Errors that can occur when processing an LSP message.
22#[derive(Debug)]
23pub enum ParseError {
24    /// Failed to parse the JSON body.
25    Body(serde_json::Error),
26    /// Failed to encode the response.
27    Encode(IoError),
28    /// Failed to parse headers.
29    Headers(httparse::Error),
30    /// The media type in the `Content-Type` header is invalid.
31    InvalidContentType,
32    /// The length value in the `Content-Length` header is invalid.
33    InvalidContentLength(ParseIntError),
34    /// Request lacks the required `Content-Length` header.
35    MissingContentLength,
36    /// Request contains invalid UTF8.
37    Utf8(Utf8Error),
38}
39
40impl Display for ParseError {
41    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
42        match *self {
43            ParseError::Body(ref e) => write!(f, "unable to parse JSON body: {e}"),
44            ParseError::Encode(ref e) => write!(f, "failed to encode response: {e}"),
45            ParseError::Headers(ref e) => write!(f, "failed to parse headers: {e}"),
46            ParseError::InvalidContentType => write!(f, "unable to parse content type"),
47            ParseError::InvalidContentLength(ref e) => {
48                write!(f, "unable to parse content length: {e}")
49            }
50            ParseError::MissingContentLength => {
51                write!(f, "missing required `Content-Length` header")
52            }
53            ParseError::Utf8(ref e) => write!(f, "request contains invalid UTF8: {e}"),
54        }
55    }
56}
57
58impl Error for ParseError {
59    fn source(&self) -> Option<&(dyn Error + 'static)> {
60        match *self {
61            ParseError::Body(ref e) => Some(e),
62            ParseError::Encode(ref e) => Some(e),
63            ParseError::InvalidContentLength(ref e) => Some(e),
64            ParseError::Utf8(ref e) => Some(e),
65            _ => None,
66        }
67    }
68}
69
70impl From<serde_json::Error> for ParseError {
71    fn from(error: serde_json::Error) -> Self {
72        ParseError::Body(error)
73    }
74}
75
76impl From<IoError> for ParseError {
77    fn from(error: IoError) -> Self {
78        ParseError::Encode(error)
79    }
80}
81
82impl From<httparse::Error> for ParseError {
83    fn from(error: httparse::Error) -> Self {
84        ParseError::Headers(error)
85    }
86}
87
88impl From<ParseIntError> for ParseError {
89    fn from(error: ParseIntError) -> Self {
90        ParseError::InvalidContentLength(error)
91    }
92}
93
94impl From<Utf8Error> for ParseError {
95    fn from(error: Utf8Error) -> Self {
96        ParseError::Utf8(error)
97    }
98}
99
100/// Encodes and decodes Language Server Protocol messages.
101pub struct LanguageServerCodec<T> {
102    content_len: Option<usize>,
103    _marker: PhantomData<T>,
104}
105
106impl<T> Default for LanguageServerCodec<T> {
107    fn default() -> Self {
108        LanguageServerCodec {
109            content_len: None,
110            _marker: PhantomData,
111        }
112    }
113}
114
115#[cfg(feature = "runtime-agnostic")]
116impl<T: Serialize> Encoder for LanguageServerCodec<T> {
117    type Item = T;
118    type Error = ParseError;
119
120    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
121        let msg = serde_json::to_string(&item)?;
122        trace!("-> {}", msg);
123
124        // Reserve just enough space to hold the `Content-Length: ` and `\r\n\r\n` constants,
125        // the length of the message, and the message body.
126        dst.reserve(msg.len() + number_of_digits(msg.len()) + 20);
127        let mut writer = dst.writer();
128        write!(writer, "Content-Length: {}\r\n\r\n{}", msg.len(), msg)?;
129        writer.flush()?;
130
131        Ok(())
132    }
133}
134
135#[cfg(feature = "runtime-tokio")]
136impl<T: Serialize> Encoder<T> for LanguageServerCodec<T> {
137    type Error = ParseError;
138
139    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
140        let msg = serde_json::to_string(&item)?;
141        trace!("-> {}", msg);
142
143        // Reserve just enough space to hold the `Content-Length: ` and `\r\n\r\n` constants,
144        // the length of the message, and the message body.
145        dst.reserve(msg.len() + number_of_digits(msg.len()) + 20);
146        let mut writer = dst.writer();
147        write!(writer, "Content-Length: {}\r\n\r\n{}", msg.len(), msg)?;
148        writer.flush()?;
149
150        Ok(())
151    }
152}
153
154#[inline]
155fn number_of_digits(mut n: usize) -> usize {
156    let mut num_digits = 0;
157
158    while n > 0 {
159        n /= 10;
160        num_digits += 1;
161    }
162
163    num_digits
164}
165
166impl<T: DeserializeOwned> Decoder for LanguageServerCodec<T> {
167    type Item = T;
168    type Error = ParseError;
169
170    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
171        if let Some(content_len) = self.content_len {
172            if src.len() < content_len {
173                return Ok(None);
174            }
175
176            let bytes = &src[..content_len];
177            let message = std::str::from_utf8(bytes)?;
178
179            let result = if message.is_empty() {
180                Ok(None)
181            } else {
182                trace!("<- {}", message);
183                match serde_json::from_str(message) {
184                    Ok(parsed) => Ok(Some(parsed)),
185                    Err(err) => Err(err.into()),
186                }
187            };
188
189            src.advance(content_len);
190            self.content_len = None; // Reset state in preparation for parsing next message.
191
192            result
193        } else {
194            let mut dst = [httparse::EMPTY_HEADER; 2];
195
196            let (headers_len, headers) = match httparse::parse_headers(src, &mut dst)? {
197                httparse::Status::Complete(output) => output,
198                httparse::Status::Partial => return Ok(None),
199            };
200
201            match decode_headers(headers) {
202                Ok(content_len) => {
203                    src.advance(headers_len);
204                    self.content_len = Some(content_len);
205                    self.decode(src) // Recurse right back in, now that `Content-Length` is known.
206                }
207                Err(err) => {
208                    match err {
209                        ParseError::MissingContentLength => {}
210                        _ => src.advance(headers_len),
211                    }
212
213                    // Skip any garbage bytes by scanning ahead for another potential message.
214                    src.advance(memmem::find(src, b"Content-Length").unwrap_or_default());
215                    Err(err)
216                }
217            }
218        }
219    }
220}
221
222fn decode_headers(headers: &[httparse::Header<'_>]) -> Result<usize, ParseError> {
223    let mut content_len = None;
224
225    for header in headers {
226        match header.name {
227            "Content-Length" => {
228                let string = std::str::from_utf8(header.value)?;
229                let parsed_len = string.parse()?;
230                content_len = Some(parsed_len);
231            }
232            "Content-Type" => {
233                let string = std::str::from_utf8(header.value)?;
234                let charset = string
235                    .split(';')
236                    .skip(1)
237                    .map(|param| param.trim())
238                    .find_map(|param| param.strip_prefix("charset="));
239
240                match charset {
241                    Some("utf-8") | Some("utf8") => {}
242                    _ => return Err(ParseError::InvalidContentType),
243                }
244            }
245            other => warn!("encountered unsupported header: {:?}", other),
246        }
247    }
248
249    if let Some(content_len) = content_len {
250        Ok(content_len)
251    } else {
252        Err(ParseError::MissingContentLength)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use bytes::BytesMut;
259    use serde_json::Value;
260
261    use super::*;
262
263    macro_rules! assert_err {
264        ($expression:expr, $($pattern:tt)+) => {
265            match $expression {
266                $($pattern)+ => (),
267                ref e => panic!("expected `{}` but got `{:?}`", stringify!($($pattern)+), e),
268            }
269        }
270    }
271
272    fn encode_message(content_type: Option<&str>, message: &str) -> String {
273        let content_type = content_type
274            .map(|ty| format!("\r\nContent-Type: {ty}"))
275            .unwrap_or_default();
276
277        format!(
278            "Content-Length: {}{}\r\n\r\n{}",
279            message.len(),
280            content_type,
281            message
282        )
283    }
284
285    #[test]
286    fn encode_and_decode() {
287        let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
288        let encoded = encode_message(None, decoded);
289
290        let mut codec = LanguageServerCodec::default();
291        let mut buffer = BytesMut::new();
292        let item: Value = serde_json::from_str(decoded).unwrap();
293        codec.encode(item, &mut buffer).unwrap();
294        assert_eq!(buffer, BytesMut::from(encoded.as_str()));
295
296        let mut buffer = BytesMut::from(encoded.as_str());
297        let message = codec.decode(&mut buffer).unwrap();
298        let decoded = serde_json::from_str(decoded).unwrap();
299        assert_eq!(message, Some(decoded));
300    }
301
302    #[test]
303    fn decodes_optional_content_type() {
304        let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
305        let content_type = "application/vscode-jsonrpc; charset=utf-8";
306        let encoded = encode_message(Some(content_type), decoded);
307
308        let mut codec = LanguageServerCodec::default();
309        let mut buffer = BytesMut::from(encoded.as_str());
310        let message = codec.decode(&mut buffer).unwrap();
311        let decoded_: Value = serde_json::from_str(decoded).unwrap();
312        assert_eq!(message, Some(decoded_));
313
314        let content_type = "application/vscode-jsonrpc; charset=utf8";
315        let encoded = encode_message(Some(content_type), decoded);
316
317        let mut buffer = BytesMut::from(encoded.as_str());
318        let message = codec.decode(&mut buffer).unwrap();
319        let decoded_: Value = serde_json::from_str(decoded).unwrap();
320        assert_eq!(message, Some(decoded_));
321
322        let content_type = "application/vscode-jsonrpc; charset=invalid";
323        let encoded = encode_message(Some(content_type), decoded);
324
325        let mut buffer = BytesMut::from(encoded.as_str());
326        assert_err!(
327            codec.decode(&mut buffer),
328            Err(ParseError::InvalidContentType)
329        );
330
331        let content_type = "application/vscode-jsonrpc";
332        let encoded = encode_message(Some(content_type), decoded);
333
334        let mut buffer = BytesMut::from(encoded.as_str());
335        assert_err!(
336            codec.decode(&mut buffer),
337            Err(ParseError::InvalidContentType)
338        );
339
340        let content_type = "this-mime-should-be-ignored; charset=utf8";
341        let encoded = encode_message(Some(content_type), decoded);
342
343        let mut buffer = BytesMut::from(encoded.as_str());
344        let message = codec.decode(&mut buffer).unwrap();
345        let decoded_: Value = serde_json::from_str(decoded).unwrap();
346        assert_eq!(message, Some(decoded_));
347    }
348
349    #[test]
350    fn decodes_zero_length_message() {
351        let content_type = "application/vscode-jsonrpc; charset=utf-8";
352        let encoded = encode_message(Some(content_type), "");
353
354        let mut codec = LanguageServerCodec::default();
355        let mut buffer = BytesMut::from(encoded.as_str());
356        let message: Option<Value> = codec.decode(&mut buffer).unwrap();
357        assert_eq!(message, None);
358    }
359
360    #[test]
361    fn recovers_from_parse_error() {
362        let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
363        let encoded = encode_message(None, decoded);
364        let mixed = format!("foobar{encoded}Content-Length: foobar\r\n\r\n{encoded}");
365
366        let mut codec = LanguageServerCodec::default();
367        let mut buffer = BytesMut::from(mixed.as_str());
368        assert_err!(
369            codec.decode(&mut buffer),
370            Err(ParseError::MissingContentLength)
371        );
372
373        let message: Option<Value> = codec.decode(&mut buffer).unwrap();
374        let first_valid = serde_json::from_str(decoded).unwrap();
375        assert_eq!(message, Some(first_valid));
376        assert_err!(
377            codec.decode(&mut buffer),
378            Err(ParseError::InvalidContentLength(_))
379        );
380
381        let message = codec.decode(&mut buffer).unwrap();
382        let second_valid = serde_json::from_str(decoded).unwrap();
383        assert_eq!(message, Some(second_valid));
384
385        let message = codec.decode(&mut buffer).unwrap();
386        assert_eq!(message, None);
387    }
388
389    #[test]
390    fn decodes_small_chunks() {
391        let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#;
392        let content_type = "application/vscode-jsonrpc; charset=utf-8";
393        let encoded = encode_message(Some(content_type), decoded);
394
395        let mut codec = LanguageServerCodec::default();
396        let mut buffer = BytesMut::from(encoded.as_str());
397
398        let rest = buffer.split_off(40);
399        let message = codec.decode(&mut buffer).unwrap();
400        assert_eq!(message, None);
401        buffer.unsplit(rest);
402
403        let rest = buffer.split_off(80);
404        let message = codec.decode(&mut buffer).unwrap();
405        assert_eq!(message, None);
406        buffer.unsplit(rest);
407
408        let rest = buffer.split_off(16);
409        let message = codec.decode(&mut buffer).unwrap();
410        assert_eq!(message, None);
411        buffer.unsplit(rest);
412
413        let decoded: Value = serde_json::from_str(decoded).unwrap();
414        let message = codec.decode(&mut buffer).unwrap();
415        assert_eq!(message, Some(decoded));
416    }
417}