1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7#[cfg(any(feature = "socks", feature = "__tls"))]
8use hyper_util::rt::TokioIo;
9#[cfg(feature = "default-tls")]
10use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
11use tower_service::Service;
12
13use pin_project_lite::pin_project;
14use std::future::Future;
15use std::io::{self, IoSlice};
16use std::net::IpAddr;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21
22#[cfg(feature = "default-tls")]
23use self::native_tls_conn::NativeTlsConn;
24#[cfg(feature = "__rustls")]
25use self::rustls_tls_conn::RustlsTlsConn;
26use crate::dns::DynResolver;
27use crate::error::BoxError;
28use crate::proxy::{Proxy, ProxyScheme};
29
30pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector<DynResolver>;
31
32#[derive(Clone)]
33pub(crate) struct Connector {
34 inner: Inner,
35 proxies: Arc<Vec<Proxy>>,
36 verbose: verbose::Wrapper,
37 timeout: Option<Duration>,
38 #[cfg(feature = "__tls")]
39 nodelay: bool,
40 #[cfg(feature = "__tls")]
41 tls_info: bool,
42 #[cfg(feature = "__tls")]
43 user_agent: Option<HeaderValue>,
44}
45
46#[derive(Clone)]
47enum Inner {
48 #[cfg(not(feature = "__tls"))]
49 Http(HttpConnector),
50 #[cfg(feature = "default-tls")]
51 DefaultTls(HttpConnector, TlsConnector),
52 #[cfg(feature = "__rustls")]
53 RustlsTls {
54 http: HttpConnector,
55 tls: Arc<rustls::ClientConfig>,
56 tls_proxy: Arc<rustls::ClientConfig>,
57 },
58}
59
60impl Connector {
61 #[cfg(not(feature = "__tls"))]
62 pub(crate) fn new<T>(
63 mut http: HttpConnector,
64 proxies: Arc<Vec<Proxy>>,
65 local_addr: T,
66 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
67 interface: Option<&str>,
68 nodelay: bool,
69 ) -> Connector
70 where
71 T: Into<Option<IpAddr>>,
72 {
73 http.set_local_address(local_addr.into());
74 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
75 if let Some(interface) = interface {
76 http.set_interface(interface.to_owned());
77 }
78 http.set_nodelay(nodelay);
79
80 Connector {
81 inner: Inner::Http(http),
82 verbose: verbose::OFF,
83 proxies,
84 timeout: None,
85 }
86 }
87
88 #[cfg(feature = "default-tls")]
89 pub(crate) fn new_default_tls<T>(
90 http: HttpConnector,
91 tls: TlsConnectorBuilder,
92 proxies: Arc<Vec<Proxy>>,
93 user_agent: Option<HeaderValue>,
94 local_addr: T,
95 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
96 interface: Option<&str>,
97 nodelay: bool,
98 tls_info: bool,
99 ) -> crate::Result<Connector>
100 where
101 T: Into<Option<IpAddr>>,
102 {
103 let tls = tls.build().map_err(crate::error::builder)?;
104 Ok(Self::from_built_default_tls(
105 http,
106 tls,
107 proxies,
108 user_agent,
109 local_addr,
110 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
111 interface,
112 nodelay,
113 tls_info,
114 ))
115 }
116
117 #[cfg(feature = "default-tls")]
118 pub(crate) fn from_built_default_tls<T>(
119 mut http: HttpConnector,
120 tls: TlsConnector,
121 proxies: Arc<Vec<Proxy>>,
122 user_agent: Option<HeaderValue>,
123 local_addr: T,
124 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
125 interface: Option<&str>,
126 nodelay: bool,
127 tls_info: bool,
128 ) -> Connector
129 where
130 T: Into<Option<IpAddr>>,
131 {
132 http.set_local_address(local_addr.into());
133 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
134 if let Some(interface) = interface {
135 http.set_interface(interface);
136 }
137 http.set_nodelay(nodelay);
138 http.enforce_http(false);
139
140 Connector {
141 inner: Inner::DefaultTls(http, tls),
142 proxies,
143 verbose: verbose::OFF,
144 timeout: None,
145 nodelay,
146 tls_info,
147 user_agent,
148 }
149 }
150
151 #[cfg(feature = "__rustls")]
152 pub(crate) fn new_rustls_tls<T>(
153 mut http: HttpConnector,
154 tls: rustls::ClientConfig,
155 proxies: Arc<Vec<Proxy>>,
156 user_agent: Option<HeaderValue>,
157 local_addr: T,
158 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
159 interface: Option<&str>,
160 nodelay: bool,
161 tls_info: bool,
162 ) -> Connector
163 where
164 T: Into<Option<IpAddr>>,
165 {
166 http.set_local_address(local_addr.into());
167 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
168 if let Some(interface) = interface {
169 http.set_interface(interface.to_owned());
170 }
171 http.set_nodelay(nodelay);
172 http.enforce_http(false);
173
174 let (tls, tls_proxy) = if proxies.is_empty() {
175 let tls = Arc::new(tls);
176 (tls.clone(), tls)
177 } else {
178 let mut tls_proxy = tls.clone();
179 tls_proxy.alpn_protocols.clear();
180 (Arc::new(tls), Arc::new(tls_proxy))
181 };
182
183 Connector {
184 inner: Inner::RustlsTls {
185 http,
186 tls,
187 tls_proxy,
188 },
189 proxies,
190 verbose: verbose::OFF,
191 timeout: None,
192 nodelay,
193 tls_info,
194 user_agent,
195 }
196 }
197
198 pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
199 self.timeout = timeout;
200 }
201
202 pub(crate) fn set_verbose(&mut self, enabled: bool) {
203 self.verbose.0 = enabled;
204 }
205
206 #[cfg(feature = "socks")]
207 async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
208 let dns = match proxy {
209 ProxyScheme::Socks4 { .. } => socks::DnsResolve::Local,
210 ProxyScheme::Socks5 {
211 remote_dns: false, ..
212 } => socks::DnsResolve::Local,
213 ProxyScheme::Socks5 {
214 remote_dns: true, ..
215 } => socks::DnsResolve::Proxy,
216 ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
217 unreachable!("connect_socks is only called for socks proxies");
218 }
219 };
220
221 match &self.inner {
222 #[cfg(feature = "default-tls")]
223 Inner::DefaultTls(_http, tls) => {
224 if dst.scheme() == Some(&Scheme::HTTPS) {
225 let host = dst.host().ok_or("no host in url")?.to_string();
226 let conn = socks::connect(proxy, dst, dns).await?;
227 let conn = TokioIo::new(conn);
228 let conn = TokioIo::new(conn);
229 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
230 let io = tls_connector.connect(&host, conn).await?;
231 let io = TokioIo::new(io);
232 return Ok(Conn {
233 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
234 is_proxy: false,
235 tls_info: self.tls_info,
236 });
237 }
238 }
239 #[cfg(feature = "__rustls")]
240 Inner::RustlsTls { tls, .. } => {
241 if dst.scheme() == Some(&Scheme::HTTPS) {
242 use std::convert::TryFrom;
243 use tokio_rustls::TlsConnector as RustlsConnector;
244
245 let tls = tls.clone();
246 let host = dst.host().ok_or("no host in url")?.to_string();
247 let conn = socks::connect(proxy, dst, dns).await?;
248 let conn = TokioIo::new(conn);
249 let conn = TokioIo::new(conn);
250 let server_name =
251 rustls_pki_types::ServerName::try_from(host.as_str().to_owned())
252 .map_err(|_| "Invalid Server Name")?;
253 let io = RustlsConnector::from(tls)
254 .connect(server_name, conn)
255 .await?;
256 let io = TokioIo::new(io);
257 return Ok(Conn {
258 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
259 is_proxy: false,
260 tls_info: false,
261 });
262 }
263 }
264 #[cfg(not(feature = "__tls"))]
265 Inner::Http(_) => (),
266 }
267
268 socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
269 inner: self.verbose.wrap(TokioIo::new(tcp)),
270 is_proxy: false,
271 tls_info: false,
272 })
273 }
274
275 async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
276 match self.inner {
277 #[cfg(not(feature = "__tls"))]
278 Inner::Http(mut http) => {
279 let io = http.call(dst).await?;
280 Ok(Conn {
281 inner: self.verbose.wrap(io),
282 is_proxy,
283 tls_info: false,
284 })
285 }
286 #[cfg(feature = "default-tls")]
287 Inner::DefaultTls(http, tls) => {
288 let mut http = http.clone();
289
290 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
294 http.set_nodelay(true);
295 }
296
297 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
298 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
299 let io = http.call(dst).await?;
300
301 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
302 if !self.nodelay {
303 stream
304 .inner()
305 .get_ref()
306 .get_ref()
307 .get_ref()
308 .inner()
309 .inner()
310 .set_nodelay(false)?;
311 }
312 Ok(Conn {
313 inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
314 is_proxy,
315 tls_info: self.tls_info,
316 })
317 } else {
318 Ok(Conn {
319 inner: self.verbose.wrap(io),
320 is_proxy,
321 tls_info: false,
322 })
323 }
324 }
325 #[cfg(feature = "__rustls")]
326 Inner::RustlsTls { http, tls, .. } => {
327 let mut http = http.clone();
328
329 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
333 http.set_nodelay(true);
334 }
335
336 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
337 let io = http.call(dst).await?;
338
339 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
340 if !self.nodelay {
341 let (io, _) = stream.inner().get_ref();
342 io.inner().inner().set_nodelay(false)?;
343 }
344 Ok(Conn {
345 inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
346 is_proxy,
347 tls_info: self.tls_info,
348 })
349 } else {
350 Ok(Conn {
351 inner: self.verbose.wrap(io),
352 is_proxy,
353 tls_info: false,
354 })
355 }
356 }
357 }
358 }
359
360 async fn connect_via_proxy(
361 self,
362 dst: Uri,
363 proxy_scheme: ProxyScheme,
364 ) -> Result<Conn, BoxError> {
365 log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");
366
367 let (proxy_dst, _auth) = match proxy_scheme {
368 ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
369 ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
370 #[cfg(feature = "socks")]
371 ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await,
372 #[cfg(feature = "socks")]
373 ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
374 };
375
376 #[cfg(feature = "__tls")]
377 let auth = _auth;
378
379 match &self.inner {
380 #[cfg(feature = "default-tls")]
381 Inner::DefaultTls(http, tls) => {
382 if dst.scheme() == Some(&Scheme::HTTPS) {
383 let host = dst.host().to_owned();
384 let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
385 let http = http.clone();
386 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
387 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
388 let conn = http.call(proxy_dst).await?;
389 log::trace!("tunneling HTTPS over proxy");
390 let tunneled = tunnel(
391 conn,
392 host.ok_or("no host in url")?.to_string(),
393 port,
394 self.user_agent.clone(),
395 auth,
396 )
397 .await?;
398 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
399 let io = tls_connector
400 .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled))
401 .await?;
402 return Ok(Conn {
403 inner: self.verbose.wrap(NativeTlsConn {
404 inner: TokioIo::new(io),
405 }),
406 is_proxy: false,
407 tls_info: false,
408 });
409 }
410 }
411 #[cfg(feature = "__rustls")]
412 Inner::RustlsTls {
413 http,
414 tls,
415 tls_proxy,
416 } => {
417 if dst.scheme() == Some(&Scheme::HTTPS) {
418 use rustls_pki_types::ServerName;
419 use std::convert::TryFrom;
420 use tokio_rustls::TlsConnector as RustlsConnector;
421
422 let host = dst.host().ok_or("no host in url")?.to_string();
423 let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
424 let http = http.clone();
425 let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
426 let tls = tls.clone();
427 let conn = http.call(proxy_dst).await?;
428 log::trace!("tunneling HTTPS over proxy");
429 let maybe_server_name = ServerName::try_from(host.as_str().to_owned())
430 .map_err(|_| "Invalid Server Name");
431 let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
432 let server_name = maybe_server_name?;
433 let io = RustlsConnector::from(tls)
434 .connect(server_name, TokioIo::new(tunneled))
435 .await?;
436
437 return Ok(Conn {
438 inner: self.verbose.wrap(RustlsTlsConn {
439 inner: TokioIo::new(io),
440 }),
441 is_proxy: false,
442 tls_info: false,
443 });
444 }
445 }
446 #[cfg(not(feature = "__tls"))]
447 Inner::Http(_) => (),
448 }
449
450 self.connect_with_maybe_proxy(proxy_dst, true).await
451 }
452
453 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
454 match &mut self.inner {
455 #[cfg(feature = "default-tls")]
456 Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
457 #[cfg(feature = "__rustls")]
458 Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
459 #[cfg(not(feature = "__tls"))]
460 Inner::Http(http) => http.set_keepalive(dur),
461 }
462 }
463}
464
465fn into_uri(scheme: Scheme, host: Authority) -> Uri {
466 http::Uri::builder()
468 .scheme(scheme)
469 .authority(host)
470 .path_and_query(http::uri::PathAndQuery::from_static("/"))
471 .build()
472 .expect("scheme and authority is valid Uri")
473}
474
475async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
476where
477 F: Future<Output = Result<T, BoxError>>,
478{
479 if let Some(to) = timeout {
480 match tokio::time::timeout(to, f).await {
481 Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
482 Ok(Ok(try_res)) => Ok(try_res),
483 Ok(Err(e)) => Err(e),
484 }
485 } else {
486 f.await
487 }
488}
489
490impl Service<Uri> for Connector {
491 type Response = Conn;
492 type Error = BoxError;
493 type Future = Connecting;
494
495 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
496 Poll::Ready(Ok(()))
497 }
498
499 fn call(&mut self, dst: Uri) -> Self::Future {
500 log::debug!("starting new connection: {dst:?}");
501 let timeout = self.timeout;
502 for prox in self.proxies.iter() {
503 if let Some(proxy_scheme) = prox.intercept(&dst) {
504 return Box::pin(with_timeout(
505 self.clone().connect_via_proxy(dst, proxy_scheme),
506 timeout,
507 ));
508 }
509 }
510
511 Box::pin(with_timeout(
512 self.clone().connect_with_maybe_proxy(dst, false),
513 timeout,
514 ))
515 }
516}
517
518#[cfg(feature = "__tls")]
519trait TlsInfoFactory {
520 fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
521}
522
523#[cfg(feature = "__tls")]
524impl TlsInfoFactory for tokio::net::TcpStream {
525 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
526 None
527 }
528}
529
530#[cfg(feature = "__tls")]
531impl<T: TlsInfoFactory> TlsInfoFactory for TokioIo<T> {
532 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
533 self.inner().tls_info()
534 }
535}
536
537#[cfg(feature = "default-tls")]
538impl TlsInfoFactory for tokio_native_tls::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
539 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
540 let peer_certificate = self
541 .get_ref()
542 .peer_certificate()
543 .ok()
544 .flatten()
545 .and_then(|c| c.to_der().ok());
546 Some(crate::tls::TlsInfo { peer_certificate })
547 }
548}
549
550#[cfg(feature = "default-tls")]
551impl TlsInfoFactory
552 for tokio_native_tls::TlsStream<
553 TokioIo<hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
554 >
555{
556 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
557 let peer_certificate = self
558 .get_ref()
559 .peer_certificate()
560 .ok()
561 .flatten()
562 .and_then(|c| c.to_der().ok());
563 Some(crate::tls::TlsInfo { peer_certificate })
564 }
565}
566
567#[cfg(feature = "default-tls")]
568impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
569 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
570 match self {
571 hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
572 hyper_tls::MaybeHttpsStream::Http(_) => None,
573 }
574 }
575}
576
577#[cfg(feature = "__rustls")]
578impl TlsInfoFactory for tokio_rustls::client::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
579 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
580 let peer_certificate = self
581 .get_ref()
582 .1
583 .peer_certificates()
584 .and_then(|certs| certs.first())
585 .map(|c| c.to_vec());
586 Some(crate::tls::TlsInfo { peer_certificate })
587 }
588}
589
590#[cfg(feature = "__rustls")]
591impl TlsInfoFactory
592 for tokio_rustls::client::TlsStream<
593 TokioIo<hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
594 >
595{
596 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
597 let peer_certificate = self
598 .get_ref()
599 .1
600 .peer_certificates()
601 .and_then(|certs| certs.first())
602 .map(|c| c.to_vec());
603 Some(crate::tls::TlsInfo { peer_certificate })
604 }
605}
606
607#[cfg(feature = "__rustls")]
608impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
609 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
610 match self {
611 hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
612 hyper_rustls::MaybeHttpsStream::Http(_) => None,
613 }
614 }
615}
616
617pub(crate) trait AsyncConn:
618 Read + Write + Connection + Send + Sync + Unpin + 'static
619{
620}
621
622impl<T: Read + Write + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
623
624#[cfg(feature = "__tls")]
625trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
626#[cfg(not(feature = "__tls"))]
627trait AsyncConnWithInfo: AsyncConn {}
628
629#[cfg(feature = "__tls")]
630impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {}
631#[cfg(not(feature = "__tls"))]
632impl<T: AsyncConn> AsyncConnWithInfo for T {}
633
634type BoxConn = Box<dyn AsyncConnWithInfo>;
635
636pin_project! {
637 pub(crate) struct Conn {
642 #[pin]
643 inner: BoxConn,
644 is_proxy: bool,
645 tls_info: bool,
647 }
648}
649
650impl Connection for Conn {
651 fn connected(&self) -> Connected {
652 let connected = self.inner.connected().proxy(self.is_proxy);
653 #[cfg(feature = "__tls")]
654 if self.tls_info {
655 if let Some(tls_info) = self.inner.tls_info() {
656 connected.extra(tls_info)
657 } else {
658 connected
659 }
660 } else {
661 connected
662 }
663 #[cfg(not(feature = "__tls"))]
664 connected
665 }
666}
667
668impl Read for Conn {
669 fn poll_read(
670 self: Pin<&mut Self>,
671 cx: &mut Context,
672 buf: ReadBufCursor<'_>,
673 ) -> Poll<io::Result<()>> {
674 let this = self.project();
675 Read::poll_read(this.inner, cx, buf)
676 }
677}
678
679impl Write for Conn {
680 fn poll_write(
681 self: Pin<&mut Self>,
682 cx: &mut Context,
683 buf: &[u8],
684 ) -> Poll<Result<usize, io::Error>> {
685 let this = self.project();
686 Write::poll_write(this.inner, cx, buf)
687 }
688
689 fn poll_write_vectored(
690 self: Pin<&mut Self>,
691 cx: &mut Context<'_>,
692 bufs: &[IoSlice<'_>],
693 ) -> Poll<Result<usize, io::Error>> {
694 let this = self.project();
695 Write::poll_write_vectored(this.inner, cx, bufs)
696 }
697
698 fn is_write_vectored(&self) -> bool {
699 self.inner.is_write_vectored()
700 }
701
702 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
703 let this = self.project();
704 Write::poll_flush(this.inner, cx)
705 }
706
707 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
708 let this = self.project();
709 Write::poll_shutdown(this.inner, cx)
710 }
711}
712
713pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
714
715#[cfg(feature = "__tls")]
716async fn tunnel<T>(
717 mut conn: T,
718 host: String,
719 port: u16,
720 user_agent: Option<HeaderValue>,
721 auth: Option<HeaderValue>,
722) -> Result<T, BoxError>
723where
724 T: Read + Write + Unpin,
725{
726 use hyper_util::rt::TokioIo;
727 use tokio::io::{AsyncReadExt, AsyncWriteExt};
728
729 let mut buf = format!(
730 "\
731 CONNECT {host}:{port} HTTP/1.1\r\n\
732 Host: {host}:{port}\r\n\
733 "
734 )
735 .into_bytes();
736
737 if let Some(user_agent) = user_agent {
739 buf.extend_from_slice(b"User-Agent: ");
740 buf.extend_from_slice(user_agent.as_bytes());
741 buf.extend_from_slice(b"\r\n");
742 }
743
744 if let Some(value) = auth {
746 log::debug!("tunnel to {host}:{port} using basic auth");
747 buf.extend_from_slice(b"Proxy-Authorization: ");
748 buf.extend_from_slice(value.as_bytes());
749 buf.extend_from_slice(b"\r\n");
750 }
751
752 buf.extend_from_slice(b"\r\n");
754
755 let mut tokio_conn = TokioIo::new(&mut conn);
756
757 tokio_conn.write_all(&buf).await?;
758
759 let mut buf = [0; 8192];
760 let mut pos = 0;
761
762 loop {
763 let n = tokio_conn.read(&mut buf[pos..]).await?;
764
765 if n == 0 {
766 return Err(tunnel_eof());
767 }
768 pos += n;
769
770 let recvd = &buf[..pos];
771 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
772 if recvd.ends_with(b"\r\n\r\n") {
773 return Ok(conn);
774 }
775 if pos == buf.len() {
776 return Err("proxy headers too long for tunnel".into());
777 }
778 } else if recvd.starts_with(b"HTTP/1.1 407") {
780 return Err("proxy authentication required".into());
781 } else {
782 return Err("unsuccessful tunnel".into());
783 }
784 }
785}
786
787#[cfg(feature = "__tls")]
788fn tunnel_eof() -> BoxError {
789 "unexpected eof while tunneling".into()
790}
791
792#[cfg(feature = "default-tls")]
793mod native_tls_conn {
794 use super::TlsInfoFactory;
795 use hyper::rt::{Read, ReadBufCursor, Write};
796 use hyper_tls::MaybeHttpsStream;
797 use hyper_util::client::legacy::connect::{Connected, Connection};
798 use hyper_util::rt::TokioIo;
799 use pin_project_lite::pin_project;
800 use std::{
801 io::{self, IoSlice},
802 pin::Pin,
803 task::{Context, Poll},
804 };
805 use tokio::io::{AsyncRead, AsyncWrite};
806 use tokio::net::TcpStream;
807 use tokio_native_tls::TlsStream;
808
809 pin_project! {
810 pub(super) struct NativeTlsConn<T> {
811 #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
812 }
813 }
814
815 impl Connection for NativeTlsConn<TokioIo<TokioIo<TcpStream>>> {
816 fn connected(&self) -> Connected {
817 let connected = self
818 .inner
819 .inner()
820 .get_ref()
821 .get_ref()
822 .get_ref()
823 .inner()
824 .connected();
825 #[cfg(feature = "native-tls-alpn")]
826 match self.inner.inner().get_ref().negotiated_alpn().ok() {
827 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
828 _ => connected,
829 }
830 #[cfg(not(feature = "native-tls-alpn"))]
831 connected
832 }
833 }
834
835 impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
836 fn connected(&self) -> Connected {
837 let connected = self
838 .inner
839 .inner()
840 .get_ref()
841 .get_ref()
842 .get_ref()
843 .inner()
844 .connected();
845 #[cfg(feature = "native-tls-alpn")]
846 match self.inner.inner().get_ref().negotiated_alpn().ok() {
847 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
848 _ => connected,
849 }
850 #[cfg(not(feature = "native-tls-alpn"))]
851 connected
852 }
853 }
854
855 impl<T: AsyncRead + AsyncWrite + Unpin> Read for NativeTlsConn<T> {
856 fn poll_read(
857 self: Pin<&mut Self>,
858 cx: &mut Context,
859 buf: ReadBufCursor<'_>,
860 ) -> Poll<tokio::io::Result<()>> {
861 let this = self.project();
862 Read::poll_read(this.inner, cx, buf)
863 }
864 }
865
866 impl<T: AsyncRead + AsyncWrite + Unpin> Write for NativeTlsConn<T> {
867 fn poll_write(
868 self: Pin<&mut Self>,
869 cx: &mut Context,
870 buf: &[u8],
871 ) -> Poll<Result<usize, tokio::io::Error>> {
872 let this = self.project();
873 Write::poll_write(this.inner, cx, buf)
874 }
875
876 fn poll_write_vectored(
877 self: Pin<&mut Self>,
878 cx: &mut Context<'_>,
879 bufs: &[IoSlice<'_>],
880 ) -> Poll<Result<usize, io::Error>> {
881 let this = self.project();
882 Write::poll_write_vectored(this.inner, cx, bufs)
883 }
884
885 fn is_write_vectored(&self) -> bool {
886 self.inner.is_write_vectored()
887 }
888
889 fn poll_flush(
890 self: Pin<&mut Self>,
891 cx: &mut Context,
892 ) -> Poll<Result<(), tokio::io::Error>> {
893 let this = self.project();
894 Write::poll_flush(this.inner, cx)
895 }
896
897 fn poll_shutdown(
898 self: Pin<&mut Self>,
899 cx: &mut Context,
900 ) -> Poll<Result<(), tokio::io::Error>> {
901 let this = self.project();
902 Write::poll_shutdown(this.inner, cx)
903 }
904 }
905
906 impl<T> TlsInfoFactory for NativeTlsConn<T>
907 where
908 TokioIo<TlsStream<T>>: TlsInfoFactory,
909 {
910 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
911 self.inner.tls_info()
912 }
913 }
914}
915
916#[cfg(feature = "__rustls")]
917mod rustls_tls_conn {
918 use super::TlsInfoFactory;
919 use hyper::rt::{Read, ReadBufCursor, Write};
920 use hyper_rustls::MaybeHttpsStream;
921 use hyper_util::client::legacy::connect::{Connected, Connection};
922 use hyper_util::rt::TokioIo;
923 use pin_project_lite::pin_project;
924 use std::{
925 io::{self, IoSlice},
926 pin::Pin,
927 task::{Context, Poll},
928 };
929 use tokio::io::{AsyncRead, AsyncWrite};
930 use tokio::net::TcpStream;
931 use tokio_rustls::client::TlsStream;
932
933 pin_project! {
934 pub(super) struct RustlsTlsConn<T> {
935 #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
936 }
937 }
938
939 impl Connection for RustlsTlsConn<TokioIo<TokioIo<TcpStream>>> {
940 fn connected(&self) -> Connected {
941 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
942 self.inner
943 .inner()
944 .get_ref()
945 .0
946 .inner()
947 .connected()
948 .negotiated_h2()
949 } else {
950 self.inner.inner().get_ref().0.inner().connected()
951 }
952 }
953 }
954 impl Connection for RustlsTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
955 fn connected(&self) -> Connected {
956 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
957 self.inner
958 .inner()
959 .get_ref()
960 .0
961 .inner()
962 .connected()
963 .negotiated_h2()
964 } else {
965 self.inner.inner().get_ref().0.inner().connected()
966 }
967 }
968 }
969
970 impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustlsTlsConn<T> {
971 fn poll_read(
972 self: Pin<&mut Self>,
973 cx: &mut Context,
974 buf: ReadBufCursor<'_>,
975 ) -> Poll<tokio::io::Result<()>> {
976 let this = self.project();
977 Read::poll_read(this.inner, cx, buf)
978 }
979 }
980
981 impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustlsTlsConn<T> {
982 fn poll_write(
983 self: Pin<&mut Self>,
984 cx: &mut Context,
985 buf: &[u8],
986 ) -> Poll<Result<usize, tokio::io::Error>> {
987 let this = self.project();
988 Write::poll_write(this.inner, cx, buf)
989 }
990
991 fn poll_write_vectored(
992 self: Pin<&mut Self>,
993 cx: &mut Context<'_>,
994 bufs: &[IoSlice<'_>],
995 ) -> Poll<Result<usize, io::Error>> {
996 let this = self.project();
997 Write::poll_write_vectored(this.inner, cx, bufs)
998 }
999
1000 fn is_write_vectored(&self) -> bool {
1001 self.inner.is_write_vectored()
1002 }
1003
1004 fn poll_flush(
1005 self: Pin<&mut Self>,
1006 cx: &mut Context,
1007 ) -> Poll<Result<(), tokio::io::Error>> {
1008 let this = self.project();
1009 Write::poll_flush(this.inner, cx)
1010 }
1011
1012 fn poll_shutdown(
1013 self: Pin<&mut Self>,
1014 cx: &mut Context,
1015 ) -> Poll<Result<(), tokio::io::Error>> {
1016 let this = self.project();
1017 Write::poll_shutdown(this.inner, cx)
1018 }
1019 }
1020 impl<T> TlsInfoFactory for RustlsTlsConn<T>
1021 where
1022 TokioIo<TlsStream<T>>: TlsInfoFactory,
1023 {
1024 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1025 self.inner.tls_info()
1026 }
1027 }
1028}
1029
1030#[cfg(feature = "socks")]
1031mod socks {
1032 use std::io;
1033 use std::net::ToSocketAddrs;
1034
1035 use http::Uri;
1036 use tokio::net::TcpStream;
1037 use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
1038
1039 use super::{BoxError, Scheme};
1040 use crate::proxy::ProxyScheme;
1041
1042 pub(super) enum DnsResolve {
1043 Local,
1044 Proxy,
1045 }
1046
1047 pub(super) async fn connect(
1048 proxy: ProxyScheme,
1049 dst: Uri,
1050 dns: DnsResolve,
1051 ) -> Result<TcpStream, BoxError> {
1052 let https = dst.scheme() == Some(&Scheme::HTTPS);
1053 let original_host = dst
1054 .host()
1055 .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
1056 let mut host = original_host.to_owned();
1057 let port = match dst.port() {
1058 Some(p) => p.as_u16(),
1059 None if https => 443u16,
1060 _ => 80u16,
1061 };
1062
1063 if let DnsResolve::Local = dns {
1064 let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
1065 if let Some(new_target) = maybe_new_target {
1066 host = new_target.ip().to_string();
1067 }
1068 }
1069
1070 match proxy {
1071 ProxyScheme::Socks4 { addr } => {
1072 let stream = Socks4Stream::connect(addr, (host.as_str(), port))
1073 .await
1074 .map_err(|e| format!("socks connect error: {e}"))?;
1075 Ok(stream.into_inner())
1076 }
1077 ProxyScheme::Socks5 { addr, ref auth, .. } => {
1078 let stream = if let Some((username, password)) = auth {
1079 Socks5Stream::connect_with_password(
1080 addr,
1081 (host.as_str(), port),
1082 &username,
1083 &password,
1084 )
1085 .await
1086 .map_err(|e| format!("socks connect error: {e}"))?
1087 } else {
1088 Socks5Stream::connect(addr, (host.as_str(), port))
1089 .await
1090 .map_err(|e| format!("socks connect error: {e}"))?
1091 };
1092
1093 Ok(stream.into_inner())
1094 }
1095 _ => unreachable!(),
1096 }
1097 }
1098}
1099
1100mod verbose {
1101 use hyper::rt::{Read, ReadBufCursor, Write};
1102 use hyper_util::client::legacy::connect::{Connected, Connection};
1103 use std::cmp::min;
1104 use std::fmt;
1105 use std::io::{self, IoSlice};
1106 use std::pin::Pin;
1107 use std::task::{Context, Poll};
1108
1109 pub(super) const OFF: Wrapper = Wrapper(false);
1110
1111 #[derive(Clone, Copy)]
1112 pub(super) struct Wrapper(pub(super) bool);
1113
1114 impl Wrapper {
1115 pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn {
1116 if self.0 && log::log_enabled!(log::Level::Trace) {
1117 Box::new(Verbose {
1118 id: crate::util::fast_random() as u32,
1120 inner: conn,
1121 })
1122 } else {
1123 Box::new(conn)
1124 }
1125 }
1126 }
1127
1128 struct Verbose<T> {
1129 id: u32,
1130 inner: T,
1131 }
1132
1133 impl<T: Connection + Read + Write + Unpin> Connection for Verbose<T> {
1134 fn connected(&self) -> Connected {
1135 self.inner.connected()
1136 }
1137 }
1138
1139 impl<T: Read + Write + Unpin> Read for Verbose<T> {
1140 fn poll_read(
1141 mut self: Pin<&mut Self>,
1142 cx: &mut Context,
1143 mut buf: ReadBufCursor<'_>,
1144 ) -> Poll<std::io::Result<()>> {
1145 let mut vbuf = hyper::rt::ReadBuf::uninit(unsafe { buf.as_mut() });
1149 match Pin::new(&mut self.inner).poll_read(cx, vbuf.unfilled()) {
1150 Poll::Ready(Ok(())) => {
1151 log::trace!("{:08x} read: {:?}", self.id, Escape(vbuf.filled()));
1152 let len = vbuf.filled().len();
1153 unsafe {
1156 buf.advance(len);
1157 }
1158 Poll::Ready(Ok(()))
1159 }
1160 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1161 Poll::Pending => Poll::Pending,
1162 }
1163 }
1164 }
1165
1166 impl<T: Read + Write + Unpin> Write for Verbose<T> {
1167 fn poll_write(
1168 mut self: Pin<&mut Self>,
1169 cx: &mut Context,
1170 buf: &[u8],
1171 ) -> Poll<Result<usize, std::io::Error>> {
1172 match Pin::new(&mut self.inner).poll_write(cx, buf) {
1173 Poll::Ready(Ok(n)) => {
1174 log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1175 Poll::Ready(Ok(n))
1176 }
1177 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1178 Poll::Pending => Poll::Pending,
1179 }
1180 }
1181
1182 fn poll_write_vectored(
1183 mut self: Pin<&mut Self>,
1184 cx: &mut Context<'_>,
1185 bufs: &[IoSlice<'_>],
1186 ) -> Poll<Result<usize, io::Error>> {
1187 match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
1188 Poll::Ready(Ok(nwritten)) => {
1189 log::trace!(
1190 "{:08x} write (vectored): {:?}",
1191 self.id,
1192 Vectored { bufs, nwritten }
1193 );
1194 Poll::Ready(Ok(nwritten))
1195 }
1196 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1197 Poll::Pending => Poll::Pending,
1198 }
1199 }
1200
1201 fn is_write_vectored(&self) -> bool {
1202 self.inner.is_write_vectored()
1203 }
1204
1205 fn poll_flush(
1206 mut self: Pin<&mut Self>,
1207 cx: &mut Context,
1208 ) -> Poll<Result<(), std::io::Error>> {
1209 Pin::new(&mut self.inner).poll_flush(cx)
1210 }
1211
1212 fn poll_shutdown(
1213 mut self: Pin<&mut Self>,
1214 cx: &mut Context,
1215 ) -> Poll<Result<(), std::io::Error>> {
1216 Pin::new(&mut self.inner).poll_shutdown(cx)
1217 }
1218 }
1219
1220 #[cfg(feature = "__tls")]
1221 impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> {
1222 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1223 self.inner.tls_info()
1224 }
1225 }
1226
1227 struct Escape<'a>(&'a [u8]);
1228
1229 impl fmt::Debug for Escape<'_> {
1230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1231 write!(f, "b\"")?;
1232 for &c in self.0 {
1233 if c == b'\n' {
1235 write!(f, "\\n")?;
1236 } else if c == b'\r' {
1237 write!(f, "\\r")?;
1238 } else if c == b'\t' {
1239 write!(f, "\\t")?;
1240 } else if c == b'\\' || c == b'"' {
1241 write!(f, "\\{}", c as char)?;
1242 } else if c == b'\0' {
1243 write!(f, "\\0")?;
1244 } else if c >= 0x20 && c < 0x7f {
1246 write!(f, "{}", c as char)?;
1247 } else {
1248 write!(f, "\\x{c:02x}")?;
1249 }
1250 }
1251 write!(f, "\"")?;
1252 Ok(())
1253 }
1254 }
1255
1256 struct Vectored<'a, 'b> {
1257 bufs: &'a [IoSlice<'b>],
1258 nwritten: usize,
1259 }
1260
1261 impl fmt::Debug for Vectored<'_, '_> {
1262 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1263 let mut left = self.nwritten;
1264 for buf in self.bufs.iter() {
1265 if left == 0 {
1266 break;
1267 }
1268 let n = min(left, buf.len());
1269 Escape(&buf[..n]).fmt(f)?;
1270 left -= n;
1271 }
1272 Ok(())
1273 }
1274 }
1275}
1276
1277#[cfg(feature = "__tls")]
1278#[cfg(test)]
1279mod tests {
1280 use super::tunnel;
1281 use crate::proxy;
1282 use hyper_util::rt::TokioIo;
1283 use std::io::{Read, Write};
1284 use std::net::TcpListener;
1285 use std::thread;
1286 use tokio::net::TcpStream;
1287 use tokio::runtime;
1288
1289 static TUNNEL_UA: &str = "tunnel-test/x.y";
1290 static TUNNEL_OK: &[u8] = b"\
1291 HTTP/1.1 200 OK\r\n\
1292 \r\n\
1293 ";
1294
1295 macro_rules! mock_tunnel {
1296 () => {{
1297 mock_tunnel!(TUNNEL_OK)
1298 }};
1299 ($write:expr) => {{
1300 mock_tunnel!($write, "")
1301 }};
1302 ($write:expr, $auth:expr) => {{
1303 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1304 let addr = listener.local_addr().unwrap();
1305 let connect_expected = format!(
1306 "\
1307 CONNECT {0}:{1} HTTP/1.1\r\n\
1308 Host: {0}:{1}\r\n\
1309 User-Agent: {2}\r\n\
1310 {3}\
1311 \r\n\
1312 ",
1313 addr.ip(),
1314 addr.port(),
1315 TUNNEL_UA,
1316 $auth
1317 )
1318 .into_bytes();
1319
1320 thread::spawn(move || {
1321 let (mut sock, _) = listener.accept().unwrap();
1322 let mut buf = [0u8; 4096];
1323 let n = sock.read(&mut buf).unwrap();
1324 assert_eq!(&buf[..n], &connect_expected[..]);
1325
1326 sock.write_all($write).unwrap();
1327 });
1328 addr
1329 }};
1330 }
1331
1332 fn ua() -> Option<http::header::HeaderValue> {
1333 Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1334 }
1335
1336 #[test]
1337 fn test_tunnel() {
1338 let addr = mock_tunnel!();
1339
1340 let rt = runtime::Builder::new_current_thread()
1341 .enable_all()
1342 .build()
1343 .expect("new rt");
1344 let f = async move {
1345 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1346 let host = addr.ip().to_string();
1347 let port = addr.port();
1348 tunnel(tcp, host, port, ua(), None).await
1349 };
1350
1351 rt.block_on(f).unwrap();
1352 }
1353
1354 #[test]
1355 fn test_tunnel_eof() {
1356 let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1357
1358 let rt = runtime::Builder::new_current_thread()
1359 .enable_all()
1360 .build()
1361 .expect("new rt");
1362 let f = async move {
1363 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1364 let host = addr.ip().to_string();
1365 let port = addr.port();
1366 tunnel(tcp, host, port, ua(), None).await
1367 };
1368
1369 rt.block_on(f).unwrap_err();
1370 }
1371
1372 #[test]
1373 fn test_tunnel_non_http_response() {
1374 let addr = mock_tunnel!(b"foo bar baz hallo");
1375
1376 let rt = runtime::Builder::new_current_thread()
1377 .enable_all()
1378 .build()
1379 .expect("new rt");
1380 let f = async move {
1381 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1382 let host = addr.ip().to_string();
1383 let port = addr.port();
1384 tunnel(tcp, host, port, ua(), None).await
1385 };
1386
1387 rt.block_on(f).unwrap_err();
1388 }
1389
1390 #[test]
1391 fn test_tunnel_proxy_unauthorized() {
1392 let addr = mock_tunnel!(
1393 b"\
1394 HTTP/1.1 407 Proxy Authentication Required\r\n\
1395 Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1396 \r\n\
1397 "
1398 );
1399
1400 let rt = runtime::Builder::new_current_thread()
1401 .enable_all()
1402 .build()
1403 .expect("new rt");
1404 let f = async move {
1405 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1406 let host = addr.ip().to_string();
1407 let port = addr.port();
1408 tunnel(tcp, host, port, ua(), None).await
1409 };
1410
1411 let error = rt.block_on(f).unwrap_err();
1412 assert_eq!(error.to_string(), "proxy authentication required");
1413 }
1414
1415 #[test]
1416 fn test_tunnel_basic_auth() {
1417 let addr = mock_tunnel!(
1418 TUNNEL_OK,
1419 "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1420 );
1421
1422 let rt = runtime::Builder::new_current_thread()
1423 .enable_all()
1424 .build()
1425 .expect("new rt");
1426 let f = async move {
1427 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1428 let host = addr.ip().to_string();
1429 let port = addr.port();
1430 tunnel(
1431 tcp,
1432 host,
1433 port,
1434 ua(),
1435 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1436 )
1437 .await
1438 };
1439
1440 rt.block_on(f).unwrap();
1441 }
1442}