Spaces:
Running
Running
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
"""trio async I/O library query support""" | |
import socket | |
import trio | |
import trio.socket # type: ignore | |
import dns._asyncbackend | |
import dns._features | |
import dns.exception | |
import dns.inet | |
if not dns._features.have("trio"): | |
raise ImportError("trio not found or too old") | |
def _maybe_timeout(timeout): | |
if timeout is not None: | |
return trio.move_on_after(timeout) | |
else: | |
return dns._asyncbackend.NullContext() | |
# for brevity | |
_lltuple = dns.inet.low_level_address_tuple | |
# pylint: disable=redefined-outer-name | |
class DatagramSocket(dns._asyncbackend.DatagramSocket): | |
def __init__(self, socket): | |
super().__init__(socket.family) | |
self.socket = socket | |
async def sendto(self, what, destination, timeout): | |
with _maybe_timeout(timeout): | |
return await self.socket.sendto(what, destination) | |
raise dns.exception.Timeout( | |
timeout=timeout | |
) # pragma: no cover lgtm[py/unreachable-statement] | |
async def recvfrom(self, size, timeout): | |
with _maybe_timeout(timeout): | |
return await self.socket.recvfrom(size) | |
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] | |
async def close(self): | |
self.socket.close() | |
async def getpeername(self): | |
return self.socket.getpeername() | |
async def getsockname(self): | |
return self.socket.getsockname() | |
async def getpeercert(self, timeout): | |
raise NotImplementedError | |
class StreamSocket(dns._asyncbackend.StreamSocket): | |
def __init__(self, family, stream, tls=False): | |
self.family = family | |
self.stream = stream | |
self.tls = tls | |
async def sendall(self, what, timeout): | |
with _maybe_timeout(timeout): | |
return await self.stream.send_all(what) | |
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] | |
async def recv(self, size, timeout): | |
with _maybe_timeout(timeout): | |
return await self.stream.receive_some(size) | |
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement] | |
async def close(self): | |
await self.stream.aclose() | |
async def getpeername(self): | |
if self.tls: | |
return self.stream.transport_stream.socket.getpeername() | |
else: | |
return self.stream.socket.getpeername() | |
async def getsockname(self): | |
if self.tls: | |
return self.stream.transport_stream.socket.getsockname() | |
else: | |
return self.stream.socket.getsockname() | |
async def getpeercert(self, timeout): | |
if self.tls: | |
with _maybe_timeout(timeout): | |
await self.stream.do_handshake() | |
return self.stream.getpeercert() | |
else: | |
raise NotImplementedError | |
if dns._features.have("doh"): | |
import httpcore | |
import httpcore._backends.trio | |
import httpx | |
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend | |
_CoreTrioStream = httpcore._backends.trio.TrioStream | |
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining | |
class _NetworkBackend(_CoreAsyncNetworkBackend): | |
def __init__(self, resolver, local_port, bootstrap_address, family): | |
super().__init__() | |
self._local_port = local_port | |
self._resolver = resolver | |
self._bootstrap_address = bootstrap_address | |
self._family = family | |
async def connect_tcp( | |
self, host, port, timeout, local_address, socket_options=None | |
): # pylint: disable=signature-differs | |
addresses = [] | |
_, expiration = _compute_times(timeout) | |
if dns.inet.is_address(host): | |
addresses.append(host) | |
elif self._bootstrap_address is not None: | |
addresses.append(self._bootstrap_address) | |
else: | |
timeout = _remaining(expiration) | |
family = self._family | |
if local_address: | |
family = dns.inet.af_for_address(local_address) | |
answers = await self._resolver.resolve_name( | |
host, family=family, lifetime=timeout | |
) | |
addresses = answers.addresses() | |
for address in addresses: | |
try: | |
af = dns.inet.af_for_address(address) | |
if local_address is not None or self._local_port != 0: | |
source = (local_address, self._local_port) | |
else: | |
source = None | |
destination = (address, port) | |
attempt_expiration = _expiration_for_this_attempt(2.0, expiration) | |
timeout = _remaining(attempt_expiration) | |
sock = await Backend().make_socket( | |
af, socket.SOCK_STREAM, 0, source, destination, timeout | |
) | |
return _CoreTrioStream(sock.stream) | |
except Exception: | |
continue | |
raise httpcore.ConnectError | |
async def connect_unix_socket( | |
self, path, timeout, socket_options=None | |
): # pylint: disable=signature-differs | |
raise NotImplementedError | |
async def sleep(self, seconds): # pylint: disable=signature-differs | |
await trio.sleep(seconds) | |
class _HTTPTransport(httpx.AsyncHTTPTransport): | |
def __init__( | |
self, | |
*args, | |
local_port=0, | |
bootstrap_address=None, | |
resolver=None, | |
family=socket.AF_UNSPEC, | |
**kwargs, | |
): | |
if resolver is None: | |
# pylint: disable=import-outside-toplevel,redefined-outer-name | |
import dns.asyncresolver | |
resolver = dns.asyncresolver.Resolver() | |
super().__init__(*args, **kwargs) | |
self._pool._network_backend = _NetworkBackend( | |
resolver, local_port, bootstrap_address, family | |
) | |
else: | |
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore | |
class Backend(dns._asyncbackend.Backend): | |
def name(self): | |
return "trio" | |
async def make_socket( | |
self, | |
af, | |
socktype, | |
proto=0, | |
source=None, | |
destination=None, | |
timeout=None, | |
ssl_context=None, | |
server_hostname=None, | |
): | |
s = trio.socket.socket(af, socktype, proto) | |
stream = None | |
try: | |
if source: | |
await s.bind(_lltuple(source, af)) | |
if socktype == socket.SOCK_STREAM: | |
connected = False | |
with _maybe_timeout(timeout): | |
await s.connect(_lltuple(destination, af)) | |
connected = True | |
if not connected: | |
raise dns.exception.Timeout( | |
timeout=timeout | |
) # lgtm[py/unreachable-statement] | |
except Exception: # pragma: no cover | |
s.close() | |
raise | |
if socktype == socket.SOCK_DGRAM: | |
return DatagramSocket(s) | |
elif socktype == socket.SOCK_STREAM: | |
stream = trio.SocketStream(s) | |
tls = False | |
if ssl_context: | |
tls = True | |
try: | |
stream = trio.SSLStream( | |
stream, ssl_context, server_hostname=server_hostname | |
) | |
except Exception: # pragma: no cover | |
await stream.aclose() | |
raise | |
return StreamSocket(af, stream, tls) | |
raise NotImplementedError( | |
"unsupported socket " + f"type {socktype}" | |
) # pragma: no cover | |
async def sleep(self, interval): | |
await trio.sleep(interval) | |
def get_transport_class(self): | |
return _HTTPTransport | |
async def wait_for(self, awaitable, timeout): | |
with _maybe_timeout(timeout): | |
return await awaitable | |
raise dns.exception.Timeout( | |
timeout=timeout | |
) # pragma: no cover lgtm[py/unreachable-statement] | |