1#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{future, join, stream, FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt};
15use tower::Service;
16use tracing::error;
17
18use crate::codec::{LanguageServerCodec, ParseError};
19use crate::jsonrpc::{Error, Id, Message, Request, Response};
20use crate::service::{ClientSocket, RequestStream, ResponseSink};
21
22const DEFAULT_MAX_CONCURRENCY: usize = 4;
23const MESSAGE_QUEUE_SIZE: usize = 100;
24
25pub trait Loopback {
29 type RequestStream: Stream<Item = Request>;
31 type ResponseSink: Sink<Response> + Unpin;
33
34 fn split(self) -> (Self::RequestStream, Self::ResponseSink);
38}
39
40impl Loopback for ClientSocket {
41 type RequestStream = RequestStream;
42 type ResponseSink = ResponseSink;
43
44 #[inline]
45 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
46 self.split()
47 }
48}
49
50#[derive(Debug)]
52pub struct Server<I, O, L = ClientSocket> {
53 stdin: I,
54 stdout: O,
55 loopback: L,
56 max_concurrency: usize,
57}
58
59impl<I, O, L> Server<I, O, L>
60where
61 I: AsyncRead + Unpin,
62 O: AsyncWrite,
63 L: Loopback,
64 <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
65{
66 pub fn new(stdin: I, stdout: O, socket: L) -> Self {
68 Server {
69 stdin,
70 stdout,
71 loopback: socket,
72 max_concurrency: DEFAULT_MAX_CONCURRENCY,
73 }
74 }
75
76 pub fn concurrency_level(mut self, max: usize) -> Self {
97 self.max_concurrency = max;
98 self
99 }
100
101 pub async fn serve<T>(self, mut service: T)
103 where
104 T: Service<Request, Response = Option<Response>> + Send + 'static,
105 T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
106 T::Future: Send,
107 {
108 let (client_requests, mut client_responses) = self.loopback.split();
109 let (client_requests, client_abort) = stream::abortable(client_requests);
110 let (mut responses_tx, responses_rx) = mpsc::channel(0);
111 let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
112
113 let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
114 let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
115
116 let process_server_tasks = server_tasks_rx
117 .buffer_unordered(self.max_concurrency)
118 .filter_map(future::ready)
119 .map(|res| Ok(Message::Response(res)))
120 .forward(responses_tx.clone().sink_map_err(|_| unreachable!()))
121 .map(|_| ());
122
123 let print_output = stream::select(responses_rx, client_requests.map(Message::Request))
124 .map(Ok)
125 .forward(framed_stdout.sink_map_err(|e| error!("failed to encode message: {}", e)))
126 .map(|_| ());
127
128 let read_input = async {
129 while let Some(msg) = framed_stdin.next().await {
130 match msg {
131 Ok(Message::Request(req)) => {
132 if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
133 error!("{}", display_sources(err.into().as_ref()));
134 return;
135 }
136
137 let fut = service.call(req).unwrap_or_else(|err| {
138 error!("{}", display_sources(err.into().as_ref()));
139 None
140 });
141
142 server_tasks_tx.send(fut).await.unwrap();
143 }
144 Ok(Message::Response(res)) => {
145 if let Err(err) = client_responses.send(res).await {
146 error!("{}", display_sources(&err));
147 return;
148 }
149 }
150 Err(err) => {
151 error!("failed to decode message: {}", err);
152 let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
153 responses_tx.send(Message::Response(res)).await.unwrap();
154 }
155 }
156 }
157
158 server_tasks_tx.disconnect();
159 responses_tx.disconnect();
160 client_abort.abort();
161 };
162
163 join!(print_output, read_input, process_server_tasks);
164 }
165}
166
167fn display_sources(error: &dyn std::error::Error) -> String {
168 if let Some(source) = error.source() {
169 format!("{}: {}", error, display_sources(source))
170 } else {
171 error.to_string()
172 }
173}
174
175#[cfg(feature = "runtime-tokio")]
176#[inline]
177fn to_jsonrpc_error(err: ParseError) -> Error {
178 match err {
179 ParseError::Body(err) if err.is_data() => Error::invalid_request(),
180 _ => Error::parse_error(),
181 }
182}
183
184#[cfg(feature = "runtime-agnostic")]
185#[inline]
186fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
187 match err.source().and_then(|e| e.downcast_ref()) {
188 Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
189 _ => Error::parse_error(),
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use std::task::{Context, Poll};
196
197 #[cfg(feature = "runtime-agnostic")]
198 use futures::io::Cursor;
199 #[cfg(feature = "runtime-tokio")]
200 use std::io::Cursor;
201
202 use futures::future::Ready;
203 use futures::{future, sink, stream};
204
205 use super::*;
206
207 const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
208 const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
209
210 #[derive(Debug)]
211 struct MockService;
212
213 impl Service<Request> for MockService {
214 type Response = Option<Response>;
215 type Error = String;
216 type Future = Ready<Result<Self::Response, Self::Error>>;
217
218 fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
219 Poll::Ready(Ok(()))
220 }
221
222 fn call(&mut self, _: Request) -> Self::Future {
223 let response = serde_json::from_str(RESPONSE).unwrap();
224 future::ok(Some(response))
225 }
226 }
227
228 struct MockLoopback(Vec<Request>);
229
230 impl Loopback for MockLoopback {
231 type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
232 type ResponseSink = sink::Drain<Response>;
233
234 #[inline]
235 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
236 (stream::iter(self.0), sink::drain())
237 }
238 }
239
240 fn mock_request() -> Vec<u8> {
241 format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
242 }
243
244 fn mock_response() -> Vec<u8> {
245 format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
246 }
247
248 fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
249 (Cursor::new(mock_request()), Vec::new())
250 }
251
252 #[tokio::test(flavor = "current_thread")]
253 async fn serves_on_stdio() {
254 let (mut stdin, mut stdout) = mock_stdio();
255 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
256 .serve(MockService)
257 .await;
258
259 assert_eq!(stdin.position(), 80);
260 assert_eq!(stdout, mock_response());
261 }
262
263 #[tokio::test(flavor = "current_thread")]
264 async fn interleaves_messages() {
265 let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
266
267 let (mut stdin, mut stdout) = mock_stdio();
268 Server::new(&mut stdin, &mut stdout, socket)
269 .serve(MockService)
270 .await;
271
272 assert_eq!(stdin.position(), 80);
273 let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
274 assert_eq!(stdout, output);
275 }
276
277 #[tokio::test(flavor = "current_thread")]
278 async fn handles_invalid_json() {
279 let invalid = r#"{"jsonrpc":"2.0","method":"#;
280 let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
281 let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
282
283 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
284 .serve(MockService)
285 .await;
286
287 assert_eq!(stdin.position(), 48);
288 let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
289 let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
290 assert_eq!(stdout, output);
291 }
292}