Spaces:
Paused
Paused
| # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
| """asyncio library query support""" | |
| import asyncio | |
| import socket | |
| import sys | |
| import dns._asyncbackend | |
| import dns._features | |
| import dns.exception | |
| import dns.inet | |
| _is_win32 = sys.platform == "win32" | |
| def _get_running_loop(): | |
| try: | |
| return asyncio.get_running_loop() | |
| except AttributeError: # pragma: no cover | |
| return asyncio.get_event_loop() | |
| class _DatagramProtocol: | |
| def __init__(self): | |
| self.transport = None | |
| self.recvfrom = None | |
| def connection_made(self, transport): | |
| self.transport = transport | |
| def datagram_received(self, data, addr): | |
| if self.recvfrom and not self.recvfrom.done(): | |
| self.recvfrom.set_result((data, addr)) | |
| def error_received(self, exc): # pragma: no cover | |
| if self.recvfrom and not self.recvfrom.done(): | |
| self.recvfrom.set_exception(exc) | |
| def connection_lost(self, exc): | |
| if self.recvfrom and not self.recvfrom.done(): | |
| if exc is None: | |
| # EOF we triggered. Is there a better way to do this? | |
| try: | |
| raise EOFError | |
| except EOFError as e: | |
| self.recvfrom.set_exception(e) | |
| else: | |
| self.recvfrom.set_exception(exc) | |
| def close(self): | |
| self.transport.close() | |
| async def _maybe_wait_for(awaitable, timeout): | |
| if timeout is not None: | |
| try: | |
| return await asyncio.wait_for(awaitable, timeout) | |
| except asyncio.TimeoutError: | |
| raise dns.exception.Timeout(timeout=timeout) | |
| else: | |
| return await awaitable | |
| class DatagramSocket(dns._asyncbackend.DatagramSocket): | |
| def __init__(self, family, transport, protocol): | |
| super().__init__(family) | |
| self.transport = transport | |
| self.protocol = protocol | |
| async def sendto(self, what, destination, timeout): # pragma: no cover | |
| # no timeout for asyncio sendto | |
| self.transport.sendto(what, destination) | |
| return len(what) | |
| async def recvfrom(self, size, timeout): | |
| # ignore size as there's no way I know to tell protocol about it | |
| done = _get_running_loop().create_future() | |
| try: | |
| assert self.protocol.recvfrom is None | |
| self.protocol.recvfrom = done | |
| await _maybe_wait_for(done, timeout) | |
| return done.result() | |
| finally: | |
| self.protocol.recvfrom = None | |
| async def close(self): | |
| self.protocol.close() | |
| async def getpeername(self): | |
| return self.transport.get_extra_info("peername") | |
| async def getsockname(self): | |
| return self.transport.get_extra_info("sockname") | |
| async def getpeercert(self, timeout): | |
| raise NotImplementedError | |
| class StreamSocket(dns._asyncbackend.StreamSocket): | |
| def __init__(self, af, reader, writer): | |
| self.family = af | |
| self.reader = reader | |
| self.writer = writer | |
| async def sendall(self, what, timeout): | |
| self.writer.write(what) | |
| return await _maybe_wait_for(self.writer.drain(), timeout) | |
| async def recv(self, size, timeout): | |
| return await _maybe_wait_for(self.reader.read(size), timeout) | |
| async def close(self): | |
| self.writer.close() | |
| async def getpeername(self): | |
| return self.writer.get_extra_info("peername") | |
| async def getsockname(self): | |
| return self.writer.get_extra_info("sockname") | |
| async def getpeercert(self, timeout): | |
| return self.writer.get_extra_info("peercert") | |
| if dns._features.have("doh"): | |
| import anyio | |
| import httpcore | |
| import httpcore._backends.anyio | |
| import httpx | |
| _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend | |
| _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream | |
| from dns.query import _compute_times, _expiration_for_this_attempt, _remaining | |
| class _NetworkBackend(_CoreAsyncNetworkBackend): | |
| def __init__(self, resolver, local_port, bootstrap_address, family): | |
| super().__init__() | |
| self._local_port = local_port | |
| self._resolver = resolver | |
| self._bootstrap_address = bootstrap_address | |
| self._family = family | |
| if local_port != 0: | |
| raise NotImplementedError( | |
| "the asyncio transport for HTTPX cannot set the local port" | |
| ) | |
| async def connect_tcp( | |
| self, host, port, timeout, local_address, socket_options=None | |
| ): # pylint: disable=signature-differs | |
| addresses = [] | |
| _, expiration = _compute_times(timeout) | |
| if dns.inet.is_address(host): | |
| addresses.append(host) | |
| elif self._bootstrap_address is not None: | |
| addresses.append(self._bootstrap_address) | |
| else: | |
| timeout = _remaining(expiration) | |
| family = self._family | |
| if local_address: | |
| family = dns.inet.af_for_address(local_address) | |
| answers = await self._resolver.resolve_name( | |
| host, family=family, lifetime=timeout | |
| ) | |
| addresses = answers.addresses() | |
| for address in addresses: | |
| try: | |
| attempt_expiration = _expiration_for_this_attempt(2.0, expiration) | |
| timeout = _remaining(attempt_expiration) | |
| with anyio.fail_after(timeout): | |
| stream = await anyio.connect_tcp( | |
| remote_host=address, | |
| remote_port=port, | |
| local_host=local_address, | |
| ) | |
| return _CoreAnyIOStream(stream) | |
| except Exception: | |
| pass | |
| raise httpcore.ConnectError | |
| async def connect_unix_socket( | |
| self, path, timeout, socket_options=None | |
| ): # pylint: disable=signature-differs | |
| raise NotImplementedError | |
| async def sleep(self, seconds): # pylint: disable=signature-differs | |
| await anyio.sleep(seconds) | |
| class _HTTPTransport(httpx.AsyncHTTPTransport): | |
| def __init__( | |
| self, | |
| *args, | |
| local_port=0, | |
| bootstrap_address=None, | |
| resolver=None, | |
| family=socket.AF_UNSPEC, | |
| **kwargs, | |
| ): | |
| if resolver is None: | |
| # pylint: disable=import-outside-toplevel,redefined-outer-name | |
| import dns.asyncresolver | |
| resolver = dns.asyncresolver.Resolver() | |
| super().__init__(*args, **kwargs) | |
| self._pool._network_backend = _NetworkBackend( | |
| resolver, local_port, bootstrap_address, family | |
| ) | |
| else: | |
| _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore | |
| class Backend(dns._asyncbackend.Backend): | |
| def name(self): | |
| return "asyncio" | |
| async def make_socket( | |
| self, | |
| af, | |
| socktype, | |
| proto=0, | |
| source=None, | |
| destination=None, | |
| timeout=None, | |
| ssl_context=None, | |
| server_hostname=None, | |
| ): | |
| loop = _get_running_loop() | |
| if socktype == socket.SOCK_DGRAM: | |
| if _is_win32 and source is None: | |
| # Win32 wants explicit binding before recvfrom(). This is the | |
| # proper fix for [#637]. | |
| source = (dns.inet.any_for_af(af), 0) | |
| transport, protocol = await loop.create_datagram_endpoint( | |
| _DatagramProtocol, | |
| source, | |
| family=af, | |
| proto=proto, | |
| remote_addr=destination, | |
| ) | |
| return DatagramSocket(af, transport, protocol) | |
| elif socktype == socket.SOCK_STREAM: | |
| if destination is None: | |
| # This shouldn't happen, but we check to make code analysis software | |
| # happier. | |
| raise ValueError("destination required for stream sockets") | |
| (r, w) = await _maybe_wait_for( | |
| asyncio.open_connection( | |
| destination[0], | |
| destination[1], | |
| ssl=ssl_context, | |
| family=af, | |
| proto=proto, | |
| local_addr=source, | |
| server_hostname=server_hostname, | |
| ), | |
| timeout, | |
| ) | |
| return StreamSocket(af, r, w) | |
| raise NotImplementedError( | |
| "unsupported socket " + f"type {socktype}" | |
| ) # pragma: no cover | |
| async def sleep(self, interval): | |
| await asyncio.sleep(interval) | |
| def datagram_connection_required(self): | |
| return False | |
| def get_transport_class(self): | |
| return _HTTPTransport | |
| async def wait_for(self, awaitable, timeout): | |
| return await _maybe_wait_for(awaitable, timeout) | |