1extern crate openssl;
2extern crate openssl_probe;
3
4use self::openssl::error::ErrorStack;
5use self::openssl::hash::MessageDigest;
6use self::openssl::nid::Nid;
7use self::openssl::pkcs12::Pkcs12;
8use self::openssl::pkey::{PKey, Private};
9use self::openssl::ssl::{
10 self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11 SslVerifyMode,
12};
13use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
14use std::error;
15use std::fmt;
16use std::io;
17use std::sync::Once;
18
19use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
20
21#[cfg(have_min_max_version)]
22fn supported_protocols(
23 min: Option<Protocol>,
24 max: Option<Protocol>,
25 ctx: &mut SslContextBuilder,
26) -> Result<(), ErrorStack> {
27 use self::openssl::ssl::SslVersion;
28
29 fn cvt(p: Protocol) -> SslVersion {
30 match p {
31 Protocol::Sslv3 => SslVersion::SSL3,
32 Protocol::Tlsv10 => SslVersion::TLS1,
33 Protocol::Tlsv11 => SslVersion::TLS1_1,
34 Protocol::Tlsv12 => SslVersion::TLS1_2,
35 }
36 }
37
38 ctx.set_min_proto_version(min.map(cvt))?;
39 ctx.set_max_proto_version(max.map(cvt))?;
40
41 Ok(())
42}
43
44#[cfg(not(have_min_max_version))]
45fn supported_protocols(
46 min: Option<Protocol>,
47 max: Option<Protocol>,
48 ctx: &mut SslContextBuilder,
49) -> Result<(), ErrorStack> {
50 use self::openssl::ssl::SslOptions;
51
52 let no_ssl_mask = SslOptions::NO_SSLV2
53 | SslOptions::NO_SSLV3
54 | SslOptions::NO_TLSV1
55 | SslOptions::NO_TLSV1_1
56 | SslOptions::NO_TLSV1_2;
57
58 ctx.clear_options(no_ssl_mask);
59 let mut options = SslOptions::empty();
60 options |= match min {
61 None => SslOptions::empty(),
62 Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
63 Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
64 Some(Protocol::Tlsv11) => {
65 SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
66 }
67 Some(Protocol::Tlsv12) => {
68 SslOptions::NO_SSLV2
69 | SslOptions::NO_SSLV3
70 | SslOptions::NO_TLSV1
71 | SslOptions::NO_TLSV1_1
72 }
73 };
74 options |= match max {
75 None | Some(Protocol::Tlsv12) => SslOptions::empty(),
76 Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
77 Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
78 Some(Protocol::Sslv3) => {
79 SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
80 }
81 };
82
83 ctx.set_options(options);
84
85 Ok(())
86}
87
88fn init_trust() {
89 static ONCE: Once = Once::new();
90 ONCE.call_once(openssl_probe::init_ssl_cert_env_vars);
91}
92
93#[cfg(target_os = "android")]
94fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
95 use std::fs;
96
97 if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
98 let certs = dir
99 .filter_map(|r| r.ok())
100 .filter_map(|e| fs::read(e.path()).ok())
101 .filter_map(|b| X509::from_pem(&b).ok());
102 for cert in certs {
103 if let Err(err) = connector.cert_store_mut().add_cert(cert) {
104 debug!("load_android_root_certs error: {:?}", err);
105 }
106 }
107 }
108
109 Ok(())
110}
111
112#[derive(Debug)]
113pub enum Error {
114 Normal(ErrorStack),
115 Ssl(ssl::Error, X509VerifyResult),
116 EmptyChain,
117 NotPkcs8,
118}
119
120impl error::Error for Error {
121 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
122 match *self {
123 Error::Normal(ref e) => error::Error::source(e),
124 Error::Ssl(ref e, _) => error::Error::source(e),
125 Error::EmptyChain => None,
126 Error::NotPkcs8 => None,
127 }
128 }
129}
130
131impl fmt::Display for Error {
132 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
133 match *self {
134 Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
135 Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
136 Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
137 Error::EmptyChain => write!(
138 fmt,
139 "at least one certificate must be provided to create an identity"
140 ),
141 Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
142 }
143 }
144}
145
146impl From<ErrorStack> for Error {
147 fn from(err: ErrorStack) -> Error {
148 Error::Normal(err)
149 }
150}
151
152#[derive(Clone)]
153pub struct Identity {
154 pkey: PKey<Private>,
155 cert: X509,
156 chain: Vec<X509>,
157}
158
159impl Identity {
160 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
161 let pkcs12 = Pkcs12::from_der(buf)?;
162 let parsed = pkcs12.parse2(pass)?;
163 Ok(Identity {
164 pkey: parsed.pkey.ok_or_else(|| Error::EmptyChain)?,
165 cert: parsed.cert.ok_or_else(|| Error::EmptyChain)?,
166 chain: parsed.ca.into_iter().flatten().rev().collect(),
170 })
171 }
172
173 pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
174 if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
175 return Err(Error::NotPkcs8);
176 }
177
178 let pkey = PKey::private_key_from_pem(key)?;
179 let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
180 let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
181 let chain = cert_chain.collect();
182 Ok(Identity { pkey, cert, chain })
183 }
184}
185
186#[derive(Clone)]
187pub struct Certificate(X509);
188
189impl Certificate {
190 pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
191 let cert = X509::from_der(buf)?;
192 Ok(Certificate(cert))
193 }
194
195 pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
196 let cert = X509::from_pem(buf)?;
197 Ok(Certificate(cert))
198 }
199
200 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
201 let der = self.0.to_der()?;
202 Ok(der)
203 }
204}
205
206pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
207
208impl<S> fmt::Debug for MidHandshakeTlsStream<S>
209where
210 S: fmt::Debug,
211{
212 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
213 fmt::Debug::fmt(&self.0, fmt)
214 }
215}
216
217impl<S> MidHandshakeTlsStream<S> {
218 pub fn get_ref(&self) -> &S {
219 self.0.get_ref()
220 }
221
222 pub fn get_mut(&mut self) -> &mut S {
223 self.0.get_mut()
224 }
225}
226
227impl<S> MidHandshakeTlsStream<S>
228where
229 S: io::Read + io::Write,
230{
231 pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
232 match self.0.handshake() {
233 Ok(s) => Ok(TlsStream(s)),
234 Err(e) => Err(e.into()),
235 }
236 }
237}
238
239pub enum HandshakeError<S> {
240 Failure(Error),
241 WouldBlock(MidHandshakeTlsStream<S>),
242}
243
244impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
245 fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
246 match e {
247 ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
248 ssl::HandshakeError::Failure(e) => {
249 let v = e.ssl().verify_result();
250 HandshakeError::Failure(Error::Ssl(e.into_error(), v))
251 }
252 ssl::HandshakeError::WouldBlock(s) => {
253 HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
254 }
255 }
256 }
257}
258
259impl<S> From<ErrorStack> for HandshakeError<S> {
260 fn from(e: ErrorStack) -> HandshakeError<S> {
261 HandshakeError::Failure(e.into())
262 }
263}
264
265#[derive(Clone)]
266pub struct TlsConnector {
267 connector: SslConnector,
268 use_sni: bool,
269 accept_invalid_hostnames: bool,
270 accept_invalid_certs: bool,
271}
272
273impl TlsConnector {
274 pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
275 init_trust();
276
277 let mut connector = SslConnector::builder(SslMethod::tls())?;
278 if let Some(ref identity) = builder.identity {
279 connector.set_certificate(&identity.0.cert)?;
280 connector.set_private_key(&identity.0.pkey)?;
281 for cert in identity.0.chain.iter() {
282 connector.add_extra_chain_cert(cert.to_owned())?;
286 }
287 }
288 supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
289
290 if builder.disable_built_in_roots {
291 connector.set_cert_store(X509StoreBuilder::new()?.build());
292 }
293
294 for cert in &builder.root_certificates {
295 if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
296 debug!("add_cert error: {:?}", err);
297 }
298 }
299
300 #[cfg(feature = "alpn")]
301 {
302 if !builder.alpn.is_empty() {
303 let mut alpn_wire_format = Vec::with_capacity(
305 builder
306 .alpn
307 .iter()
308 .map(|s| s.as_bytes().len())
309 .sum::<usize>()
310 + builder.alpn.len(),
311 );
312 for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
313 alpn_wire_format.push(alpn.len() as u8);
314 alpn_wire_format.extend(alpn);
315 }
316 connector.set_alpn_protos(&alpn_wire_format)?;
317 }
318 }
319
320 #[cfg(target_os = "android")]
321 load_android_root_certs(&mut connector)?;
322
323 Ok(TlsConnector {
324 connector: connector.build(),
325 use_sni: builder.use_sni,
326 accept_invalid_hostnames: builder.accept_invalid_hostnames,
327 accept_invalid_certs: builder.accept_invalid_certs,
328 })
329 }
330
331 pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
332 where
333 S: io::Read + io::Write,
334 {
335 let mut ssl = self
336 .connector
337 .configure()?
338 .use_server_name_indication(self.use_sni)
339 .verify_hostname(!self.accept_invalid_hostnames);
340 if self.accept_invalid_certs {
341 ssl.set_verify(SslVerifyMode::NONE);
342 }
343
344 let s = ssl.connect(domain, stream)?;
345 Ok(TlsStream(s))
346 }
347}
348
349impl fmt::Debug for TlsConnector {
350 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
351 fmt.debug_struct("TlsConnector")
352 .field("use_sni", &self.use_sni)
354 .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
355 .field("accept_invalid_certs", &self.accept_invalid_certs)
356 .finish()
357 }
358}
359
360#[derive(Clone)]
361pub struct TlsAcceptor(SslAcceptor);
362
363impl TlsAcceptor {
364 pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
365 let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
366 acceptor.set_private_key(&builder.identity.0.pkey)?;
367 acceptor.set_certificate(&builder.identity.0.cert)?;
368 for cert in builder.identity.0.chain.iter() {
369 acceptor.add_extra_chain_cert(cert.to_owned())?;
373 }
374 supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
375
376 Ok(TlsAcceptor(acceptor.build()))
377 }
378
379 pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
380 where
381 S: io::Read + io::Write,
382 {
383 let s = self.0.accept(stream)?;
384 Ok(TlsStream(s))
385 }
386}
387
388pub struct TlsStream<S>(ssl::SslStream<S>);
389
390impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
391 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
392 fmt::Debug::fmt(&self.0, fmt)
393 }
394}
395
396impl<S> TlsStream<S> {
397 pub fn get_ref(&self) -> &S {
398 self.0.get_ref()
399 }
400
401 pub fn get_mut(&mut self) -> &mut S {
402 self.0.get_mut()
403 }
404}
405
406impl<S: io::Read + io::Write> TlsStream<S> {
407 pub fn buffered_read_size(&self) -> Result<usize, Error> {
408 Ok(self.0.ssl().pending())
409 }
410
411 pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
412 Ok(self.0.ssl().peer_certificate().map(Certificate))
413 }
414
415 #[cfg(feature = "alpn")]
416 pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
417 Ok(self
418 .0
419 .ssl()
420 .selected_alpn_protocol()
421 .map(|alpn| alpn.to_vec()))
422 }
423
424 pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
425 let cert = if self.0.ssl().is_server() {
426 self.0.ssl().certificate().map(|x| x.to_owned())
427 } else {
428 self.0.ssl().peer_certificate()
429 };
430
431 let cert = match cert {
432 Some(cert) => cert,
433 None => return Ok(None),
434 };
435
436 let algo_nid = cert.signature_algorithm().object().nid();
437 let signature_algorithms = match algo_nid.signature_algorithms() {
438 Some(algs) => algs,
439 None => return Ok(None),
440 };
441
442 let md = match signature_algorithms.digest {
443 Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
444 nid => match MessageDigest::from_nid(nid) {
445 Some(md) => md,
446 None => return Ok(None),
447 },
448 };
449
450 let digest = cert.digest(md)?;
451
452 Ok(Some(digest.to_vec()))
453 }
454
455 pub fn shutdown(&mut self) -> io::Result<()> {
456 match self.0.shutdown() {
457 Ok(_) => Ok(()),
458 Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
459 Err(e) => Err(e
460 .into_io_error()
461 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
462 }
463 }
464}
465
466impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
467 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
468 self.0.read(buf)
469 }
470}
471
472impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
473 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
474 self.0.write(buf)
475 }
476
477 fn flush(&mut self) -> io::Result<()> {
478 self.0.flush()
479 }
480}