reqwest/async_impl/
decoder.rs

1use std::fmt;
2#[cfg(any(
3    feature = "gzip",
4    feature = "zstd",
5    feature = "brotli",
6    feature = "deflate"
7))]
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12#[cfg(feature = "gzip")]
13use async_compression::tokio::bufread::GzipDecoder;
14
15#[cfg(feature = "brotli")]
16use async_compression::tokio::bufread::BrotliDecoder;
17
18#[cfg(feature = "zstd")]
19use async_compression::tokio::bufread::ZstdDecoder;
20
21#[cfg(feature = "deflate")]
22use async_compression::tokio::bufread::ZlibDecoder;
23
24#[cfg(any(
25    feature = "gzip",
26    feature = "zstd",
27    feature = "brotli",
28    feature = "deflate",
29    feature = "blocking",
30))]
31use futures_core::Stream;
32
33use bytes::Bytes;
34use http::HeaderMap;
35use hyper::body::Body as HttpBody;
36use hyper::body::Frame;
37
38#[cfg(any(
39    feature = "gzip",
40    feature = "brotli",
41    feature = "zstd",
42    feature = "deflate"
43))]
44use tokio_util::codec::{BytesCodec, FramedRead};
45#[cfg(any(
46    feature = "gzip",
47    feature = "brotli",
48    feature = "zstd",
49    feature = "deflate"
50))]
51use tokio_util::io::StreamReader;
52
53use super::body::ResponseBody;
54
55#[derive(Clone, Copy, Debug)]
56pub(super) struct Accepts {
57    #[cfg(feature = "gzip")]
58    pub(super) gzip: bool,
59    #[cfg(feature = "brotli")]
60    pub(super) brotli: bool,
61    #[cfg(feature = "zstd")]
62    pub(super) zstd: bool,
63    #[cfg(feature = "deflate")]
64    pub(super) deflate: bool,
65}
66
67impl Accepts {
68    pub fn none() -> Self {
69        Self {
70            #[cfg(feature = "gzip")]
71            gzip: false,
72            #[cfg(feature = "brotli")]
73            brotli: false,
74            #[cfg(feature = "zstd")]
75            zstd: false,
76            #[cfg(feature = "deflate")]
77            deflate: false,
78        }
79    }
80}
81
82/// A response decompressor over a non-blocking stream of chunks.
83///
84/// The inner decoder may be constructed asynchronously.
85pub(crate) struct Decoder {
86    inner: Inner,
87}
88
89#[cfg(any(
90    feature = "gzip",
91    feature = "zstd",
92    feature = "brotli",
93    feature = "deflate"
94))]
95type PeekableIoStream = futures_util::stream::Peekable<IoStream>;
96
97#[cfg(any(
98    feature = "gzip",
99    feature = "zstd",
100    feature = "brotli",
101    feature = "deflate"
102))]
103type PeekableIoStreamReader = StreamReader<PeekableIoStream, Bytes>;
104
105enum Inner {
106    /// A `PlainText` decoder just returns the response content as is.
107    PlainText(ResponseBody),
108
109    /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
110    #[cfg(feature = "gzip")]
111    Gzip(Pin<Box<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>),
112
113    /// A `Brotli` decoder will uncompress the brotlied response content before returning it.
114    #[cfg(feature = "brotli")]
115    Brotli(Pin<Box<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>),
116
117    /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it.
118    #[cfg(feature = "zstd")]
119    Zstd(Pin<Box<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>),
120
121    /// A `Deflate` decoder will uncompress the deflated response content before returning it.
122    #[cfg(feature = "deflate")]
123    Deflate(Pin<Box<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>),
124
125    /// A decoder that doesn't have a value yet.
126    #[cfg(any(
127        feature = "brotli",
128        feature = "zstd",
129        feature = "gzip",
130        feature = "deflate"
131    ))]
132    Pending(Pin<Box<Pending>>),
133}
134
135#[cfg(any(
136    feature = "gzip",
137    feature = "zstd",
138    feature = "brotli",
139    feature = "deflate"
140))]
141/// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
142struct Pending(PeekableIoStream, DecoderType);
143
144#[cfg(any(
145    feature = "gzip",
146    feature = "zstd",
147    feature = "brotli",
148    feature = "deflate",
149    feature = "blocking",
150))]
151pub(crate) struct IoStream<B = ResponseBody>(B);
152
153#[cfg(any(
154    feature = "gzip",
155    feature = "zstd",
156    feature = "brotli",
157    feature = "deflate"
158))]
159enum DecoderType {
160    #[cfg(feature = "gzip")]
161    Gzip,
162    #[cfg(feature = "brotli")]
163    Brotli,
164    #[cfg(feature = "zstd")]
165    Zstd,
166    #[cfg(feature = "deflate")]
167    Deflate,
168}
169
170impl fmt::Debug for Decoder {
171    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
172        f.debug_struct("Decoder").finish()
173    }
174}
175
176impl Decoder {
177    #[cfg(feature = "blocking")]
178    pub(crate) fn empty() -> Decoder {
179        Decoder {
180            inner: Inner::PlainText(empty()),
181        }
182    }
183
184    #[cfg(feature = "blocking")]
185    pub(crate) fn into_stream(self) -> IoStream<Self> {
186        IoStream(self)
187    }
188
189    /// A plain text decoder.
190    ///
191    /// This decoder will emit the underlying chunks as-is.
192    fn plain_text(body: ResponseBody) -> Decoder {
193        Decoder {
194            inner: Inner::PlainText(body),
195        }
196    }
197
198    /// A gzip decoder.
199    ///
200    /// This decoder will buffer and decompress chunks that are gzipped.
201    #[cfg(feature = "gzip")]
202    fn gzip(body: ResponseBody) -> Decoder {
203        use futures_util::StreamExt;
204
205        Decoder {
206            inner: Inner::Pending(Box::pin(Pending(
207                IoStream(body).peekable(),
208                DecoderType::Gzip,
209            ))),
210        }
211    }
212
213    /// A brotli decoder.
214    ///
215    /// This decoder will buffer and decompress chunks that are brotlied.
216    #[cfg(feature = "brotli")]
217    fn brotli(body: ResponseBody) -> Decoder {
218        use futures_util::StreamExt;
219
220        Decoder {
221            inner: Inner::Pending(Box::pin(Pending(
222                IoStream(body).peekable(),
223                DecoderType::Brotli,
224            ))),
225        }
226    }
227
228    /// A zstd decoder.
229    ///
230    /// This decoder will buffer and decompress chunks that are zstd compressed.
231    #[cfg(feature = "zstd")]
232    fn zstd(body: ResponseBody) -> Decoder {
233        use futures_util::StreamExt;
234
235        Decoder {
236            inner: Inner::Pending(Box::pin(Pending(
237                IoStream(body).peekable(),
238                DecoderType::Zstd,
239            ))),
240        }
241    }
242
243    /// A deflate decoder.
244    ///
245    /// This decoder will buffer and decompress chunks that are deflated.
246    #[cfg(feature = "deflate")]
247    fn deflate(body: ResponseBody) -> Decoder {
248        use futures_util::StreamExt;
249
250        Decoder {
251            inner: Inner::Pending(Box::pin(Pending(
252                IoStream(body).peekable(),
253                DecoderType::Deflate,
254            ))),
255        }
256    }
257
258    #[cfg(any(
259        feature = "brotli",
260        feature = "zstd",
261        feature = "gzip",
262        feature = "deflate"
263    ))]
264    fn detect_encoding(headers: &mut HeaderMap, encoding_str: &str) -> bool {
265        use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
266        use log::warn;
267
268        let mut is_content_encoded = {
269            headers
270                .get_all(CONTENT_ENCODING)
271                .iter()
272                .any(|enc| enc == encoding_str)
273                || headers
274                    .get_all(TRANSFER_ENCODING)
275                    .iter()
276                    .any(|enc| enc == encoding_str)
277        };
278        if is_content_encoded {
279            if let Some(content_length) = headers.get(CONTENT_LENGTH) {
280                if content_length == "0" {
281                    warn!("{encoding_str} response with content-length of 0");
282                    is_content_encoded = false;
283                }
284            }
285        }
286        if is_content_encoded {
287            headers.remove(CONTENT_ENCODING);
288            headers.remove(CONTENT_LENGTH);
289        }
290        is_content_encoded
291    }
292
293    /// Constructs a Decoder from a hyper request.
294    ///
295    /// A decoder is just a wrapper around the hyper request that knows
296    /// how to decode the content body of the request.
297    ///
298    /// Uses the correct variant by inspecting the Content-Encoding header.
299    pub(super) fn detect(
300        _headers: &mut HeaderMap,
301        body: ResponseBody,
302        _accepts: Accepts,
303    ) -> Decoder {
304        #[cfg(feature = "gzip")]
305        {
306            if _accepts.gzip && Decoder::detect_encoding(_headers, "gzip") {
307                return Decoder::gzip(body);
308            }
309        }
310
311        #[cfg(feature = "brotli")]
312        {
313            if _accepts.brotli && Decoder::detect_encoding(_headers, "br") {
314                return Decoder::brotli(body);
315            }
316        }
317
318        #[cfg(feature = "zstd")]
319        {
320            if _accepts.zstd && Decoder::detect_encoding(_headers, "zstd") {
321                return Decoder::zstd(body);
322            }
323        }
324
325        #[cfg(feature = "deflate")]
326        {
327            if _accepts.deflate && Decoder::detect_encoding(_headers, "deflate") {
328                return Decoder::deflate(body);
329            }
330        }
331
332        Decoder::plain_text(body)
333    }
334}
335
336impl HttpBody for Decoder {
337    type Data = Bytes;
338    type Error = crate::Error;
339
340    fn poll_frame(
341        mut self: Pin<&mut Self>,
342        cx: &mut Context,
343    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
344        match self.inner {
345            #[cfg(any(
346                feature = "brotli",
347                feature = "zstd",
348                feature = "gzip",
349                feature = "deflate"
350            ))]
351            Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) {
352                Poll::Ready(Ok(inner)) => {
353                    self.inner = inner;
354                    self.poll_frame(cx)
355                }
356                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(crate::error::decode_io(e)))),
357                Poll::Pending => Poll::Pending,
358            },
359            Inner::PlainText(ref mut body) => {
360                match futures_core::ready!(Pin::new(body).poll_frame(cx)) {
361                    Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
362                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode(err)))),
363                    None => Poll::Ready(None),
364                }
365            }
366            #[cfg(feature = "gzip")]
367            Inner::Gzip(ref mut decoder) => {
368                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
369                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
370                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
371                    None => Poll::Ready(None),
372                }
373            }
374            #[cfg(feature = "brotli")]
375            Inner::Brotli(ref mut decoder) => {
376                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
377                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
378                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
379                    None => Poll::Ready(None),
380                }
381            }
382            #[cfg(feature = "zstd")]
383            Inner::Zstd(ref mut decoder) => {
384                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
385                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
386                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
387                    None => Poll::Ready(None),
388                }
389            }
390            #[cfg(feature = "deflate")]
391            Inner::Deflate(ref mut decoder) => {
392                match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
393                    Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
394                    Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
395                    None => Poll::Ready(None),
396                }
397            }
398        }
399    }
400
401    fn size_hint(&self) -> http_body::SizeHint {
402        match self.inner {
403            Inner::PlainText(ref body) => HttpBody::size_hint(body),
404            // the rest are "unknown", so default
405            #[cfg(any(
406                feature = "brotli",
407                feature = "zstd",
408                feature = "gzip",
409                feature = "deflate"
410            ))]
411            _ => http_body::SizeHint::default(),
412        }
413    }
414}
415
416#[cfg(any(
417    feature = "gzip",
418    feature = "zstd",
419    feature = "brotli",
420    feature = "deflate",
421    feature = "blocking",
422))]
423fn empty() -> ResponseBody {
424    use http_body_util::{combinators::BoxBody, BodyExt, Empty};
425    BoxBody::new(Empty::new().map_err(|never| match never {}))
426}
427
428#[cfg(any(
429    feature = "gzip",
430    feature = "zstd",
431    feature = "brotli",
432    feature = "deflate"
433))]
434impl Future for Pending {
435    type Output = Result<Inner, std::io::Error>;
436
437    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
438        use futures_util::StreamExt;
439
440        match futures_core::ready!(Pin::new(&mut self.0).poll_peek(cx)) {
441            Some(Ok(_)) => {
442                // fallthrough
443            }
444            Some(Err(_e)) => {
445                // error was just a ref, so we need to really poll to move it
446                return Poll::Ready(Err(futures_core::ready!(
447                    Pin::new(&mut self.0).poll_next(cx)
448                )
449                .expect("just peeked Some")
450                .unwrap_err()));
451            }
452            None => return Poll::Ready(Ok(Inner::PlainText(empty()))),
453        };
454
455        let _body = std::mem::replace(&mut self.0, IoStream(empty()).peekable());
456
457        match self.1 {
458            #[cfg(feature = "brotli")]
459            DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new(
460                BrotliDecoder::new(StreamReader::new(_body)),
461                BytesCodec::new(),
462            ))))),
463            #[cfg(feature = "zstd")]
464            DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new(
465                ZstdDecoder::new(StreamReader::new(_body)),
466                BytesCodec::new(),
467            ))))),
468            #[cfg(feature = "gzip")]
469            DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new(
470                GzipDecoder::new(StreamReader::new(_body)),
471                BytesCodec::new(),
472            ))))),
473            #[cfg(feature = "deflate")]
474            DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new(
475                ZlibDecoder::new(StreamReader::new(_body)),
476                BytesCodec::new(),
477            ))))),
478        }
479    }
480}
481
482#[cfg(any(
483    feature = "gzip",
484    feature = "zstd",
485    feature = "brotli",
486    feature = "deflate",
487    feature = "blocking",
488))]
489impl<B> Stream for IoStream<B>
490where
491    B: HttpBody<Data = Bytes> + Unpin,
492    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
493{
494    type Item = Result<Bytes, std::io::Error>;
495
496    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
497        loop {
498            return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) {
499                Some(Ok(frame)) => {
500                    // skip non-data frames
501                    if let Ok(buf) = frame.into_data() {
502                        Poll::Ready(Some(Ok(buf)))
503                    } else {
504                        continue;
505                    }
506                }
507                Some(Err(err)) => Poll::Ready(Some(Err(crate::error::into_io(err.into())))),
508                None => Poll::Ready(None),
509            };
510        }
511    }
512}
513
514// ===== impl Accepts =====
515
516impl Accepts {
517    /*
518    pub(super) fn none() -> Self {
519        Accepts {
520            #[cfg(feature = "gzip")]
521            gzip: false,
522            #[cfg(feature = "brotli")]
523            brotli: false,
524            #[cfg(feature = "zstd")]
525            zstd: false,
526            #[cfg(feature = "deflate")]
527            deflate: false,
528        }
529    }
530    */
531
532    pub(super) fn as_str(&self) -> Option<&'static str> {
533        match (
534            self.is_gzip(),
535            self.is_brotli(),
536            self.is_zstd(),
537            self.is_deflate(),
538        ) {
539            (true, true, true, true) => Some("gzip, br, zstd, deflate"),
540            (true, true, false, true) => Some("gzip, br, deflate"),
541            (true, true, true, false) => Some("gzip, br, zstd"),
542            (true, true, false, false) => Some("gzip, br"),
543            (true, false, true, true) => Some("gzip, zstd, deflate"),
544            (true, false, false, true) => Some("gzip, deflate"),
545            (false, true, true, true) => Some("br, zstd, deflate"),
546            (false, true, false, true) => Some("br, deflate"),
547            (true, false, true, false) => Some("gzip, zstd"),
548            (true, false, false, false) => Some("gzip"),
549            (false, true, true, false) => Some("br, zstd"),
550            (false, true, false, false) => Some("br"),
551            (false, false, true, true) => Some("zstd, deflate"),
552            (false, false, true, false) => Some("zstd"),
553            (false, false, false, true) => Some("deflate"),
554            (false, false, false, false) => None,
555        }
556    }
557
558    fn is_gzip(&self) -> bool {
559        #[cfg(feature = "gzip")]
560        {
561            self.gzip
562        }
563
564        #[cfg(not(feature = "gzip"))]
565        {
566            false
567        }
568    }
569
570    fn is_brotli(&self) -> bool {
571        #[cfg(feature = "brotli")]
572        {
573            self.brotli
574        }
575
576        #[cfg(not(feature = "brotli"))]
577        {
578            false
579        }
580    }
581
582    fn is_zstd(&self) -> bool {
583        #[cfg(feature = "zstd")]
584        {
585            self.zstd
586        }
587
588        #[cfg(not(feature = "zstd"))]
589        {
590            false
591        }
592    }
593
594    fn is_deflate(&self) -> bool {
595        #[cfg(feature = "deflate")]
596        {
597            self.deflate
598        }
599
600        #[cfg(not(feature = "deflate"))]
601        {
602            false
603        }
604    }
605}
606
607impl Default for Accepts {
608    fn default() -> Accepts {
609        Accepts {
610            #[cfg(feature = "gzip")]
611            gzip: true,
612            #[cfg(feature = "brotli")]
613            brotli: true,
614            #[cfg(feature = "zstd")]
615            zstd: true,
616            #[cfg(feature = "deflate")]
617            deflate: true,
618        }
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn accepts_as_str() {
628        fn format_accept_encoding(accepts: &Accepts) -> String {
629            let mut encodings = vec![];
630            if accepts.is_gzip() {
631                encodings.push("gzip");
632            }
633            if accepts.is_brotli() {
634                encodings.push("br");
635            }
636            if accepts.is_zstd() {
637                encodings.push("zstd");
638            }
639            if accepts.is_deflate() {
640                encodings.push("deflate");
641            }
642            encodings.join(", ")
643        }
644
645        let state = [true, false];
646        let mut permutations = Vec::new();
647
648        #[allow(unused_variables)]
649        for gzip in state {
650            for brotli in state {
651                for zstd in state {
652                    for deflate in state {
653                        permutations.push(Accepts {
654                            #[cfg(feature = "gzip")]
655                            gzip,
656                            #[cfg(feature = "brotli")]
657                            brotli,
658                            #[cfg(feature = "zstd")]
659                            zstd,
660                            #[cfg(feature = "deflate")]
661                            deflate,
662                        });
663                    }
664                }
665            }
666        }
667
668        for accepts in permutations {
669            let expected = format_accept_encoding(&accepts);
670            let got = accepts.as_str().unwrap_or("");
671            assert_eq!(got, expected.as_str());
672        }
673    }
674}