tower_lsp/service/
pending.rs

1//! Types for tracking cancelable client-to-server JSON-RPC requests.
2
3use std::fmt::{self, Debug, Formatter};
4use std::future::Future;
5use std::sync::Arc;
6
7use dashmap::{mapref::entry::Entry, DashMap};
8use futures::future::{self, Either};
9use tracing::{debug, info};
10
11use super::ExitedError;
12use crate::jsonrpc::{Error, Id, Response};
13
14/// A hashmap containing pending server requests, keyed by request ID.
15pub struct Pending(Arc<DashMap<Id, future::AbortHandle>>);
16
17impl Pending {
18    /// Creates a new pending server requests map.
19    #[inline]
20    pub fn new() -> Self {
21        Pending(Arc::new(DashMap::new()))
22    }
23
24    /// Executes the given async request handler, keyed by the given request ID.
25    ///
26    /// If a cancel request is issued before the future is finished resolving, this will resolve to
27    /// a "canceled" error response, and the pending request handler future will be dropped.
28    pub fn execute<F>(
29        &self,
30        id: Id,
31        fut: F,
32    ) -> impl Future<Output = Result<Option<Response>, ExitedError>> + Send + 'static
33    where
34        F: Future<Output = Result<Option<Response>, ExitedError>> + Send + 'static,
35    {
36        if let Entry::Vacant(entry) = self.0.entry(id.clone()) {
37            let (handler_fut, abort_handle) = future::abortable(fut);
38            entry.insert(abort_handle);
39
40            let requests = self.0.clone();
41            Either::Left(async move {
42                let abort_result = handler_fut.await;
43                requests.remove(&id); // Remove abort handle now to avoid double cancellation.
44
45                if let Ok(handler_result) = abort_result {
46                    handler_result
47                } else {
48                    Ok(Some(Response::from_error(id, Error::request_cancelled())))
49                }
50            })
51        } else {
52            Either::Right(async { Ok(Some(Response::from_error(id, Error::invalid_request()))) })
53        }
54    }
55
56    /// Attempts to cancel the running request handler corresponding to this ID.
57    ///
58    /// This will force the future to resolve to a "canceled" error response. If the future has
59    /// already completed, this method call will do nothing.
60    pub fn cancel(&self, id: &Id) {
61        if let Some((_, handle)) = self.0.remove(id) {
62            handle.abort();
63            info!("successfully cancelled request with ID: {}", id);
64        } else {
65            debug!(
66                "client asked to cancel request {}, but no such pending request exists, ignoring",
67                id
68            );
69        }
70    }
71
72    /// Cancels all pending request handlers, if any.
73    #[inline]
74    pub fn cancel_all(&self) {
75        self.0.retain(|_, handle| {
76            handle.abort();
77            false
78        });
79    }
80}
81
82impl Debug for Pending {
83    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
84        f.debug_set()
85            .entries(self.0.iter().map(|entry| entry.key().clone()))
86            .finish()
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use serde_json::json;
93
94    use super::*;
95
96    #[tokio::test(flavor = "current_thread")]
97    async fn executes_server_request() {
98        let pending = Pending::new();
99
100        let id = Id::Number(1);
101        let id2 = id.clone();
102        let response = pending
103            .execute(id.clone(), async {
104                Ok(Some(Response::from_ok(id2, json!({}))))
105            })
106            .await;
107
108        assert_eq!(response, Ok(Some(Response::from_ok(id, json!({})))));
109    }
110
111    #[tokio::test(flavor = "current_thread")]
112    async fn cancels_server_request() {
113        let pending = Pending::new();
114
115        let id = Id::Number(1);
116        let handler_fut = tokio::spawn(pending.execute(id.clone(), future::pending()));
117
118        pending.cancel(&id);
119
120        let res = handler_fut.await.expect("task panicked");
121        assert_eq!(
122            res,
123            Ok(Some(Response::from_error(id, Error::request_cancelled())))
124        );
125    }
126}