1use std::fmt;
2#[cfg(any(
3 feature = "gzip",
4 feature = "zstd",
5 feature = "brotli",
6 feature = "deflate"
7))]
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12#[cfg(feature = "gzip")]
13use async_compression::tokio::bufread::GzipDecoder;
14
15#[cfg(feature = "brotli")]
16use async_compression::tokio::bufread::BrotliDecoder;
17
18#[cfg(feature = "zstd")]
19use async_compression::tokio::bufread::ZstdDecoder;
20
21#[cfg(feature = "deflate")]
22use async_compression::tokio::bufread::ZlibDecoder;
23
24#[cfg(any(
25 feature = "gzip",
26 feature = "zstd",
27 feature = "brotli",
28 feature = "deflate",
29 feature = "blocking",
30))]
31use futures_core::Stream;
32
33use bytes::Bytes;
34use http::HeaderMap;
35use hyper::body::Body as HttpBody;
36use hyper::body::Frame;
37
38#[cfg(any(
39 feature = "gzip",
40 feature = "brotli",
41 feature = "zstd",
42 feature = "deflate"
43))]
44use tokio_util::codec::{BytesCodec, FramedRead};
45#[cfg(any(
46 feature = "gzip",
47 feature = "brotli",
48 feature = "zstd",
49 feature = "deflate"
50))]
51use tokio_util::io::StreamReader;
52
53use super::body::ResponseBody;
54
55#[derive(Clone, Copy, Debug)]
56pub(super) struct Accepts {
57 #[cfg(feature = "gzip")]
58 pub(super) gzip: bool,
59 #[cfg(feature = "brotli")]
60 pub(super) brotli: bool,
61 #[cfg(feature = "zstd")]
62 pub(super) zstd: bool,
63 #[cfg(feature = "deflate")]
64 pub(super) deflate: bool,
65}
66
67impl Accepts {
68 pub fn none() -> Self {
69 Self {
70 #[cfg(feature = "gzip")]
71 gzip: false,
72 #[cfg(feature = "brotli")]
73 brotli: false,
74 #[cfg(feature = "zstd")]
75 zstd: false,
76 #[cfg(feature = "deflate")]
77 deflate: false,
78 }
79 }
80}
81
82pub(crate) struct Decoder {
86 inner: Inner,
87}
88
89#[cfg(any(
90 feature = "gzip",
91 feature = "zstd",
92 feature = "brotli",
93 feature = "deflate"
94))]
95type PeekableIoStream = futures_util::stream::Peekable<IoStream>;
96
97#[cfg(any(
98 feature = "gzip",
99 feature = "zstd",
100 feature = "brotli",
101 feature = "deflate"
102))]
103type PeekableIoStreamReader = StreamReader<PeekableIoStream, Bytes>;
104
105enum Inner {
106 PlainText(ResponseBody),
108
109 #[cfg(feature = "gzip")]
111 Gzip(Pin<Box<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>),
112
113 #[cfg(feature = "brotli")]
115 Brotli(Pin<Box<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>),
116
117 #[cfg(feature = "zstd")]
119 Zstd(Pin<Box<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>),
120
121 #[cfg(feature = "deflate")]
123 Deflate(Pin<Box<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>),
124
125 #[cfg(any(
127 feature = "brotli",
128 feature = "zstd",
129 feature = "gzip",
130 feature = "deflate"
131 ))]
132 Pending(Pin<Box<Pending>>),
133}
134
135#[cfg(any(
136 feature = "gzip",
137 feature = "zstd",
138 feature = "brotli",
139 feature = "deflate"
140))]
141struct Pending(PeekableIoStream, DecoderType);
143
144#[cfg(any(
145 feature = "gzip",
146 feature = "zstd",
147 feature = "brotli",
148 feature = "deflate",
149 feature = "blocking",
150))]
151pub(crate) struct IoStream<B = ResponseBody>(B);
152
153#[cfg(any(
154 feature = "gzip",
155 feature = "zstd",
156 feature = "brotli",
157 feature = "deflate"
158))]
159enum DecoderType {
160 #[cfg(feature = "gzip")]
161 Gzip,
162 #[cfg(feature = "brotli")]
163 Brotli,
164 #[cfg(feature = "zstd")]
165 Zstd,
166 #[cfg(feature = "deflate")]
167 Deflate,
168}
169
170impl fmt::Debug for Decoder {
171 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
172 f.debug_struct("Decoder").finish()
173 }
174}
175
176impl Decoder {
177 #[cfg(feature = "blocking")]
178 pub(crate) fn empty() -> Decoder {
179 Decoder {
180 inner: Inner::PlainText(empty()),
181 }
182 }
183
184 #[cfg(feature = "blocking")]
185 pub(crate) fn into_stream(self) -> IoStream<Self> {
186 IoStream(self)
187 }
188
189 fn plain_text(body: ResponseBody) -> Decoder {
193 Decoder {
194 inner: Inner::PlainText(body),
195 }
196 }
197
198 #[cfg(feature = "gzip")]
202 fn gzip(body: ResponseBody) -> Decoder {
203 use futures_util::StreamExt;
204
205 Decoder {
206 inner: Inner::Pending(Box::pin(Pending(
207 IoStream(body).peekable(),
208 DecoderType::Gzip,
209 ))),
210 }
211 }
212
213 #[cfg(feature = "brotli")]
217 fn brotli(body: ResponseBody) -> Decoder {
218 use futures_util::StreamExt;
219
220 Decoder {
221 inner: Inner::Pending(Box::pin(Pending(
222 IoStream(body).peekable(),
223 DecoderType::Brotli,
224 ))),
225 }
226 }
227
228 #[cfg(feature = "zstd")]
232 fn zstd(body: ResponseBody) -> Decoder {
233 use futures_util::StreamExt;
234
235 Decoder {
236 inner: Inner::Pending(Box::pin(Pending(
237 IoStream(body).peekable(),
238 DecoderType::Zstd,
239 ))),
240 }
241 }
242
243 #[cfg(feature = "deflate")]
247 fn deflate(body: ResponseBody) -> Decoder {
248 use futures_util::StreamExt;
249
250 Decoder {
251 inner: Inner::Pending(Box::pin(Pending(
252 IoStream(body).peekable(),
253 DecoderType::Deflate,
254 ))),
255 }
256 }
257
258 #[cfg(any(
259 feature = "brotli",
260 feature = "zstd",
261 feature = "gzip",
262 feature = "deflate"
263 ))]
264 fn detect_encoding(headers: &mut HeaderMap, encoding_str: &str) -> bool {
265 use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
266 use log::warn;
267
268 let mut is_content_encoded = {
269 headers
270 .get_all(CONTENT_ENCODING)
271 .iter()
272 .any(|enc| enc == encoding_str)
273 || headers
274 .get_all(TRANSFER_ENCODING)
275 .iter()
276 .any(|enc| enc == encoding_str)
277 };
278 if is_content_encoded {
279 if let Some(content_length) = headers.get(CONTENT_LENGTH) {
280 if content_length == "0" {
281 warn!("{encoding_str} response with content-length of 0");
282 is_content_encoded = false;
283 }
284 }
285 }
286 if is_content_encoded {
287 headers.remove(CONTENT_ENCODING);
288 headers.remove(CONTENT_LENGTH);
289 }
290 is_content_encoded
291 }
292
293 pub(super) fn detect(
300 _headers: &mut HeaderMap,
301 body: ResponseBody,
302 _accepts: Accepts,
303 ) -> Decoder {
304 #[cfg(feature = "gzip")]
305 {
306 if _accepts.gzip && Decoder::detect_encoding(_headers, "gzip") {
307 return Decoder::gzip(body);
308 }
309 }
310
311 #[cfg(feature = "brotli")]
312 {
313 if _accepts.brotli && Decoder::detect_encoding(_headers, "br") {
314 return Decoder::brotli(body);
315 }
316 }
317
318 #[cfg(feature = "zstd")]
319 {
320 if _accepts.zstd && Decoder::detect_encoding(_headers, "zstd") {
321 return Decoder::zstd(body);
322 }
323 }
324
325 #[cfg(feature = "deflate")]
326 {
327 if _accepts.deflate && Decoder::detect_encoding(_headers, "deflate") {
328 return Decoder::deflate(body);
329 }
330 }
331
332 Decoder::plain_text(body)
333 }
334}
335
336impl HttpBody for Decoder {
337 type Data = Bytes;
338 type Error = crate::Error;
339
340 fn poll_frame(
341 mut self: Pin<&mut Self>,
342 cx: &mut Context,
343 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
344 match self.inner {
345 #[cfg(any(
346 feature = "brotli",
347 feature = "zstd",
348 feature = "gzip",
349 feature = "deflate"
350 ))]
351 Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) {
352 Poll::Ready(Ok(inner)) => {
353 self.inner = inner;
354 self.poll_frame(cx)
355 }
356 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(crate::error::decode_io(e)))),
357 Poll::Pending => Poll::Pending,
358 },
359 Inner::PlainText(ref mut body) => {
360 match futures_core::ready!(Pin::new(body).poll_frame(cx)) {
361 Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
362 Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode(err)))),
363 None => Poll::Ready(None),
364 }
365 }
366 #[cfg(feature = "gzip")]
367 Inner::Gzip(ref mut decoder) => {
368 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
369 Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
370 Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
371 None => Poll::Ready(None),
372 }
373 }
374 #[cfg(feature = "brotli")]
375 Inner::Brotli(ref mut decoder) => {
376 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
377 Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
378 Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
379 None => Poll::Ready(None),
380 }
381 }
382 #[cfg(feature = "zstd")]
383 Inner::Zstd(ref mut decoder) => {
384 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
385 Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
386 Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
387 None => Poll::Ready(None),
388 }
389 }
390 #[cfg(feature = "deflate")]
391 Inner::Deflate(ref mut decoder) => {
392 match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
393 Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
394 Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
395 None => Poll::Ready(None),
396 }
397 }
398 }
399 }
400
401 fn size_hint(&self) -> http_body::SizeHint {
402 match self.inner {
403 Inner::PlainText(ref body) => HttpBody::size_hint(body),
404 #[cfg(any(
406 feature = "brotli",
407 feature = "zstd",
408 feature = "gzip",
409 feature = "deflate"
410 ))]
411 _ => http_body::SizeHint::default(),
412 }
413 }
414}
415
416#[cfg(any(
417 feature = "gzip",
418 feature = "zstd",
419 feature = "brotli",
420 feature = "deflate",
421 feature = "blocking",
422))]
423fn empty() -> ResponseBody {
424 use http_body_util::{combinators::BoxBody, BodyExt, Empty};
425 BoxBody::new(Empty::new().map_err(|never| match never {}))
426}
427
428#[cfg(any(
429 feature = "gzip",
430 feature = "zstd",
431 feature = "brotli",
432 feature = "deflate"
433))]
434impl Future for Pending {
435 type Output = Result<Inner, std::io::Error>;
436
437 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
438 use futures_util::StreamExt;
439
440 match futures_core::ready!(Pin::new(&mut self.0).poll_peek(cx)) {
441 Some(Ok(_)) => {
442 }
444 Some(Err(_e)) => {
445 return Poll::Ready(Err(futures_core::ready!(
447 Pin::new(&mut self.0).poll_next(cx)
448 )
449 .expect("just peeked Some")
450 .unwrap_err()));
451 }
452 None => return Poll::Ready(Ok(Inner::PlainText(empty()))),
453 };
454
455 let _body = std::mem::replace(&mut self.0, IoStream(empty()).peekable());
456
457 match self.1 {
458 #[cfg(feature = "brotli")]
459 DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new(
460 BrotliDecoder::new(StreamReader::new(_body)),
461 BytesCodec::new(),
462 ))))),
463 #[cfg(feature = "zstd")]
464 DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new(
465 ZstdDecoder::new(StreamReader::new(_body)),
466 BytesCodec::new(),
467 ))))),
468 #[cfg(feature = "gzip")]
469 DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new(
470 GzipDecoder::new(StreamReader::new(_body)),
471 BytesCodec::new(),
472 ))))),
473 #[cfg(feature = "deflate")]
474 DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new(
475 ZlibDecoder::new(StreamReader::new(_body)),
476 BytesCodec::new(),
477 ))))),
478 }
479 }
480}
481
482#[cfg(any(
483 feature = "gzip",
484 feature = "zstd",
485 feature = "brotli",
486 feature = "deflate",
487 feature = "blocking",
488))]
489impl<B> Stream for IoStream<B>
490where
491 B: HttpBody<Data = Bytes> + Unpin,
492 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
493{
494 type Item = Result<Bytes, std::io::Error>;
495
496 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
497 loop {
498 return match futures_core::ready!(Pin::new(&mut self.0).poll_frame(cx)) {
499 Some(Ok(frame)) => {
500 if let Ok(buf) = frame.into_data() {
502 Poll::Ready(Some(Ok(buf)))
503 } else {
504 continue;
505 }
506 }
507 Some(Err(err)) => Poll::Ready(Some(Err(crate::error::into_io(err.into())))),
508 None => Poll::Ready(None),
509 };
510 }
511 }
512}
513
514impl Accepts {
517 pub(super) fn as_str(&self) -> Option<&'static str> {
533 match (
534 self.is_gzip(),
535 self.is_brotli(),
536 self.is_zstd(),
537 self.is_deflate(),
538 ) {
539 (true, true, true, true) => Some("gzip, br, zstd, deflate"),
540 (true, true, false, true) => Some("gzip, br, deflate"),
541 (true, true, true, false) => Some("gzip, br, zstd"),
542 (true, true, false, false) => Some("gzip, br"),
543 (true, false, true, true) => Some("gzip, zstd, deflate"),
544 (true, false, false, true) => Some("gzip, deflate"),
545 (false, true, true, true) => Some("br, zstd, deflate"),
546 (false, true, false, true) => Some("br, deflate"),
547 (true, false, true, false) => Some("gzip, zstd"),
548 (true, false, false, false) => Some("gzip"),
549 (false, true, true, false) => Some("br, zstd"),
550 (false, true, false, false) => Some("br"),
551 (false, false, true, true) => Some("zstd, deflate"),
552 (false, false, true, false) => Some("zstd"),
553 (false, false, false, true) => Some("deflate"),
554 (false, false, false, false) => None,
555 }
556 }
557
558 fn is_gzip(&self) -> bool {
559 #[cfg(feature = "gzip")]
560 {
561 self.gzip
562 }
563
564 #[cfg(not(feature = "gzip"))]
565 {
566 false
567 }
568 }
569
570 fn is_brotli(&self) -> bool {
571 #[cfg(feature = "brotli")]
572 {
573 self.brotli
574 }
575
576 #[cfg(not(feature = "brotli"))]
577 {
578 false
579 }
580 }
581
582 fn is_zstd(&self) -> bool {
583 #[cfg(feature = "zstd")]
584 {
585 self.zstd
586 }
587
588 #[cfg(not(feature = "zstd"))]
589 {
590 false
591 }
592 }
593
594 fn is_deflate(&self) -> bool {
595 #[cfg(feature = "deflate")]
596 {
597 self.deflate
598 }
599
600 #[cfg(not(feature = "deflate"))]
601 {
602 false
603 }
604 }
605}
606
607impl Default for Accepts {
608 fn default() -> Accepts {
609 Accepts {
610 #[cfg(feature = "gzip")]
611 gzip: true,
612 #[cfg(feature = "brotli")]
613 brotli: true,
614 #[cfg(feature = "zstd")]
615 zstd: true,
616 #[cfg(feature = "deflate")]
617 deflate: true,
618 }
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 #[test]
627 fn accepts_as_str() {
628 fn format_accept_encoding(accepts: &Accepts) -> String {
629 let mut encodings = vec![];
630 if accepts.is_gzip() {
631 encodings.push("gzip");
632 }
633 if accepts.is_brotli() {
634 encodings.push("br");
635 }
636 if accepts.is_zstd() {
637 encodings.push("zstd");
638 }
639 if accepts.is_deflate() {
640 encodings.push("deflate");
641 }
642 encodings.join(", ")
643 }
644
645 let state = [true, false];
646 let mut permutations = Vec::new();
647
648 #[allow(unused_variables)]
649 for gzip in state {
650 for brotli in state {
651 for zstd in state {
652 for deflate in state {
653 permutations.push(Accepts {
654 #[cfg(feature = "gzip")]
655 gzip,
656 #[cfg(feature = "brotli")]
657 brotli,
658 #[cfg(feature = "zstd")]
659 zstd,
660 #[cfg(feature = "deflate")]
661 deflate,
662 });
663 }
664 }
665 }
666 }
667
668 for accepts in permutations {
669 let expected = format_accept_encoding(&accepts);
670 let got = accepts.as_str().unwrap_or("");
671 assert_eq!(got, expected.as_str());
672 }
673 }
674}