1use std::marker::PhantomData;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use futures::future::{self, BoxFuture, FutureExt};
8use tower::{Layer, Service};
9use tracing::{info, warn};
10
11use super::ExitedError;
12use crate::jsonrpc::{not_initialized_error, Error, Id, Request, Response};
13
14use super::client::Client;
15use super::pending::Pending;
16use super::state::{ServerState, State};
17
18pub struct Initialize {
24 state: Arc<ServerState>,
25 pending: Arc<Pending>,
26}
27
28impl Initialize {
29 #[inline]
30 pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
31 Initialize { state, pending }
32 }
33}
34
35impl<S> Layer<S> for Initialize {
36 type Service = InitializeService<S>;
37
38 fn layer(&self, inner: S) -> Self::Service {
39 InitializeService {
40 inner: Cancellable::new(inner, self.pending.clone()),
41 state: self.state.clone(),
42 }
43 }
44}
45
46pub struct InitializeService<S> {
48 inner: Cancellable<S>,
49 state: Arc<ServerState>,
50}
51
52impl<S> Service<Request> for InitializeService<S>
53where
54 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
55 S::Future: Send + 'static,
56{
57 type Response = S::Response;
58 type Error = S::Error;
59 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
60
61 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
62 self.inner.poll_ready(cx)
63 }
64
65 fn call(&mut self, req: Request) -> Self::Future {
66 if self.state.get() == State::Uninitialized {
67 let state = self.state.clone();
68 let fut = self.inner.call(req);
69
70 Box::pin(async move {
71 let response = fut.await?;
72
73 match &response {
74 Some(res) if res.is_ok() => state.set(State::Initialized),
75 _ => state.set(State::Uninitialized),
76 }
77
78 Ok(response)
79 })
80 } else {
81 warn!("received duplicate `initialize` request, ignoring");
82 let (_, id, _) = req.into_parts();
83 future::ok(id.map(|id| Response::from_error(id, Error::invalid_request()))).boxed()
84 }
85 }
86}
87
88pub struct Shutdown {
94 state: Arc<ServerState>,
95 pending: Arc<Pending>,
96}
97
98impl Shutdown {
99 #[inline]
100 pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
101 Shutdown { state, pending }
102 }
103}
104
105impl<S> Layer<S> for Shutdown {
106 type Service = ShutdownService<S>;
107
108 fn layer(&self, inner: S) -> Self::Service {
109 ShutdownService {
110 inner: Cancellable::new(inner, self.pending.clone()),
111 state: self.state.clone(),
112 }
113 }
114}
115
116pub struct ShutdownService<S> {
118 inner: Cancellable<S>,
119 state: Arc<ServerState>,
120}
121
122impl<S> Service<Request> for ShutdownService<S>
123where
124 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
125 S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
126{
127 type Response = S::Response;
128 type Error = S::Error;
129 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
130
131 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
132 self.inner.poll_ready(cx)
133 }
134
135 fn call(&mut self, req: Request) -> Self::Future {
136 match self.state.get() {
137 State::Initialized => {
138 info!("shutdown request received, shutting down");
139 self.state.set(State::ShutDown);
140 self.inner.call(req)
141 }
142 cur_state => {
143 let (_, id, _) = req.into_parts();
144 future::ok(not_initialized_response(id, cur_state)).boxed()
145 }
146 }
147 }
148}
149
150pub struct Exit {
156 state: Arc<ServerState>,
157 pending: Arc<Pending>,
158 client: Client,
159}
160
161impl Exit {
162 #[inline]
163 pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>, client: Client) -> Self {
164 Exit {
165 state,
166 pending,
167 client,
168 }
169 }
170}
171
172impl<S> Layer<S> for Exit {
173 type Service = ExitService<S>;
174
175 fn layer(&self, _: S) -> Self::Service {
176 ExitService {
177 state: self.state.clone(),
178 pending: self.pending.clone(),
179 client: self.client.clone(),
180 _marker: PhantomData,
181 }
182 }
183}
184
185pub struct ExitService<S> {
187 state: Arc<ServerState>,
188 pending: Arc<Pending>,
189 client: Client,
190 _marker: PhantomData<S>,
191}
192
193impl<S> Service<Request> for ExitService<S> {
194 type Response = Option<Response>;
195 type Error = ExitedError;
196 type Future = future::Ready<Result<Self::Response, Self::Error>>;
197
198 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199 if self.state.get() == State::Exited {
200 Poll::Ready(Err(ExitedError(())))
201 } else {
202 Poll::Ready(Ok(()))
203 }
204 }
205
206 fn call(&mut self, _: Request) -> Self::Future {
207 info!("exit notification received, stopping");
208 self.state.set(State::Exited);
209 self.pending.cancel_all();
210 self.client.close();
211 future::ok(None)
212 }
213}
214
215pub struct Normal {
217 state: Arc<ServerState>,
218 pending: Arc<Pending>,
219}
220
221impl Normal {
222 #[inline]
223 pub const fn new(state: Arc<ServerState>, pending: Arc<Pending>) -> Self {
224 Normal { state, pending }
225 }
226}
227
228impl<S> Layer<S> for Normal {
229 type Service = NormalService<S>;
230
231 fn layer(&self, inner: S) -> Self::Service {
232 NormalService {
233 inner: Cancellable::new(inner, self.pending.clone()),
234 state: self.state.clone(),
235 }
236 }
237}
238
239pub struct NormalService<S> {
241 inner: Cancellable<S>,
242 state: Arc<ServerState>,
243}
244
245impl<S> Service<Request> for NormalService<S>
246where
247 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
248 S::Future: Into<BoxFuture<'static, Result<Option<Response>, S::Error>>> + Send + 'static,
249{
250 type Response = S::Response;
251 type Error = S::Error;
252 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
253
254 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
255 self.inner.poll_ready(cx)
256 }
257
258 fn call(&mut self, req: Request) -> Self::Future {
259 match self.state.get() {
260 State::Initialized => self.inner.call(req),
261 cur_state => {
262 let (_, id, _) = req.into_parts();
263 future::ok(not_initialized_response(id, cur_state)).boxed()
264 }
265 }
266 }
267}
268
269struct Cancellable<S> {
275 inner: S,
276 pending: Arc<Pending>,
277}
278
279impl<S> Cancellable<S> {
280 #[inline]
281 const fn new(inner: S, pending: Arc<Pending>) -> Self {
282 Cancellable { inner, pending }
283 }
284}
285
286impl<S> Service<Request> for Cancellable<S>
287where
288 S: Service<Request, Response = Option<Response>, Error = ExitedError>,
289 S::Future: Send + 'static,
290{
291 type Response = S::Response;
292 type Error = S::Error;
293 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
294
295 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
296 self.inner.poll_ready(cx)
297 }
298
299 fn call(&mut self, req: Request) -> Self::Future {
300 match req.id().cloned() {
301 Some(id) => self.pending.execute(id, self.inner.call(req)).boxed(),
302 None => self.inner.call(req).boxed(),
303 }
304 }
305}
306
307fn not_initialized_response(id: Option<Id>, server_state: State) -> Option<Response> {
308 let id = id?;
309 let error = match server_state {
310 State::Uninitialized | State::Initializing => not_initialized_error(),
311 _ => Error::invalid_request(),
312 };
313
314 Some(Response::from_error(id, error))
315}
316
317