reqwest/
connect.rs

1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7#[cfg(any(feature = "socks", feature = "__tls"))]
8use hyper_util::rt::TokioIo;
9#[cfg(feature = "default-tls")]
10use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
11use tower_service::Service;
12
13use pin_project_lite::pin_project;
14use std::future::Future;
15use std::io::{self, IoSlice};
16use std::net::IpAddr;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22#[cfg(feature = "default-tls")]
23use self::native_tls_conn::NativeTlsConn;
24#[cfg(feature = "__rustls")]
25use self::rustls_tls_conn::RustlsTlsConn;
26use crate::dns::DynResolver;
27use crate::error::BoxError;
28use crate::proxy::{Proxy, ProxyScheme};
29
30pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector<DynResolver>;
31
32#[derive(Clone)]
33pub(crate) struct Connector {
34    inner: Inner,
35    proxies: Arc<Vec<Proxy>>,
36    verbose: verbose::Wrapper,
37    timeout: Option<Duration>,
38    #[cfg(feature = "__tls")]
39    nodelay: bool,
40    #[cfg(feature = "__tls")]
41    tls_info: bool,
42    #[cfg(feature = "__tls")]
43    user_agent: Option<HeaderValue>,
44}
45
46#[derive(Clone)]
47enum Inner {
48    #[cfg(not(feature = "__tls"))]
49    Http(HttpConnector),
50    #[cfg(feature = "default-tls")]
51    DefaultTls(HttpConnector, TlsConnector),
52    #[cfg(feature = "__rustls")]
53    RustlsTls {
54        http: HttpConnector,
55        tls: Arc<rustls::ClientConfig>,
56        tls_proxy: Arc<rustls::ClientConfig>,
57    },
58}
59
60impl Connector {
61    #[cfg(not(feature = "__tls"))]
62    pub(crate) fn new<T>(
63        mut http: HttpConnector,
64        proxies: Arc<Vec<Proxy>>,
65        local_addr: T,
66        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
67        interface: Option<&str>,
68        nodelay: bool,
69    ) -> Connector
70    where
71        T: Into<Option<IpAddr>>,
72    {
73        http.set_local_address(local_addr.into());
74        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
75        if let Some(interface) = interface {
76            http.set_interface(interface.to_owned());
77        }
78        http.set_nodelay(nodelay);
79
80        Connector {
81            inner: Inner::Http(http),
82            verbose: verbose::OFF,
83            proxies,
84            timeout: None,
85        }
86    }
87
88    #[cfg(feature = "default-tls")]
89    pub(crate) fn new_default_tls<T>(
90        http: HttpConnector,
91        tls: TlsConnectorBuilder,
92        proxies: Arc<Vec<Proxy>>,
93        user_agent: Option<HeaderValue>,
94        local_addr: T,
95        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
96        interface: Option<&str>,
97        nodelay: bool,
98        tls_info: bool,
99    ) -> crate::Result<Connector>
100    where
101        T: Into<Option<IpAddr>>,
102    {
103        let tls = tls.build().map_err(crate::error::builder)?;
104        Ok(Self::from_built_default_tls(
105            http,
106            tls,
107            proxies,
108            user_agent,
109            local_addr,
110            #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
111            interface,
112            nodelay,
113            tls_info,
114        ))
115    }
116
117    #[cfg(feature = "default-tls")]
118    pub(crate) fn from_built_default_tls<T>(
119        mut http: HttpConnector,
120        tls: TlsConnector,
121        proxies: Arc<Vec<Proxy>>,
122        user_agent: Option<HeaderValue>,
123        local_addr: T,
124        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
125        interface: Option<&str>,
126        nodelay: bool,
127        tls_info: bool,
128    ) -> Connector
129    where
130        T: Into<Option<IpAddr>>,
131    {
132        http.set_local_address(local_addr.into());
133        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
134        if let Some(interface) = interface {
135            http.set_interface(interface);
136        }
137        http.set_nodelay(nodelay);
138        http.enforce_http(false);
139
140        Connector {
141            inner: Inner::DefaultTls(http, tls),
142            proxies,
143            verbose: verbose::OFF,
144            timeout: None,
145            nodelay,
146            tls_info,
147            user_agent,
148        }
149    }
150
151    #[cfg(feature = "__rustls")]
152    pub(crate) fn new_rustls_tls<T>(
153        mut http: HttpConnector,
154        tls: rustls::ClientConfig,
155        proxies: Arc<Vec<Proxy>>,
156        user_agent: Option<HeaderValue>,
157        local_addr: T,
158        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
159        interface: Option<&str>,
160        nodelay: bool,
161        tls_info: bool,
162    ) -> Connector
163    where
164        T: Into<Option<IpAddr>>,
165    {
166        http.set_local_address(local_addr.into());
167        #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
168        if let Some(interface) = interface {
169            http.set_interface(interface.to_owned());
170        }
171        http.set_nodelay(nodelay);
172        http.enforce_http(false);
173
174        let (tls, tls_proxy) = if proxies.is_empty() {
175            let tls = Arc::new(tls);
176            (tls.clone(), tls)
177        } else {
178            let mut tls_proxy = tls.clone();
179            tls_proxy.alpn_protocols.clear();
180            (Arc::new(tls), Arc::new(tls_proxy))
181        };
182
183        Connector {
184            inner: Inner::RustlsTls {
185                http,
186                tls,
187                tls_proxy,
188            },
189            proxies,
190            verbose: verbose::OFF,
191            timeout: None,
192            nodelay,
193            tls_info,
194            user_agent,
195        }
196    }
197
198    pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
199        self.timeout = timeout;
200    }
201
202    pub(crate) fn set_verbose(&mut self, enabled: bool) {
203        self.verbose.0 = enabled;
204    }
205
206    #[cfg(feature = "socks")]
207    async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
208        let dns = match proxy {
209            ProxyScheme::Socks4 { .. } => socks::DnsResolve::Local,
210            ProxyScheme::Socks5 {
211                remote_dns: false, ..
212            } => socks::DnsResolve::Local,
213            ProxyScheme::Socks5 {
214                remote_dns: true, ..
215            } => socks::DnsResolve::Proxy,
216            ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
217                unreachable!("connect_socks is only called for socks proxies");
218            }
219        };
220
221        match &self.inner {
222            #[cfg(feature = "default-tls")]
223            Inner::DefaultTls(_http, tls) => {
224                if dst.scheme() == Some(&Scheme::HTTPS) {
225                    let host = dst.host().ok_or("no host in url")?.to_string();
226                    let conn = socks::connect(proxy, dst, dns).await?;
227                    let conn = TokioIo::new(conn);
228                    let conn = TokioIo::new(conn);
229                    let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
230                    let io = tls_connector.connect(&host, conn).await?;
231                    let io = TokioIo::new(io);
232                    return Ok(Conn {
233                        inner: self.verbose.wrap(NativeTlsConn { inner: io }),
234                        is_proxy: false,
235                        tls_info: self.tls_info,
236                    });
237                }
238            }
239            #[cfg(feature = "__rustls")]
240            Inner::RustlsTls { tls, .. } => {
241                if dst.scheme() == Some(&Scheme::HTTPS) {
242                    use std::convert::TryFrom;
243                    use tokio_rustls::TlsConnector as RustlsConnector;
244
245                    let tls = tls.clone();
246                    let host = dst.host().ok_or("no host in url")?.to_string();
247                    let conn = socks::connect(proxy, dst, dns).await?;
248                    let conn = TokioIo::new(conn);
249                    let conn = TokioIo::new(conn);
250                    let server_name =
251                        rustls_pki_types::ServerName::try_from(host.as_str().to_owned())
252                            .map_err(|_| "Invalid Server Name")?;
253                    let io = RustlsConnector::from(tls)
254                        .connect(server_name, conn)
255                        .await?;
256                    let io = TokioIo::new(io);
257                    return Ok(Conn {
258                        inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
259                        is_proxy: false,
260                        tls_info: false,
261                    });
262                }
263            }
264            #[cfg(not(feature = "__tls"))]
265            Inner::Http(_) => (),
266        }
267
268        socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
269            inner: self.verbose.wrap(TokioIo::new(tcp)),
270            is_proxy: false,
271            tls_info: false,
272        })
273    }
274
275    async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
276        match self.inner {
277            #[cfg(not(feature = "__tls"))]
278            Inner::Http(mut http) => {
279                let io = http.call(dst).await?;
280                Ok(Conn {
281                    inner: self.verbose.wrap(io),
282                    is_proxy,
283                    tls_info: false,
284                })
285            }
286            #[cfg(feature = "default-tls")]
287            Inner::DefaultTls(http, tls) => {
288                let mut http = http.clone();
289
290                // Disable Nagle's algorithm for TLS handshake
291                //
292                // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
293                if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
294                    http.set_nodelay(true);
295                }
296
297                let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
298                let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
299                let io = http.call(dst).await?;
300
301                if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
302                    if !self.nodelay {
303                        stream
304                            .inner()
305                            .get_ref()
306                            .get_ref()
307                            .get_ref()
308                            .inner()
309                            .inner()
310                            .set_nodelay(false)?;
311                    }
312                    Ok(Conn {
313                        inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
314                        is_proxy,
315                        tls_info: self.tls_info,
316                    })
317                } else {
318                    Ok(Conn {
319                        inner: self.verbose.wrap(io),
320                        is_proxy,
321                        tls_info: false,
322                    })
323                }
324            }
325            #[cfg(feature = "__rustls")]
326            Inner::RustlsTls { http, tls, .. } => {
327                let mut http = http.clone();
328
329                // Disable Nagle's algorithm for TLS handshake
330                //
331                // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
332                if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
333                    http.set_nodelay(true);
334                }
335
336                let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
337                let io = http.call(dst).await?;
338
339                if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
340                    if !self.nodelay {
341                        let (io, _) = stream.inner().get_ref();
342                        io.inner().inner().set_nodelay(false)?;
343                    }
344                    Ok(Conn {
345                        inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
346                        is_proxy,
347                        tls_info: self.tls_info,
348                    })
349                } else {
350                    Ok(Conn {
351                        inner: self.verbose.wrap(io),
352                        is_proxy,
353                        tls_info: false,
354                    })
355                }
356            }
357        }
358    }
359
360    async fn connect_via_proxy(
361        self,
362        dst: Uri,
363        proxy_scheme: ProxyScheme,
364    ) -> Result<Conn, BoxError> {
365        log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");
366
367        let (proxy_dst, _auth) = match proxy_scheme {
368            ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
369            ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
370            #[cfg(feature = "socks")]
371            ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await,
372            #[cfg(feature = "socks")]
373            ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
374        };
375
376        #[cfg(feature = "__tls")]
377        let auth = _auth;
378
379        match &self.inner {
380            #[cfg(feature = "default-tls")]
381            Inner::DefaultTls(http, tls) => {
382                if dst.scheme() == Some(&Scheme::HTTPS) {
383                    let host = dst.host().to_owned();
384                    let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
385                    let http = http.clone();
386                    let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
387                    let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
388                    let conn = http.call(proxy_dst).await?;
389                    log::trace!("tunneling HTTPS over proxy");
390                    let tunneled = tunnel(
391                        conn,
392                        host.ok_or("no host in url")?.to_string(),
393                        port,
394                        self.user_agent.clone(),
395                        auth,
396                    )
397                    .await?;
398                    let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
399                    let io = tls_connector
400                        .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled))
401                        .await?;
402                    return Ok(Conn {
403                        inner: self.verbose.wrap(NativeTlsConn {
404                            inner: TokioIo::new(io),
405                        }),
406                        is_proxy: false,
407                        tls_info: false,
408                    });
409                }
410            }
411            #[cfg(feature = "__rustls")]
412            Inner::RustlsTls {
413                http,
414                tls,
415                tls_proxy,
416            } => {
417                if dst.scheme() == Some(&Scheme::HTTPS) {
418                    use rustls_pki_types::ServerName;
419                    use std::convert::TryFrom;
420                    use tokio_rustls::TlsConnector as RustlsConnector;
421
422                    let host = dst.host().ok_or("no host in url")?.to_string();
423                    let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
424                    let http = http.clone();
425                    let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
426                    let tls = tls.clone();
427                    let conn = http.call(proxy_dst).await?;
428                    log::trace!("tunneling HTTPS over proxy");
429                    let maybe_server_name = ServerName::try_from(host.as_str().to_owned())
430                        .map_err(|_| "Invalid Server Name");
431                    let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
432                    let server_name = maybe_server_name?;
433                    let io = RustlsConnector::from(tls)
434                        .connect(server_name, TokioIo::new(tunneled))
435                        .await?;
436
437                    return Ok(Conn {
438                        inner: self.verbose.wrap(RustlsTlsConn {
439                            inner: TokioIo::new(io),
440                        }),
441                        is_proxy: false,
442                        tls_info: false,
443                    });
444                }
445            }
446            #[cfg(not(feature = "__tls"))]
447            Inner::Http(_) => (),
448        }
449
450        self.connect_with_maybe_proxy(proxy_dst, true).await
451    }
452
453    pub fn set_keepalive(&mut self, dur: Option<Duration>) {
454        match &mut self.inner {
455            #[cfg(feature = "default-tls")]
456            Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
457            #[cfg(feature = "__rustls")]
458            Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
459            #[cfg(not(feature = "__tls"))]
460            Inner::Http(http) => http.set_keepalive(dur),
461        }
462    }
463}
464
465fn into_uri(scheme: Scheme, host: Authority) -> Uri {
466    // TODO: Should the `http` crate get `From<(Scheme, Authority)> for Uri`?
467    http::Uri::builder()
468        .scheme(scheme)
469        .authority(host)
470        .path_and_query(http::uri::PathAndQuery::from_static("/"))
471        .build()
472        .expect("scheme and authority is valid Uri")
473}
474
475async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
476where
477    F: Future<Output = Result<T, BoxError>>,
478{
479    if let Some(to) = timeout {
480        match tokio::time::timeout(to, f).await {
481            Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
482            Ok(Ok(try_res)) => Ok(try_res),
483            Ok(Err(e)) => Err(e),
484        }
485    } else {
486        f.await
487    }
488}
489
490impl Service<Uri> for Connector {
491    type Response = Conn;
492    type Error = BoxError;
493    type Future = Connecting;
494
495    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
496        Poll::Ready(Ok(()))
497    }
498
499    fn call(&mut self, dst: Uri) -> Self::Future {
500        log::debug!("starting new connection: {dst:?}");
501        let timeout = self.timeout;
502        for prox in self.proxies.iter() {
503            if let Some(proxy_scheme) = prox.intercept(&dst) {
504                return Box::pin(with_timeout(
505                    self.clone().connect_via_proxy(dst, proxy_scheme),
506                    timeout,
507                ));
508            }
509        }
510
511        Box::pin(with_timeout(
512            self.clone().connect_with_maybe_proxy(dst, false),
513            timeout,
514        ))
515    }
516}
517
518#[cfg(feature = "__tls")]
519trait TlsInfoFactory {
520    fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
521}
522
523#[cfg(feature = "__tls")]
524impl TlsInfoFactory for tokio::net::TcpStream {
525    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
526        None
527    }
528}
529
530#[cfg(feature = "__tls")]
531impl<T: TlsInfoFactory> TlsInfoFactory for TokioIo<T> {
532    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
533        self.inner().tls_info()
534    }
535}
536
537#[cfg(feature = "default-tls")]
538impl TlsInfoFactory for tokio_native_tls::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
539    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
540        let peer_certificate = self
541            .get_ref()
542            .peer_certificate()
543            .ok()
544            .flatten()
545            .and_then(|c| c.to_der().ok());
546        Some(crate::tls::TlsInfo { peer_certificate })
547    }
548}
549
550#[cfg(feature = "default-tls")]
551impl TlsInfoFactory
552    for tokio_native_tls::TlsStream<
553        TokioIo<hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
554    >
555{
556    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
557        let peer_certificate = self
558            .get_ref()
559            .peer_certificate()
560            .ok()
561            .flatten()
562            .and_then(|c| c.to_der().ok());
563        Some(crate::tls::TlsInfo { peer_certificate })
564    }
565}
566
567#[cfg(feature = "default-tls")]
568impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
569    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
570        match self {
571            hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
572            hyper_tls::MaybeHttpsStream::Http(_) => None,
573        }
574    }
575}
576
577#[cfg(feature = "__rustls")]
578impl TlsInfoFactory for tokio_rustls::client::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
579    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
580        let peer_certificate = self
581            .get_ref()
582            .1
583            .peer_certificates()
584            .and_then(|certs| certs.first())
585            .map(|c| c.to_vec());
586        Some(crate::tls::TlsInfo { peer_certificate })
587    }
588}
589
590#[cfg(feature = "__rustls")]
591impl TlsInfoFactory
592    for tokio_rustls::client::TlsStream<
593        TokioIo<hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
594    >
595{
596    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
597        let peer_certificate = self
598            .get_ref()
599            .1
600            .peer_certificates()
601            .and_then(|certs| certs.first())
602            .map(|c| c.to_vec());
603        Some(crate::tls::TlsInfo { peer_certificate })
604    }
605}
606
607#[cfg(feature = "__rustls")]
608impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
609    fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
610        match self {
611            hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
612            hyper_rustls::MaybeHttpsStream::Http(_) => None,
613        }
614    }
615}
616
617pub(crate) trait AsyncConn:
618    Read + Write + Connection + Send + Sync + Unpin + 'static
619{
620}
621
622impl<T: Read + Write + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
623
624#[cfg(feature = "__tls")]
625trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
626#[cfg(not(feature = "__tls"))]
627trait AsyncConnWithInfo: AsyncConn {}
628
629#[cfg(feature = "__tls")]
630impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {}
631#[cfg(not(feature = "__tls"))]
632impl<T: AsyncConn> AsyncConnWithInfo for T {}
633
634type BoxConn = Box<dyn AsyncConnWithInfo>;
635
636pin_project! {
637    /// Note: the `is_proxy` member means *is plain text HTTP proxy*.
638    /// This tells hyper whether the URI should be written in
639    /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
640    /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
641    pub(crate) struct Conn {
642        #[pin]
643        inner: BoxConn,
644        is_proxy: bool,
645        // Only needed for __tls, but #[cfg()] on fields breaks pin_project!
646        tls_info: bool,
647    }
648}
649
650impl Connection for Conn {
651    fn connected(&self) -> Connected {
652        let connected = self.inner.connected().proxy(self.is_proxy);
653        #[cfg(feature = "__tls")]
654        if self.tls_info {
655            if let Some(tls_info) = self.inner.tls_info() {
656                connected.extra(tls_info)
657            } else {
658                connected
659            }
660        } else {
661            connected
662        }
663        #[cfg(not(feature = "__tls"))]
664        connected
665    }
666}
667
668impl Read for Conn {
669    fn poll_read(
670        self: Pin<&mut Self>,
671        cx: &mut Context,
672        buf: ReadBufCursor<'_>,
673    ) -> Poll<io::Result<()>> {
674        let this = self.project();
675        Read::poll_read(this.inner, cx, buf)
676    }
677}
678
679impl Write for Conn {
680    fn poll_write(
681        self: Pin<&mut Self>,
682        cx: &mut Context,
683        buf: &[u8],
684    ) -> Poll<Result<usize, io::Error>> {
685        let this = self.project();
686        Write::poll_write(this.inner, cx, buf)
687    }
688
689    fn poll_write_vectored(
690        self: Pin<&mut Self>,
691        cx: &mut Context<'_>,
692        bufs: &[IoSlice<'_>],
693    ) -> Poll<Result<usize, io::Error>> {
694        let this = self.project();
695        Write::poll_write_vectored(this.inner, cx, bufs)
696    }
697
698    fn is_write_vectored(&self) -> bool {
699        self.inner.is_write_vectored()
700    }
701
702    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
703        let this = self.project();
704        Write::poll_flush(this.inner, cx)
705    }
706
707    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
708        let this = self.project();
709        Write::poll_shutdown(this.inner, cx)
710    }
711}
712
713pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
714
715#[cfg(feature = "__tls")]
716async fn tunnel<T>(
717    mut conn: T,
718    host: String,
719    port: u16,
720    user_agent: Option<HeaderValue>,
721    auth: Option<HeaderValue>,
722) -> Result<T, BoxError>
723where
724    T: Read + Write + Unpin,
725{
726    use hyper_util::rt::TokioIo;
727    use tokio::io::{AsyncReadExt, AsyncWriteExt};
728
729    let mut buf = format!(
730        "\
731         CONNECT {host}:{port} HTTP/1.1\r\n\
732         Host: {host}:{port}\r\n\
733         "
734    )
735    .into_bytes();
736
737    // user-agent
738    if let Some(user_agent) = user_agent {
739        buf.extend_from_slice(b"User-Agent: ");
740        buf.extend_from_slice(user_agent.as_bytes());
741        buf.extend_from_slice(b"\r\n");
742    }
743
744    // proxy-authorization
745    if let Some(value) = auth {
746        log::debug!("tunnel to {host}:{port} using basic auth");
747        buf.extend_from_slice(b"Proxy-Authorization: ");
748        buf.extend_from_slice(value.as_bytes());
749        buf.extend_from_slice(b"\r\n");
750    }
751
752    // headers end
753    buf.extend_from_slice(b"\r\n");
754
755    let mut tokio_conn = TokioIo::new(&mut conn);
756
757    tokio_conn.write_all(&buf).await?;
758
759    let mut buf = [0; 8192];
760    let mut pos = 0;
761
762    loop {
763        let n = tokio_conn.read(&mut buf[pos..]).await?;
764
765        if n == 0 {
766            return Err(tunnel_eof());
767        }
768        pos += n;
769
770        let recvd = &buf[..pos];
771        if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
772            if recvd.ends_with(b"\r\n\r\n") {
773                return Ok(conn);
774            }
775            if pos == buf.len() {
776                return Err("proxy headers too long for tunnel".into());
777            }
778        // else read more
779        } else if recvd.starts_with(b"HTTP/1.1 407") {
780            return Err("proxy authentication required".into());
781        } else {
782            return Err("unsuccessful tunnel".into());
783        }
784    }
785}
786
787#[cfg(feature = "__tls")]
788fn tunnel_eof() -> BoxError {
789    "unexpected eof while tunneling".into()
790}
791
792#[cfg(feature = "default-tls")]
793mod native_tls_conn {
794    use super::TlsInfoFactory;
795    use hyper::rt::{Read, ReadBufCursor, Write};
796    use hyper_tls::MaybeHttpsStream;
797    use hyper_util::client::legacy::connect::{Connected, Connection};
798    use hyper_util::rt::TokioIo;
799    use pin_project_lite::pin_project;
800    use std::{
801        io::{self, IoSlice},
802        pin::Pin,
803        task::{Context, Poll},
804    };
805    use tokio::io::{AsyncRead, AsyncWrite};
806    use tokio::net::TcpStream;
807    use tokio_native_tls::TlsStream;
808
809    pin_project! {
810        pub(super) struct NativeTlsConn<T> {
811            #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
812        }
813    }
814
815    impl Connection for NativeTlsConn<TokioIo<TokioIo<TcpStream>>> {
816        fn connected(&self) -> Connected {
817            let connected = self
818                .inner
819                .inner()
820                .get_ref()
821                .get_ref()
822                .get_ref()
823                .inner()
824                .connected();
825            #[cfg(feature = "native-tls-alpn")]
826            match self.inner.inner().get_ref().negotiated_alpn().ok() {
827                Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
828                _ => connected,
829            }
830            #[cfg(not(feature = "native-tls-alpn"))]
831            connected
832        }
833    }
834
835    impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
836        fn connected(&self) -> Connected {
837            let connected = self
838                .inner
839                .inner()
840                .get_ref()
841                .get_ref()
842                .get_ref()
843                .inner()
844                .connected();
845            #[cfg(feature = "native-tls-alpn")]
846            match self.inner.inner().get_ref().negotiated_alpn().ok() {
847                Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
848                _ => connected,
849            }
850            #[cfg(not(feature = "native-tls-alpn"))]
851            connected
852        }
853    }
854
855    impl<T: AsyncRead + AsyncWrite + Unpin> Read for NativeTlsConn<T> {
856        fn poll_read(
857            self: Pin<&mut Self>,
858            cx: &mut Context,
859            buf: ReadBufCursor<'_>,
860        ) -> Poll<tokio::io::Result<()>> {
861            let this = self.project();
862            Read::poll_read(this.inner, cx, buf)
863        }
864    }
865
866    impl<T: AsyncRead + AsyncWrite + Unpin> Write for NativeTlsConn<T> {
867        fn poll_write(
868            self: Pin<&mut Self>,
869            cx: &mut Context,
870            buf: &[u8],
871        ) -> Poll<Result<usize, tokio::io::Error>> {
872            let this = self.project();
873            Write::poll_write(this.inner, cx, buf)
874        }
875
876        fn poll_write_vectored(
877            self: Pin<&mut Self>,
878            cx: &mut Context<'_>,
879            bufs: &[IoSlice<'_>],
880        ) -> Poll<Result<usize, io::Error>> {
881            let this = self.project();
882            Write::poll_write_vectored(this.inner, cx, bufs)
883        }
884
885        fn is_write_vectored(&self) -> bool {
886            self.inner.is_write_vectored()
887        }
888
889        fn poll_flush(
890            self: Pin<&mut Self>,
891            cx: &mut Context,
892        ) -> Poll<Result<(), tokio::io::Error>> {
893            let this = self.project();
894            Write::poll_flush(this.inner, cx)
895        }
896
897        fn poll_shutdown(
898            self: Pin<&mut Self>,
899            cx: &mut Context,
900        ) -> Poll<Result<(), tokio::io::Error>> {
901            let this = self.project();
902            Write::poll_shutdown(this.inner, cx)
903        }
904    }
905
906    impl<T> TlsInfoFactory for NativeTlsConn<T>
907    where
908        TokioIo<TlsStream<T>>: TlsInfoFactory,
909    {
910        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
911            self.inner.tls_info()
912        }
913    }
914}
915
916#[cfg(feature = "__rustls")]
917mod rustls_tls_conn {
918    use super::TlsInfoFactory;
919    use hyper::rt::{Read, ReadBufCursor, Write};
920    use hyper_rustls::MaybeHttpsStream;
921    use hyper_util::client::legacy::connect::{Connected, Connection};
922    use hyper_util::rt::TokioIo;
923    use pin_project_lite::pin_project;
924    use std::{
925        io::{self, IoSlice},
926        pin::Pin,
927        task::{Context, Poll},
928    };
929    use tokio::io::{AsyncRead, AsyncWrite};
930    use tokio::net::TcpStream;
931    use tokio_rustls::client::TlsStream;
932
933    pin_project! {
934        pub(super) struct RustlsTlsConn<T> {
935            #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
936        }
937    }
938
939    impl Connection for RustlsTlsConn<TokioIo<TokioIo<TcpStream>>> {
940        fn connected(&self) -> Connected {
941            if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
942                self.inner
943                    .inner()
944                    .get_ref()
945                    .0
946                    .inner()
947                    .connected()
948                    .negotiated_h2()
949            } else {
950                self.inner.inner().get_ref().0.inner().connected()
951            }
952        }
953    }
954    impl Connection for RustlsTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
955        fn connected(&self) -> Connected {
956            if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
957                self.inner
958                    .inner()
959                    .get_ref()
960                    .0
961                    .inner()
962                    .connected()
963                    .negotiated_h2()
964            } else {
965                self.inner.inner().get_ref().0.inner().connected()
966            }
967        }
968    }
969
970    impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustlsTlsConn<T> {
971        fn poll_read(
972            self: Pin<&mut Self>,
973            cx: &mut Context,
974            buf: ReadBufCursor<'_>,
975        ) -> Poll<tokio::io::Result<()>> {
976            let this = self.project();
977            Read::poll_read(this.inner, cx, buf)
978        }
979    }
980
981    impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustlsTlsConn<T> {
982        fn poll_write(
983            self: Pin<&mut Self>,
984            cx: &mut Context,
985            buf: &[u8],
986        ) -> Poll<Result<usize, tokio::io::Error>> {
987            let this = self.project();
988            Write::poll_write(this.inner, cx, buf)
989        }
990
991        fn poll_write_vectored(
992            self: Pin<&mut Self>,
993            cx: &mut Context<'_>,
994            bufs: &[IoSlice<'_>],
995        ) -> Poll<Result<usize, io::Error>> {
996            let this = self.project();
997            Write::poll_write_vectored(this.inner, cx, bufs)
998        }
999
1000        fn is_write_vectored(&self) -> bool {
1001            self.inner.is_write_vectored()
1002        }
1003
1004        fn poll_flush(
1005            self: Pin<&mut Self>,
1006            cx: &mut Context,
1007        ) -> Poll<Result<(), tokio::io::Error>> {
1008            let this = self.project();
1009            Write::poll_flush(this.inner, cx)
1010        }
1011
1012        fn poll_shutdown(
1013            self: Pin<&mut Self>,
1014            cx: &mut Context,
1015        ) -> Poll<Result<(), tokio::io::Error>> {
1016            let this = self.project();
1017            Write::poll_shutdown(this.inner, cx)
1018        }
1019    }
1020    impl<T> TlsInfoFactory for RustlsTlsConn<T>
1021    where
1022        TokioIo<TlsStream<T>>: TlsInfoFactory,
1023    {
1024        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1025            self.inner.tls_info()
1026        }
1027    }
1028}
1029
1030#[cfg(feature = "socks")]
1031mod socks {
1032    use std::io;
1033    use std::net::ToSocketAddrs;
1034
1035    use http::Uri;
1036    use tokio::net::TcpStream;
1037    use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
1038
1039    use super::{BoxError, Scheme};
1040    use crate::proxy::ProxyScheme;
1041
1042    pub(super) enum DnsResolve {
1043        Local,
1044        Proxy,
1045    }
1046
1047    pub(super) async fn connect(
1048        proxy: ProxyScheme,
1049        dst: Uri,
1050        dns: DnsResolve,
1051    ) -> Result<TcpStream, BoxError> {
1052        let https = dst.scheme() == Some(&Scheme::HTTPS);
1053        let original_host = dst
1054            .host()
1055            .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
1056        let mut host = original_host.to_owned();
1057        let port = match dst.port() {
1058            Some(p) => p.as_u16(),
1059            None if https => 443u16,
1060            _ => 80u16,
1061        };
1062
1063        if let DnsResolve::Local = dns {
1064            let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
1065            if let Some(new_target) = maybe_new_target {
1066                host = new_target.ip().to_string();
1067            }
1068        }
1069
1070        match proxy {
1071            ProxyScheme::Socks4 { addr } => {
1072                let stream = Socks4Stream::connect(addr, (host.as_str(), port))
1073                    .await
1074                    .map_err(|e| format!("socks connect error: {e}"))?;
1075                Ok(stream.into_inner())
1076            }
1077            ProxyScheme::Socks5 { addr, ref auth, .. } => {
1078                let stream = if let Some((username, password)) = auth {
1079                    Socks5Stream::connect_with_password(
1080                        addr,
1081                        (host.as_str(), port),
1082                        &username,
1083                        &password,
1084                    )
1085                    .await
1086                    .map_err(|e| format!("socks connect error: {e}"))?
1087                } else {
1088                    Socks5Stream::connect(addr, (host.as_str(), port))
1089                        .await
1090                        .map_err(|e| format!("socks connect error: {e}"))?
1091                };
1092
1093                Ok(stream.into_inner())
1094            }
1095            _ => unreachable!(),
1096        }
1097    }
1098}
1099
1100mod verbose {
1101    use hyper::rt::{Read, ReadBufCursor, Write};
1102    use hyper_util::client::legacy::connect::{Connected, Connection};
1103    use std::cmp::min;
1104    use std::fmt;
1105    use std::io::{self, IoSlice};
1106    use std::pin::Pin;
1107    use std::task::{Context, Poll};
1108
1109    pub(super) const OFF: Wrapper = Wrapper(false);
1110
1111    #[derive(Clone, Copy)]
1112    pub(super) struct Wrapper(pub(super) bool);
1113
1114    impl Wrapper {
1115        pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn {
1116            if self.0 && log::log_enabled!(log::Level::Trace) {
1117                Box::new(Verbose {
1118                    // truncate is fine
1119                    id: crate::util::fast_random() as u32,
1120                    inner: conn,
1121                })
1122            } else {
1123                Box::new(conn)
1124            }
1125        }
1126    }
1127
1128    struct Verbose<T> {
1129        id: u32,
1130        inner: T,
1131    }
1132
1133    impl<T: Connection + Read + Write + Unpin> Connection for Verbose<T> {
1134        fn connected(&self) -> Connected {
1135            self.inner.connected()
1136        }
1137    }
1138
1139    impl<T: Read + Write + Unpin> Read for Verbose<T> {
1140        fn poll_read(
1141            mut self: Pin<&mut Self>,
1142            cx: &mut Context,
1143            mut buf: ReadBufCursor<'_>,
1144        ) -> Poll<std::io::Result<()>> {
1145            // TODO: This _does_ forget the `init` len, so it could result in
1146            // re-initializing twice. Needs upstream support, perhaps.
1147            // SAFETY: Passing to a ReadBuf will never de-initialize any bytes.
1148            let mut vbuf = hyper::rt::ReadBuf::uninit(unsafe { buf.as_mut() });
1149            match Pin::new(&mut self.inner).poll_read(cx, vbuf.unfilled()) {
1150                Poll::Ready(Ok(())) => {
1151                    log::trace!("{:08x} read: {:?}", self.id, Escape(vbuf.filled()));
1152                    let len = vbuf.filled().len();
1153                    // SAFETY: The two cursors were for the same buffer. What was
1154                    // filled in one is safe in the other.
1155                    unsafe {
1156                        buf.advance(len);
1157                    }
1158                    Poll::Ready(Ok(()))
1159                }
1160                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1161                Poll::Pending => Poll::Pending,
1162            }
1163        }
1164    }
1165
1166    impl<T: Read + Write + Unpin> Write for Verbose<T> {
1167        fn poll_write(
1168            mut self: Pin<&mut Self>,
1169            cx: &mut Context,
1170            buf: &[u8],
1171        ) -> Poll<Result<usize, std::io::Error>> {
1172            match Pin::new(&mut self.inner).poll_write(cx, buf) {
1173                Poll::Ready(Ok(n)) => {
1174                    log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1175                    Poll::Ready(Ok(n))
1176                }
1177                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1178                Poll::Pending => Poll::Pending,
1179            }
1180        }
1181
1182        fn poll_write_vectored(
1183            mut self: Pin<&mut Self>,
1184            cx: &mut Context<'_>,
1185            bufs: &[IoSlice<'_>],
1186        ) -> Poll<Result<usize, io::Error>> {
1187            match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
1188                Poll::Ready(Ok(nwritten)) => {
1189                    log::trace!(
1190                        "{:08x} write (vectored): {:?}",
1191                        self.id,
1192                        Vectored { bufs, nwritten }
1193                    );
1194                    Poll::Ready(Ok(nwritten))
1195                }
1196                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1197                Poll::Pending => Poll::Pending,
1198            }
1199        }
1200
1201        fn is_write_vectored(&self) -> bool {
1202            self.inner.is_write_vectored()
1203        }
1204
1205        fn poll_flush(
1206            mut self: Pin<&mut Self>,
1207            cx: &mut Context,
1208        ) -> Poll<Result<(), std::io::Error>> {
1209            Pin::new(&mut self.inner).poll_flush(cx)
1210        }
1211
1212        fn poll_shutdown(
1213            mut self: Pin<&mut Self>,
1214            cx: &mut Context,
1215        ) -> Poll<Result<(), std::io::Error>> {
1216            Pin::new(&mut self.inner).poll_shutdown(cx)
1217        }
1218    }
1219
1220    #[cfg(feature = "__tls")]
1221    impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> {
1222        fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1223            self.inner.tls_info()
1224        }
1225    }
1226
1227    struct Escape<'a>(&'a [u8]);
1228
1229    impl fmt::Debug for Escape<'_> {
1230        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1231            write!(f, "b\"")?;
1232            for &c in self.0 {
1233                // https://doc.rust-lang.org/reference.html#byte-escapes
1234                if c == b'\n' {
1235                    write!(f, "\\n")?;
1236                } else if c == b'\r' {
1237                    write!(f, "\\r")?;
1238                } else if c == b'\t' {
1239                    write!(f, "\\t")?;
1240                } else if c == b'\\' || c == b'"' {
1241                    write!(f, "\\{}", c as char)?;
1242                } else if c == b'\0' {
1243                    write!(f, "\\0")?;
1244                // ASCII printable
1245                } else if c >= 0x20 && c < 0x7f {
1246                    write!(f, "{}", c as char)?;
1247                } else {
1248                    write!(f, "\\x{c:02x}")?;
1249                }
1250            }
1251            write!(f, "\"")?;
1252            Ok(())
1253        }
1254    }
1255
1256    struct Vectored<'a, 'b> {
1257        bufs: &'a [IoSlice<'b>],
1258        nwritten: usize,
1259    }
1260
1261    impl fmt::Debug for Vectored<'_, '_> {
1262        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1263            let mut left = self.nwritten;
1264            for buf in self.bufs.iter() {
1265                if left == 0 {
1266                    break;
1267                }
1268                let n = min(left, buf.len());
1269                Escape(&buf[..n]).fmt(f)?;
1270                left -= n;
1271            }
1272            Ok(())
1273        }
1274    }
1275}
1276
1277#[cfg(feature = "__tls")]
1278#[cfg(test)]
1279mod tests {
1280    use super::tunnel;
1281    use crate::proxy;
1282    use hyper_util::rt::TokioIo;
1283    use std::io::{Read, Write};
1284    use std::net::TcpListener;
1285    use std::thread;
1286    use tokio::net::TcpStream;
1287    use tokio::runtime;
1288
1289    static TUNNEL_UA: &str = "tunnel-test/x.y";
1290    static TUNNEL_OK: &[u8] = b"\
1291        HTTP/1.1 200 OK\r\n\
1292        \r\n\
1293    ";
1294
1295    macro_rules! mock_tunnel {
1296        () => {{
1297            mock_tunnel!(TUNNEL_OK)
1298        }};
1299        ($write:expr) => {{
1300            mock_tunnel!($write, "")
1301        }};
1302        ($write:expr, $auth:expr) => {{
1303            let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1304            let addr = listener.local_addr().unwrap();
1305            let connect_expected = format!(
1306                "\
1307                 CONNECT {0}:{1} HTTP/1.1\r\n\
1308                 Host: {0}:{1}\r\n\
1309                 User-Agent: {2}\r\n\
1310                 {3}\
1311                 \r\n\
1312                 ",
1313                addr.ip(),
1314                addr.port(),
1315                TUNNEL_UA,
1316                $auth
1317            )
1318            .into_bytes();
1319
1320            thread::spawn(move || {
1321                let (mut sock, _) = listener.accept().unwrap();
1322                let mut buf = [0u8; 4096];
1323                let n = sock.read(&mut buf).unwrap();
1324                assert_eq!(&buf[..n], &connect_expected[..]);
1325
1326                sock.write_all($write).unwrap();
1327            });
1328            addr
1329        }};
1330    }
1331
1332    fn ua() -> Option<http::header::HeaderValue> {
1333        Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1334    }
1335
1336    #[test]
1337    fn test_tunnel() {
1338        let addr = mock_tunnel!();
1339
1340        let rt = runtime::Builder::new_current_thread()
1341            .enable_all()
1342            .build()
1343            .expect("new rt");
1344        let f = async move {
1345            let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1346            let host = addr.ip().to_string();
1347            let port = addr.port();
1348            tunnel(tcp, host, port, ua(), None).await
1349        };
1350
1351        rt.block_on(f).unwrap();
1352    }
1353
1354    #[test]
1355    fn test_tunnel_eof() {
1356        let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1357
1358        let rt = runtime::Builder::new_current_thread()
1359            .enable_all()
1360            .build()
1361            .expect("new rt");
1362        let f = async move {
1363            let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1364            let host = addr.ip().to_string();
1365            let port = addr.port();
1366            tunnel(tcp, host, port, ua(), None).await
1367        };
1368
1369        rt.block_on(f).unwrap_err();
1370    }
1371
1372    #[test]
1373    fn test_tunnel_non_http_response() {
1374        let addr = mock_tunnel!(b"foo bar baz hallo");
1375
1376        let rt = runtime::Builder::new_current_thread()
1377            .enable_all()
1378            .build()
1379            .expect("new rt");
1380        let f = async move {
1381            let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1382            let host = addr.ip().to_string();
1383            let port = addr.port();
1384            tunnel(tcp, host, port, ua(), None).await
1385        };
1386
1387        rt.block_on(f).unwrap_err();
1388    }
1389
1390    #[test]
1391    fn test_tunnel_proxy_unauthorized() {
1392        let addr = mock_tunnel!(
1393            b"\
1394            HTTP/1.1 407 Proxy Authentication Required\r\n\
1395            Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1396            \r\n\
1397        "
1398        );
1399
1400        let rt = runtime::Builder::new_current_thread()
1401            .enable_all()
1402            .build()
1403            .expect("new rt");
1404        let f = async move {
1405            let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1406            let host = addr.ip().to_string();
1407            let port = addr.port();
1408            tunnel(tcp, host, port, ua(), None).await
1409        };
1410
1411        let error = rt.block_on(f).unwrap_err();
1412        assert_eq!(error.to_string(), "proxy authentication required");
1413    }
1414
1415    #[test]
1416    fn test_tunnel_basic_auth() {
1417        let addr = mock_tunnel!(
1418            TUNNEL_OK,
1419            "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1420        );
1421
1422        let rt = runtime::Builder::new_current_thread()
1423            .enable_all()
1424            .build()
1425            .expect("new rt");
1426        let f = async move {
1427            let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1428            let host = addr.ip().to_string();
1429            let port = addr.port();
1430            tunnel(
1431                tcp,
1432                host,
1433                port,
1434                ua(),
1435                Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1436            )
1437            .await
1438        };
1439
1440        rt.block_on(f).unwrap();
1441    }
1442}