tower_lsp/service/
layers.rs

1//! Assorted middleware that implements LSP server semantics.
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures::future::{self, BoxFuture, FutureExt};
8use tower::{Layer, Service};
9use tracing::{info, warn};
10
11use super::ExitedError;
12use crate::jsonrpc::{not_initialized_error, Error, Id, Request, Response};
13
14use super::client::Client;
15use super::pending::Pending;
16use super::state::{ServerState, State};
17
18/// Middleware which implements `initialize` request semantics.
19///
20/// # Specification
21///
22/// https://microsoft.github.io/language-server-protocol/specification#initialize
23pub struct Initialize {
24    state: Arc<ServerState>,
25    pending: Arc<Pending>,
26}
27
28impl Initialize {
29    #[inline]
30    pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
31        Initialize { state, pending }
32    }
33}
34
35impl<S> Layer<S> for Initialize {
36    type Service = InitializeService<S>;
37
38    fn layer(&self, inner: S) -> Self::Service {
39        InitializeService {
40            inner: Cancellable::new(inner, self.pending.clone()),
41            state: self.state.clone(),
42        }
43    }
44}
45
46/// Service created from [`Initialize`] layer.
47pub struct InitializeService<S> {
48    inner: Cancellable<S>,
49    state: Arc<ServerState>,
50}
51
52impl<S> Service<Request> for InitializeService<S>
53where
54    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
55    S::Future: Send + 'static,
56{
57    type Response = S::Response;
58    type Error = S::Error;
59    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
60
61    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62        self.inner.poll_ready(cx)
63    }
64
65    fn call(&mut self, req: Request) -> Self::Future {
66        if self.state.get() == State::Uninitialized {
67            let state = self.state.clone();
68            let fut = self.inner.call(req);
69
70            Box::pin(async move {
71                let response = fut.await?;
72
73                match &response {
74                    Some(res) if res.is_ok() => state.set(State::Initialized),
75                    _ => state.set(State::Uninitialized),
76                }
77
78                Ok(response)
79            })
80        } else {
81            warn!("received duplicate `initialize` request, ignoring");
82            let (_, id, _) = req.into_parts();
83            future::ok(id.map(|id| Response::from_error(id, Error::invalid_request()))).boxed()
84        }
85    }
86}
87
88/// Middleware which implements `shutdown` request semantics.
89///
90/// # Specification
91///
92/// https://microsoft.github.io/language-server-protocol/specification#shutdown
93pub struct Shutdown {
94    state: Arc<ServerState>,
95    pending: Arc<Pending>,
96}
97
98impl Shutdown {
99    #[inline]
100    pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
101        Shutdown { state, pending }
102    }
103}
104
105impl<S> Layer<S> for Shutdown {
106    type Service = ShutdownService<S>;
107
108    fn layer(&self, inner: S) -> Self::Service {
109        ShutdownService {
110            inner: Cancellable::new(inner, self.pending.clone()),
111            state: self.state.clone(),
112        }
113    }
114}
115
116/// Service created from [`Shutdown`] layer.
117pub struct ShutdownService<S> {
118    inner: Cancellable<S>,
119    state: Arc<ServerState>,
120}
121
122impl<S> Service<Request> for ShutdownService<S>
123where
124    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
125    S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
126{
127    type Response = S::Response;
128    type Error = S::Error;
129    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
130
131    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132        self.inner.poll_ready(cx)
133    }
134
135    fn call(&mut self, req: Request) -> Self::Future {
136        match self.state.get() {
137            State::Initialized => {
138                info!("shutdown request received, shutting down");
139                self.state.set(State::ShutDown);
140                self.inner.call(req)
141            }
142            cur_state => {
143                let (_, id, _) = req.into_parts();
144                future::ok(not_initialized_response(id, cur_state)).boxed()
145            }
146        }
147    }
148}
149
150/// Middleware which implements `exit` notification semantics.
151///
152/// # Specification
153///
154/// https://microsoft.github.io/language-server-protocol/specification#exit
155pub struct Exit {
156    state: Arc<ServerState>,
157    pending: Arc<Pending>,
158    client: Client,
159}
160
161impl Exit {
162    #[inline]
163    pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>, client: Client) -> Self {
164        Exit {
165            state,
166            pending,
167            client,
168        }
169    }
170}
171
172impl<S> Layer<S> for Exit {
173    type Service = ExitService<S>;
174
175    fn layer(&self, _: S) -> Self::Service {
176        ExitService {
177            state: self.state.clone(),
178            pending: self.pending.clone(),
179            client: self.client.clone(),
180            _marker: PhantomData,
181        }
182    }
183}
184
185/// Service created from [`Exit`] layer.
186pub struct ExitService<S> {
187    state: Arc<ServerState>,
188    pending: Arc<Pending>,
189    client: Client,
190    _marker: PhantomData<S>,
191}
192
193impl<S> Service<Request> for ExitService<S> {
194    type Response = Option<Response>;
195    type Error = ExitedError;
196    type Future = future::Ready<Result<Self::Response, Self::Error>>;
197
198    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        if self.state.get() == State::Exited {
200            Poll::Ready(Err(ExitedError(())))
201        } else {
202            Poll::Ready(Ok(()))
203        }
204    }
205
206    fn call(&mut self, _: Request) -> Self::Future {
207        info!("exit notification received, stopping");
208        self.state.set(State::Exited);
209        self.pending.cancel_all();
210        self.client.close();
211        future::ok(None)
212    }
213}
214
215/// Middleware which implements LSP semantics for all other kinds of requests.
216pub struct Normal {
217    state: Arc<ServerState>,
218    pending: Arc<Pending>,
219}
220
221impl Normal {
222    #[inline]
223    pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
224        Normal { state, pending }
225    }
226}
227
228impl<S> Layer<S> for Normal {
229    type Service = NormalService<S>;
230
231    fn layer(&self, inner: S) -> Self::Service {
232        NormalService {
233            inner: Cancellable::new(inner, self.pending.clone()),
234            state: self.state.clone(),
235        }
236    }
237}
238
239/// Service created from [`Normal`] layer.
240pub struct NormalService<S> {
241    inner: Cancellable<S>,
242    state: Arc<ServerState>,
243}
244
245impl<S> Service<Request> for NormalService<S>
246where
247    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
248    S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
249{
250    type Response = S::Response;
251    type Error = S::Error;
252    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
253
254    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255        self.inner.poll_ready(cx)
256    }
257
258    fn call(&mut self, req: Request) -> Self::Future {
259        match self.state.get() {
260            State::Initialized => self.inner.call(req),
261            cur_state => {
262                let (_, id, _) = req.into_parts();
263                future::ok(not_initialized_response(id, cur_state)).boxed()
264            }
265        }
266    }
267}
268
269/// Wraps an inner service `S` and implements `$/cancelRequest` semantics for all requests.
270///
271/// # Specification
272///
273/// https://microsoft.github.io/language-server-protocol/specification#cancelRequest
274struct Cancellable<S> {
275    inner: S,
276    pending: Arc<Pending>,
277}
278
279impl<S> Cancellable<S> {
280    #[inline]
281    const fn new(inner: S, pending: Arc<Pending>) -> Self {
282        Cancellable { inner, pending }
283    }
284}
285
286impl<S> Service<Request> for Cancellable<S>
287where
288    S: Service<Request, Response = Option<Response>, Error = ExitedError>,
289    S::Future: Send + 'static,
290{
291    type Response = S::Response;
292    type Error = S::Error;
293    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
294
295    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
296        self.inner.poll_ready(cx)
297    }
298
299    fn call(&mut self, req: Request) -> Self::Future {
300        match req.id().cloned() {
301            Some(id) => self.pending.execute(id, self.inner.call(req)).boxed(),
302            None => self.inner.call(req).boxed(),
303        }
304    }
305}
306
307fn not_initialized_response(id: Option<Id>, server_state: State) -> Option<Response> {
308    let id = id?;
309    let error = match server_state {
310        State::Uninitialized | State::Initializing => not_initialized_error(),
311        _ => Error::invalid_request(),
312    };
313
314    Some(Response::from_error(id, error))
315}
316
317// TODO: Add some `tower-test` middleware tests for each middleware.