Spaces:
Running
Running
""" | |
Module for using pyOpenSSL as a TLS backend. This module was relevant before | |
the standard library ``ssl`` module supported SNI, but now that we've dropped | |
support for Python 2.7 all relevant Python versions support SNI so | |
**this module is no longer recommended**. | |
This needs the following packages installed: | |
* `pyOpenSSL`_ (tested with 16.0.0) | |
* `cryptography`_ (minimum 1.3.4, from pyopenssl) | |
* `idna`_ (minimum 2.0) | |
However, pyOpenSSL depends on cryptography, so while we use all three directly here we | |
end up having relatively few packages required. | |
You can install them with the following command: | |
.. code-block:: bash | |
$ python -m pip install pyopenssl cryptography idna | |
To activate certificate checking, call | |
:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code | |
before you begin making HTTP requests. This can be done in a ``sitecustomize`` | |
module, or at any other time before your application begins using ``urllib3``, | |
like this: | |
.. code-block:: python | |
try: | |
import urllib3.contrib.pyopenssl | |
urllib3.contrib.pyopenssl.inject_into_urllib3() | |
except ImportError: | |
pass | |
.. _pyopenssl: https://www.pyopenssl.org | |
.. _cryptography: https://cryptography.io | |
.. _idna: https://github.com/kjd/idna | |
""" | |
from __future__ import annotations | |
import OpenSSL.SSL # type: ignore[import-untyped] | |
from cryptography import x509 | |
try: | |
from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined] | |
except ImportError: | |
# UnsupportedExtension is gone in cryptography >= 2.1.0 | |
class UnsupportedExtension(Exception): # type: ignore[no-redef] | |
pass | |
import logging | |
import ssl | |
import typing | |
from io import BytesIO | |
from socket import socket as socket_cls | |
from socket import timeout | |
from .. import util | |
if typing.TYPE_CHECKING: | |
from OpenSSL.crypto import X509 # type: ignore[import-untyped] | |
__all__ = ["inject_into_urllib3", "extract_from_urllib3"] | |
# Map from urllib3 to PyOpenSSL compatible parameter-values. | |
_openssl_versions: dict[int, int] = { | |
util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] | |
util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] | |
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, | |
} | |
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"): | |
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD | |
if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"): | |
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD | |
_stdlib_to_openssl_verify = { | |
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, | |
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, | |
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER | |
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, | |
} | |
_openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()} | |
# The SSLvX values are the most likely to be missing in the future | |
# but we check them all just to be sure. | |
_OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr( | |
OpenSSL.SSL, "OP_NO_SSLv3", 0 | |
) | |
_OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0) | |
_OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0) | |
_OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0) | |
_OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0) | |
_openssl_to_ssl_minimum_version: dict[int, int] = { | |
ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, | |
ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3, | |
ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1, | |
ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1, | |
ssl.TLSVersion.TLSv1_3: ( | |
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | |
), | |
ssl.TLSVersion.MAXIMUM_SUPPORTED: ( | |
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | |
), | |
} | |
_openssl_to_ssl_maximum_version: dict[int, int] = { | |
ssl.TLSVersion.MINIMUM_SUPPORTED: ( | |
_OP_NO_SSLv2_OR_SSLv3 | |
| _OP_NO_TLSv1 | |
| _OP_NO_TLSv1_1 | |
| _OP_NO_TLSv1_2 | |
| _OP_NO_TLSv1_3 | |
), | |
ssl.TLSVersion.TLSv1: ( | |
_OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3 | |
), | |
ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3, | |
ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3, | |
ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3, | |
ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, | |
} | |
# OpenSSL will only write 16K at a time | |
SSL_WRITE_BLOCKSIZE = 16384 | |
orig_util_SSLContext = util.ssl_.SSLContext | |
log = logging.getLogger(__name__) | |
def inject_into_urllib3() -> None: | |
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support." | |
_validate_dependencies_met() | |
util.SSLContext = PyOpenSSLContext # type: ignore[assignment] | |
util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment] | |
util.IS_PYOPENSSL = True | |
util.ssl_.IS_PYOPENSSL = True | |
def extract_from_urllib3() -> None: | |
"Undo monkey-patching by :func:`inject_into_urllib3`." | |
util.SSLContext = orig_util_SSLContext | |
util.ssl_.SSLContext = orig_util_SSLContext | |
util.IS_PYOPENSSL = False | |
util.ssl_.IS_PYOPENSSL = False | |
def _validate_dependencies_met() -> None: | |
""" | |
Verifies that PyOpenSSL's package-level dependencies have been met. | |
Throws `ImportError` if they are not met. | |
""" | |
# Method added in `cryptography==1.1`; not available in older versions | |
from cryptography.x509.extensions import Extensions | |
if getattr(Extensions, "get_extension_for_class", None) is None: | |
raise ImportError( | |
"'cryptography' module missing required functionality. " | |
"Try upgrading to v1.3.4 or newer." | |
) | |
# pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509 | |
# attribute is only present on those versions. | |
from OpenSSL.crypto import X509 | |
x509 = X509() | |
if getattr(x509, "_x509", None) is None: | |
raise ImportError( | |
"'pyOpenSSL' module missing required functionality. " | |
"Try upgrading to v0.14 or newer." | |
) | |
def _dnsname_to_stdlib(name: str) -> str | None: | |
""" | |
Converts a dNSName SubjectAlternativeName field to the form used by the | |
standard library on the given Python version. | |
Cryptography produces a dNSName as a unicode string that was idna-decoded | |
from ASCII bytes. We need to idna-encode that string to get it back, and | |
then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib | |
uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8). | |
If the name cannot be idna-encoded then we return None signalling that | |
the name given should be skipped. | |
""" | |
def idna_encode(name: str) -> bytes | None: | |
""" | |
Borrowed wholesale from the Python Cryptography Project. It turns out | |
that we can't just safely call `idna.encode`: it can explode for | |
wildcard names. This avoids that problem. | |
""" | |
import idna | |
try: | |
for prefix in ["*.", "."]: | |
if name.startswith(prefix): | |
name = name[len(prefix) :] | |
return prefix.encode("ascii") + idna.encode(name) | |
return idna.encode(name) | |
except idna.core.IDNAError: | |
return None | |
# Don't send IPv6 addresses through the IDNA encoder. | |
if ":" in name: | |
return name | |
encoded_name = idna_encode(name) | |
if encoded_name is None: | |
return None | |
return encoded_name.decode("utf-8") | |
def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]: | |
""" | |
Given an PyOpenSSL certificate, provides all the subject alternative names. | |
""" | |
cert = peer_cert.to_cryptography() | |
# We want to find the SAN extension. Ask Cryptography to locate it (it's | |
# faster than looping in Python) | |
try: | |
ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value | |
except x509.ExtensionNotFound: | |
# No such extension, return the empty list. | |
return [] | |
except ( | |
x509.DuplicateExtension, | |
UnsupportedExtension, | |
x509.UnsupportedGeneralNameType, | |
UnicodeError, | |
) as e: | |
# A problem has been found with the quality of the certificate. Assume | |
# no SAN field is present. | |
log.warning( | |
"A problem was encountered with the certificate that prevented " | |
"urllib3 from finding the SubjectAlternativeName field. This can " | |
"affect certificate validation. The error was %s", | |
e, | |
) | |
return [] | |
# We want to return dNSName and iPAddress fields. We need to cast the IPs | |
# back to strings because the match_hostname function wants them as | |
# strings. | |
# Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8 | |
# decoded. This is pretty frustrating, but that's what the standard library | |
# does with certificates, and so we need to attempt to do the same. | |
# We also want to skip over names which cannot be idna encoded. | |
names = [ | |
("DNS", name) | |
for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName)) | |
if name is not None | |
] | |
names.extend( | |
("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress) | |
) | |
return names | |
class WrappedSocket: | |
"""API-compatibility wrapper for Python OpenSSL's Connection-class.""" | |
def __init__( | |
self, | |
connection: OpenSSL.SSL.Connection, | |
socket: socket_cls, | |
suppress_ragged_eofs: bool = True, | |
) -> None: | |
self.connection = connection | |
self.socket = socket | |
self.suppress_ragged_eofs = suppress_ragged_eofs | |
self._io_refs = 0 | |
self._closed = False | |
def fileno(self) -> int: | |
return self.socket.fileno() | |
# Copy-pasted from Python 3.5 source code | |
def _decref_socketios(self) -> None: | |
if self._io_refs > 0: | |
self._io_refs -= 1 | |
if self._closed: | |
self.close() | |
def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes: | |
try: | |
data = self.connection.recv(*args, **kwargs) | |
except OpenSSL.SSL.SysCallError as e: | |
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): | |
return b"" | |
else: | |
raise OSError(e.args[0], str(e)) from e | |
except OpenSSL.SSL.ZeroReturnError: | |
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: | |
return b"" | |
else: | |
raise | |
except OpenSSL.SSL.WantReadError as e: | |
if not util.wait_for_read(self.socket, self.socket.gettimeout()): | |
raise timeout("The read operation timed out") from e | |
else: | |
return self.recv(*args, **kwargs) | |
# TLS 1.3 post-handshake authentication | |
except OpenSSL.SSL.Error as e: | |
raise ssl.SSLError(f"read error: {e!r}") from e | |
else: | |
return data # type: ignore[no-any-return] | |
def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int: | |
try: | |
return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return] | |
except OpenSSL.SSL.SysCallError as e: | |
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): | |
return 0 | |
else: | |
raise OSError(e.args[0], str(e)) from e | |
except OpenSSL.SSL.ZeroReturnError: | |
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: | |
return 0 | |
else: | |
raise | |
except OpenSSL.SSL.WantReadError as e: | |
if not util.wait_for_read(self.socket, self.socket.gettimeout()): | |
raise timeout("The read operation timed out") from e | |
else: | |
return self.recv_into(*args, **kwargs) | |
# TLS 1.3 post-handshake authentication | |
except OpenSSL.SSL.Error as e: | |
raise ssl.SSLError(f"read error: {e!r}") from e | |
def settimeout(self, timeout: float) -> None: | |
return self.socket.settimeout(timeout) | |
def _send_until_done(self, data: bytes) -> int: | |
while True: | |
try: | |
return self.connection.send(data) # type: ignore[no-any-return] | |
except OpenSSL.SSL.WantWriteError as e: | |
if not util.wait_for_write(self.socket, self.socket.gettimeout()): | |
raise timeout() from e | |
continue | |
except OpenSSL.SSL.SysCallError as e: | |
raise OSError(e.args[0], str(e)) from e | |
def sendall(self, data: bytes) -> None: | |
total_sent = 0 | |
while total_sent < len(data): | |
sent = self._send_until_done( | |
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE] | |
) | |
total_sent += sent | |
def shutdown(self, how: int) -> None: | |
try: | |
self.connection.shutdown() | |
except OpenSSL.SSL.Error as e: | |
raise ssl.SSLError(f"shutdown error: {e!r}") from e | |
def close(self) -> None: | |
self._closed = True | |
if self._io_refs <= 0: | |
self._real_close() | |
def _real_close(self) -> None: | |
try: | |
return self.connection.close() # type: ignore[no-any-return] | |
except OpenSSL.SSL.Error: | |
return | |
def getpeercert( | |
self, binary_form: bool = False | |
) -> dict[str, list[typing.Any]] | None: | |
x509 = self.connection.get_peer_certificate() | |
if not x509: | |
return x509 # type: ignore[no-any-return] | |
if binary_form: | |
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return] | |
return { | |
"subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item] | |
"subjectAltName": get_subj_alt_name(x509), | |
} | |
def version(self) -> str: | |
return self.connection.get_protocol_version_name() # type: ignore[no-any-return] | |
def selected_alpn_protocol(self) -> str | None: | |
alpn_proto = self.connection.get_alpn_proto_negotiated() | |
return alpn_proto.decode() if alpn_proto else None | |
WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined] | |
class PyOpenSSLContext: | |
""" | |
I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible | |
for translating the interface of the standard library ``SSLContext`` object | |
to calls into PyOpenSSL. | |
""" | |
def __init__(self, protocol: int) -> None: | |
self.protocol = _openssl_versions[protocol] | |
self._ctx = OpenSSL.SSL.Context(self.protocol) | |
self._options = 0 | |
self.check_hostname = False | |
self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED | |
self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED | |
def options(self) -> int: | |
return self._options | |
def options(self, value: int) -> None: | |
self._options = value | |
self._set_ctx_options() | |
def verify_mode(self) -> int: | |
return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] | |
def verify_mode(self, value: ssl.VerifyMode) -> None: | |
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback) | |
def set_default_verify_paths(self) -> None: | |
self._ctx.set_default_verify_paths() | |
def set_ciphers(self, ciphers: bytes | str) -> None: | |
if isinstance(ciphers, str): | |
ciphers = ciphers.encode("utf-8") | |
self._ctx.set_cipher_list(ciphers) | |
def load_verify_locations( | |
self, | |
cafile: str | None = None, | |
capath: str | None = None, | |
cadata: bytes | None = None, | |
) -> None: | |
if cafile is not None: | |
cafile = cafile.encode("utf-8") # type: ignore[assignment] | |
if capath is not None: | |
capath = capath.encode("utf-8") # type: ignore[assignment] | |
try: | |
self._ctx.load_verify_locations(cafile, capath) | |
if cadata is not None: | |
self._ctx.load_verify_locations(BytesIO(cadata)) | |
except OpenSSL.SSL.Error as e: | |
raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e | |
def load_cert_chain( | |
self, | |
certfile: str, | |
keyfile: str | None = None, | |
password: str | None = None, | |
) -> None: | |
try: | |
self._ctx.use_certificate_chain_file(certfile) | |
if password is not None: | |
if not isinstance(password, bytes): | |
password = password.encode("utf-8") # type: ignore[assignment] | |
self._ctx.set_passwd_cb(lambda *_: password) | |
self._ctx.use_privatekey_file(keyfile or certfile) | |
except OpenSSL.SSL.Error as e: | |
raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e | |
def set_alpn_protocols(self, protocols: list[bytes | str]) -> None: | |
protocols = [util.util.to_bytes(p, "ascii") for p in protocols] | |
return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return] | |
def wrap_socket( | |
self, | |
sock: socket_cls, | |
server_side: bool = False, | |
do_handshake_on_connect: bool = True, | |
suppress_ragged_eofs: bool = True, | |
server_hostname: bytes | str | None = None, | |
) -> WrappedSocket: | |
cnx = OpenSSL.SSL.Connection(self._ctx, sock) | |
# If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3 | |
if server_hostname and not util.ssl_.is_ipaddress(server_hostname): | |
if isinstance(server_hostname, str): | |
server_hostname = server_hostname.encode("utf-8") | |
cnx.set_tlsext_host_name(server_hostname) | |
cnx.set_connect_state() | |
while True: | |
try: | |
cnx.do_handshake() | |
except OpenSSL.SSL.WantReadError as e: | |
if not util.wait_for_read(sock, sock.gettimeout()): | |
raise timeout("select timed out") from e | |
continue | |
except OpenSSL.SSL.Error as e: | |
raise ssl.SSLError(f"bad handshake: {e!r}") from e | |
break | |
return WrappedSocket(cnx, sock) | |
def _set_ctx_options(self) -> None: | |
self._ctx.set_options( | |
self._options | |
| _openssl_to_ssl_minimum_version[self._minimum_version] | |
| _openssl_to_ssl_maximum_version[self._maximum_version] | |
) | |
def minimum_version(self) -> int: | |
return self._minimum_version | |
def minimum_version(self, minimum_version: int) -> None: | |
self._minimum_version = minimum_version | |
self._set_ctx_options() | |
def maximum_version(self) -> int: | |
return self._maximum_version | |
def maximum_version(self, maximum_version: int) -> None: | |
self._maximum_version = maximum_version | |
self._set_ctx_options() | |
def _verify_callback( | |
cnx: OpenSSL.SSL.Connection, | |
x509: X509, | |
err_no: int, | |
err_depth: int, | |
return_code: int, | |
) -> bool: | |
return err_no == 0 | |