native_tls/imp/
openssl.rs

1extern crate openssl;
2extern crate openssl_probe;
3
4use self::openssl::error::ErrorStack;
5use self::openssl::hash::MessageDigest;
6use self::openssl::nid::Nid;
7use self::openssl::pkcs12::Pkcs12;
8use self::openssl::pkey::{PKey, Private};
9use self::openssl::ssl::{
10    self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11    SslVerifyMode,
12};
13use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
14use std::error;
15use std::fmt;
16use std::io;
17use std::sync::Once;
18
19use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
20
21#[cfg(have_min_max_version)]
22fn supported_protocols(
23    min: Option<Protocol>,
24    max: Option<Protocol>,
25    ctx: &mut SslContextBuilder,
26) -> Result<(), ErrorStack> {
27    use self::openssl::ssl::SslVersion;
28
29    fn cvt(p: Protocol) -> SslVersion {
30        match p {
31            Protocol::Sslv3 => SslVersion::SSL3,
32            Protocol::Tlsv10 => SslVersion::TLS1,
33            Protocol::Tlsv11 => SslVersion::TLS1_1,
34            Protocol::Tlsv12 => SslVersion::TLS1_2,
35        }
36    }
37
38    ctx.set_min_proto_version(min.map(cvt))?;
39    ctx.set_max_proto_version(max.map(cvt))?;
40
41    Ok(())
42}
43
44#[cfg(not(have_min_max_version))]
45fn supported_protocols(
46    min: Option<Protocol>,
47    max: Option<Protocol>,
48    ctx: &mut SslContextBuilder,
49) -> Result<(), ErrorStack> {
50    use self::openssl::ssl::SslOptions;
51
52    let no_ssl_mask = SslOptions::NO_SSLV2
53        | SslOptions::NO_SSLV3
54        | SslOptions::NO_TLSV1
55        | SslOptions::NO_TLSV1_1
56        | SslOptions::NO_TLSV1_2;
57
58    ctx.clear_options(no_ssl_mask);
59    let mut options = SslOptions::empty();
60    options |= match min {
61        None => SslOptions::empty(),
62        Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
63        Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
64        Some(Protocol::Tlsv11) => {
65            SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
66        }
67        Some(Protocol::Tlsv12) => {
68            SslOptions::NO_SSLV2
69                | SslOptions::NO_SSLV3
70                | SslOptions::NO_TLSV1
71                | SslOptions::NO_TLSV1_1
72        }
73    };
74    options |= match max {
75        None | Some(Protocol::Tlsv12) => SslOptions::empty(),
76        Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
77        Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
78        Some(Protocol::Sslv3) => {
79            SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
80        }
81    };
82
83    ctx.set_options(options);
84
85    Ok(())
86}
87
88fn init_trust() {
89    static ONCE: Once = Once::new();
90    ONCE.call_once(openssl_probe::init_ssl_cert_env_vars);
91}
92
93#[cfg(target_os = "android")]
94fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
95    use std::fs;
96
97    if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
98        let certs = dir
99            .filter_map(|r| r.ok())
100            .filter_map(|e| fs::read(e.path()).ok())
101            .filter_map(|b| X509::from_pem(&b).ok());
102        for cert in certs {
103            if let Err(err) = connector.cert_store_mut().add_cert(cert) {
104                debug!("load_android_root_certs error: {:?}", err);
105            }
106        }
107    }
108
109    Ok(())
110}
111
112#[derive(Debug)]
113pub enum Error {
114    Normal(ErrorStack),
115    Ssl(ssl::Error, X509VerifyResult),
116    EmptyChain,
117    NotPkcs8,
118}
119
120impl error::Error for Error {
121    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
122        match *self {
123            Error::Normal(ref e) => error::Error::source(e),
124            Error::Ssl(ref e, _) => error::Error::source(e),
125            Error::EmptyChain => None,
126            Error::NotPkcs8 => None,
127        }
128    }
129}
130
131impl fmt::Display for Error {
132    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
133        match *self {
134            Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
135            Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
136            Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
137            Error::EmptyChain => write!(
138                fmt,
139                "at least one certificate must be provided to create an identity"
140            ),
141            Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
142        }
143    }
144}
145
146impl From<ErrorStack> for Error {
147    fn from(err: ErrorStack) -> Error {
148        Error::Normal(err)
149    }
150}
151
152#[derive(Clone)]
153pub struct Identity {
154    pkey: PKey<Private>,
155    cert: X509,
156    chain: Vec<X509>,
157}
158
159impl Identity {
160    pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
161        let pkcs12 = Pkcs12::from_der(buf)?;
162        let parsed = pkcs12.parse2(pass)?;
163        Ok(Identity {
164            pkey: parsed.pkey.ok_or_else(|| Error::EmptyChain)?,
165            cert: parsed.cert.ok_or_else(|| Error::EmptyChain)?,
166            // > The stack is the reverse of what you might expect due to the way
167            // > PKCS12_parse is implemented, so we need to load it backwards.
168            // > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44
169            chain: parsed.ca.into_iter().flatten().rev().collect(),
170        })
171    }
172
173    pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
174        if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
175            return Err(Error::NotPkcs8);
176        }
177
178        let pkey = PKey::private_key_from_pem(key)?;
179        let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
180        let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
181        let chain = cert_chain.collect();
182        Ok(Identity { pkey, cert, chain })
183    }
184}
185
186#[derive(Clone)]
187pub struct Certificate(X509);
188
189impl Certificate {
190    pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
191        let cert = X509::from_der(buf)?;
192        Ok(Certificate(cert))
193    }
194
195    pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
196        let cert = X509::from_pem(buf)?;
197        Ok(Certificate(cert))
198    }
199
200    pub fn to_der(&self) -> Result<Vec<u8>, Error> {
201        let der = self.0.to_der()?;
202        Ok(der)
203    }
204}
205
206pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
207
208impl<S> fmt::Debug for MidHandshakeTlsStream<S>
209where
210    S: fmt::Debug,
211{
212    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
213        fmt::Debug::fmt(&self.0, fmt)
214    }
215}
216
217impl<S> MidHandshakeTlsStream<S> {
218    pub fn get_ref(&self) -> &S {
219        self.0.get_ref()
220    }
221
222    pub fn get_mut(&mut self) -> &mut S {
223        self.0.get_mut()
224    }
225}
226
227impl<S> MidHandshakeTlsStream<S>
228where
229    S: io::Read + io::Write,
230{
231    pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
232        match self.0.handshake() {
233            Ok(s) => Ok(TlsStream(s)),
234            Err(e) => Err(e.into()),
235        }
236    }
237}
238
239pub enum HandshakeError<S> {
240    Failure(Error),
241    WouldBlock(MidHandshakeTlsStream<S>),
242}
243
244impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
245    fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
246        match e {
247            ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
248            ssl::HandshakeError::Failure(e) => {
249                let v = e.ssl().verify_result();
250                HandshakeError::Failure(Error::Ssl(e.into_error(), v))
251            }
252            ssl::HandshakeError::WouldBlock(s) => {
253                HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
254            }
255        }
256    }
257}
258
259impl<S> From<ErrorStack> for HandshakeError<S> {
260    fn from(e: ErrorStack) -> HandshakeError<S> {
261        HandshakeError::Failure(e.into())
262    }
263}
264
265#[derive(Clone)]
266pub struct TlsConnector {
267    connector: SslConnector,
268    use_sni: bool,
269    accept_invalid_hostnames: bool,
270    accept_invalid_certs: bool,
271}
272
273impl TlsConnector {
274    pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
275        init_trust();
276
277        let mut connector = SslConnector::builder(SslMethod::tls())?;
278        if let Some(ref identity) = builder.identity {
279            connector.set_certificate(&identity.0.cert)?;
280            connector.set_private_key(&identity.0.pkey)?;
281            for cert in identity.0.chain.iter() {
282                // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
283                // specifies that "When sending a certificate chain, extra chain certificates are
284                // sent in order following the end entity certificate."
285                connector.add_extra_chain_cert(cert.to_owned())?;
286            }
287        }
288        supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
289
290        if builder.disable_built_in_roots {
291            connector.set_cert_store(X509StoreBuilder::new()?.build());
292        }
293
294        for cert in &builder.root_certificates {
295            if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
296                debug!("add_cert error: {:?}", err);
297            }
298        }
299
300        #[cfg(feature = "alpn")]
301        {
302            if !builder.alpn.is_empty() {
303                // Wire format is each alpn preceded by its length as a byte.
304                let mut alpn_wire_format = Vec::with_capacity(
305                    builder
306                        .alpn
307                        .iter()
308                        .map(|s| s.as_bytes().len())
309                        .sum::<usize>()
310                        + builder.alpn.len(),
311                );
312                for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
313                    alpn_wire_format.push(alpn.len() as u8);
314                    alpn_wire_format.extend(alpn);
315                }
316                connector.set_alpn_protos(&alpn_wire_format)?;
317            }
318        }
319
320        #[cfg(target_os = "android")]
321        load_android_root_certs(&mut connector)?;
322
323        Ok(TlsConnector {
324            connector: connector.build(),
325            use_sni: builder.use_sni,
326            accept_invalid_hostnames: builder.accept_invalid_hostnames,
327            accept_invalid_certs: builder.accept_invalid_certs,
328        })
329    }
330
331    pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
332    where
333        S: io::Read + io::Write,
334    {
335        let mut ssl = self
336            .connector
337            .configure()?
338            .use_server_name_indication(self.use_sni)
339            .verify_hostname(!self.accept_invalid_hostnames);
340        if self.accept_invalid_certs {
341            ssl.set_verify(SslVerifyMode::NONE);
342        }
343
344        let s = ssl.connect(domain, stream)?;
345        Ok(TlsStream(s))
346    }
347}
348
349impl fmt::Debug for TlsConnector {
350    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
351        fmt.debug_struct("TlsConnector")
352            // n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted
353            .field("use_sni", &self.use_sni)
354            .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
355            .field("accept_invalid_certs", &self.accept_invalid_certs)
356            .finish()
357    }
358}
359
360#[derive(Clone)]
361pub struct TlsAcceptor(SslAcceptor);
362
363impl TlsAcceptor {
364    pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
365        let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
366        acceptor.set_private_key(&builder.identity.0.pkey)?;
367        acceptor.set_certificate(&builder.identity.0.cert)?;
368        for cert in builder.identity.0.chain.iter() {
369            // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
370            // specifies that "When sending a certificate chain, extra chain certificates are
371            // sent in order following the end entity certificate."
372            acceptor.add_extra_chain_cert(cert.to_owned())?;
373        }
374        supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
375
376        Ok(TlsAcceptor(acceptor.build()))
377    }
378
379    pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
380    where
381        S: io::Read + io::Write,
382    {
383        let s = self.0.accept(stream)?;
384        Ok(TlsStream(s))
385    }
386}
387
388pub struct TlsStream<S>(ssl::SslStream<S>);
389
390impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
391    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
392        fmt::Debug::fmt(&self.0, fmt)
393    }
394}
395
396impl<S> TlsStream<S> {
397    pub fn get_ref(&self) -> &S {
398        self.0.get_ref()
399    }
400
401    pub fn get_mut(&mut self) -> &mut S {
402        self.0.get_mut()
403    }
404}
405
406impl<S: io::Read + io::Write> TlsStream<S> {
407    pub fn buffered_read_size(&self) -> Result<usize, Error> {
408        Ok(self.0.ssl().pending())
409    }
410
411    pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
412        Ok(self.0.ssl().peer_certificate().map(Certificate))
413    }
414
415    #[cfg(feature = "alpn")]
416    pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
417        Ok(self
418            .0
419            .ssl()
420            .selected_alpn_protocol()
421            .map(|alpn| alpn.to_vec()))
422    }
423
424    pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
425        let cert = if self.0.ssl().is_server() {
426            self.0.ssl().certificate().map(|x| x.to_owned())
427        } else {
428            self.0.ssl().peer_certificate()
429        };
430
431        let cert = match cert {
432            Some(cert) => cert,
433            None => return Ok(None),
434        };
435
436        let algo_nid = cert.signature_algorithm().object().nid();
437        let signature_algorithms = match algo_nid.signature_algorithms() {
438            Some(algs) => algs,
439            None => return Ok(None),
440        };
441
442        let md = match signature_algorithms.digest {
443            Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
444            nid => match MessageDigest::from_nid(nid) {
445                Some(md) => md,
446                None => return Ok(None),
447            },
448        };
449
450        let digest = cert.digest(md)?;
451
452        Ok(Some(digest.to_vec()))
453    }
454
455    pub fn shutdown(&mut self) -> io::Result<()> {
456        match self.0.shutdown() {
457            Ok(_) => Ok(()),
458            Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
459            Err(e) => Err(e
460                .into_io_error()
461                .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
462        }
463    }
464}
465
466impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
467    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
468        self.0.read(buf)
469    }
470}
471
472impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
473    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
474        self.0.write(buf)
475    }
476
477    fn flush(&mut self) -> io::Result<()> {
478        self.0.flush()
479    }
480}