|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Talk to a DNS server.""" |
|
|
|
import base64 |
|
import contextlib |
|
import enum |
|
import errno |
|
import os |
|
import os.path |
|
import selectors |
|
import socket |
|
import struct |
|
import time |
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import dns._features |
|
import dns.exception |
|
import dns.inet |
|
import dns.message |
|
import dns.name |
|
import dns.quic |
|
import dns.rcode |
|
import dns.rdataclass |
|
import dns.rdatatype |
|
import dns.serial |
|
import dns.transaction |
|
import dns.tsig |
|
import dns.xfr |
|
|
|
|
|
def _remaining(expiration): |
|
if expiration is None: |
|
return None |
|
timeout = expiration - time.time() |
|
if timeout <= 0.0: |
|
raise dns.exception.Timeout |
|
return timeout |
|
|
|
|
|
def _expiration_for_this_attempt(timeout, expiration): |
|
if expiration is None: |
|
return None |
|
return min(time.time() + timeout, expiration) |
|
|
|
|
|
_have_httpx = dns._features.have("doh") |
|
if _have_httpx: |
|
import httpcore._backends.sync |
|
import httpx |
|
|
|
_CoreNetworkBackend = httpcore.NetworkBackend |
|
_CoreSyncStream = httpcore._backends.sync.SyncStream |
|
|
|
class _NetworkBackend(_CoreNetworkBackend): |
|
def __init__(self, resolver, local_port, bootstrap_address, family): |
|
super().__init__() |
|
self._local_port = local_port |
|
self._resolver = resolver |
|
self._bootstrap_address = bootstrap_address |
|
self._family = family |
|
|
|
def connect_tcp( |
|
self, host, port, timeout, local_address, socket_options=None |
|
): |
|
addresses = [] |
|
_, expiration = _compute_times(timeout) |
|
if dns.inet.is_address(host): |
|
addresses.append(host) |
|
elif self._bootstrap_address is not None: |
|
addresses.append(self._bootstrap_address) |
|
else: |
|
timeout = _remaining(expiration) |
|
family = self._family |
|
if local_address: |
|
family = dns.inet.af_for_address(local_address) |
|
answers = self._resolver.resolve_name( |
|
host, family=family, lifetime=timeout |
|
) |
|
addresses = answers.addresses() |
|
for address in addresses: |
|
af = dns.inet.af_for_address(address) |
|
if local_address is not None or self._local_port != 0: |
|
source = dns.inet.low_level_address_tuple( |
|
(local_address, self._local_port), af |
|
) |
|
else: |
|
source = None |
|
sock = _make_socket(af, socket.SOCK_STREAM, source) |
|
attempt_expiration = _expiration_for_this_attempt(2.0, expiration) |
|
try: |
|
_connect( |
|
sock, |
|
dns.inet.low_level_address_tuple((address, port), af), |
|
attempt_expiration, |
|
) |
|
return _CoreSyncStream(sock) |
|
except Exception: |
|
pass |
|
raise httpcore.ConnectError |
|
|
|
def connect_unix_socket( |
|
self, path, timeout, socket_options=None |
|
): |
|
raise NotImplementedError |
|
|
|
class _HTTPTransport(httpx.HTTPTransport): |
|
def __init__( |
|
self, |
|
*args, |
|
local_port=0, |
|
bootstrap_address=None, |
|
resolver=None, |
|
family=socket.AF_UNSPEC, |
|
**kwargs, |
|
): |
|
if resolver is None: |
|
|
|
import dns.resolver |
|
|
|
resolver = dns.resolver.Resolver() |
|
super().__init__(*args, **kwargs) |
|
self._pool._network_backend = _NetworkBackend( |
|
resolver, local_port, bootstrap_address, family |
|
) |
|
|
|
else: |
|
|
|
class _HTTPTransport: |
|
def connect_tcp(self, host, port, timeout, local_address): |
|
raise NotImplementedError |
|
|
|
|
|
have_doh = _have_httpx |
|
|
|
try: |
|
import ssl |
|
except ImportError: |
|
|
|
class ssl: |
|
CERT_NONE = 0 |
|
|
|
class WantReadException(Exception): |
|
pass |
|
|
|
class WantWriteException(Exception): |
|
pass |
|
|
|
class SSLContext: |
|
pass |
|
|
|
class SSLSocket: |
|
pass |
|
|
|
@classmethod |
|
def create_default_context(cls, *args, **kwargs): |
|
raise Exception("no ssl support") |
|
|
|
|
|
|
|
|
|
socket_factory = socket.socket |
|
|
|
|
|
class UnexpectedSource(dns.exception.DNSException): |
|
"""A DNS query response came from an unexpected address or port.""" |
|
|
|
|
|
class BadResponse(dns.exception.FormError): |
|
"""A DNS query response does not respond to the question asked.""" |
|
|
|
|
|
class NoDOH(dns.exception.DNSException): |
|
"""DNS over HTTPS (DOH) was requested but the httpx module is not |
|
available.""" |
|
|
|
|
|
class NoDOQ(dns.exception.DNSException): |
|
"""DNS over QUIC (DOQ) was requested but the aioquic module is not |
|
available.""" |
|
|
|
|
|
|
|
TransferError = dns.xfr.TransferError |
|
|
|
|
|
def _compute_times(timeout): |
|
now = time.time() |
|
if timeout is None: |
|
return (now, None) |
|
else: |
|
return (now, now + timeout) |
|
|
|
|
|
def _wait_for(fd, readable, writable, _, expiration): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: |
|
return True |
|
sel = _selector_class() |
|
events = 0 |
|
if readable: |
|
events |= selectors.EVENT_READ |
|
if writable: |
|
events |= selectors.EVENT_WRITE |
|
if events: |
|
sel.register(fd, events) |
|
if expiration is None: |
|
timeout = None |
|
else: |
|
timeout = expiration - time.time() |
|
if timeout <= 0.0: |
|
raise dns.exception.Timeout |
|
if not sel.select(timeout): |
|
raise dns.exception.Timeout |
|
|
|
|
|
def _set_selector_class(selector_class): |
|
|
|
|
|
global _selector_class |
|
|
|
_selector_class = selector_class |
|
|
|
|
|
if hasattr(selectors, "PollSelector"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
_selector_class = selectors.PollSelector |
|
else: |
|
_selector_class = selectors.SelectSelector |
|
|
|
|
|
def _wait_for_readable(s, expiration): |
|
_wait_for(s, True, False, True, expiration) |
|
|
|
|
|
def _wait_for_writable(s, expiration): |
|
_wait_for(s, False, True, True, expiration) |
|
|
|
|
|
def _addresses_equal(af, a1, a2): |
|
|
|
|
|
|
|
try: |
|
n1 = dns.inet.inet_pton(af, a1[0]) |
|
n2 = dns.inet.inet_pton(af, a2[0]) |
|
except dns.exception.SyntaxError: |
|
return False |
|
return n1 == n2 and a1[1:] == a2[1:] |
|
|
|
|
|
def _matches_destination(af, from_address, destination, ignore_unexpected): |
|
|
|
|
|
if not destination: |
|
return True |
|
if _addresses_equal(af, from_address, destination) or ( |
|
dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:] |
|
): |
|
return True |
|
elif ignore_unexpected: |
|
return False |
|
raise UnexpectedSource( |
|
f"got a response from {from_address} instead of " f"{destination}" |
|
) |
|
|
|
|
|
def _destination_and_source( |
|
where, port, source, source_port, where_must_be_address=True |
|
): |
|
|
|
|
|
af = None |
|
destination = None |
|
try: |
|
af = dns.inet.af_for_address(where) |
|
destination = where |
|
except Exception: |
|
if where_must_be_address: |
|
raise |
|
|
|
if source: |
|
saf = dns.inet.af_for_address(source) |
|
if af: |
|
|
|
if saf != af: |
|
raise ValueError( |
|
"different address families for source and destination" |
|
) |
|
else: |
|
|
|
|
|
af = saf |
|
if source_port and not source: |
|
|
|
|
|
|
|
try: |
|
source = dns.inet.any_for_af(af) |
|
except Exception: |
|
|
|
raise ValueError("source_port specified but address family is unknown") |
|
|
|
|
|
if destination: |
|
destination = dns.inet.low_level_address_tuple((destination, port), af) |
|
if source: |
|
source = dns.inet.low_level_address_tuple((source, source_port), af) |
|
return (af, destination, source) |
|
|
|
|
|
def _make_socket(af, type, source, ssl_context=None, server_hostname=None): |
|
s = socket_factory(af, type) |
|
try: |
|
s.setblocking(False) |
|
if source is not None: |
|
s.bind(source) |
|
if ssl_context: |
|
|
|
return ssl_context.wrap_socket( |
|
s, |
|
do_handshake_on_connect=False, |
|
server_hostname=server_hostname, |
|
) |
|
else: |
|
return s |
|
except Exception: |
|
s.close() |
|
raise |
|
|
|
|
|
def https( |
|
q: dns.message.Message, |
|
where: str, |
|
timeout: Optional[float] = None, |
|
port: int = 443, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
one_rr_per_rrset: bool = False, |
|
ignore_trailing: bool = False, |
|
session: Optional[Any] = None, |
|
path: str = "/dns-query", |
|
post: bool = True, |
|
bootstrap_address: Optional[str] = None, |
|
verify: Union[bool, str] = True, |
|
resolver: Optional["dns.resolver.Resolver"] = None, |
|
family: Optional[int] = socket.AF_UNSPEC, |
|
) -> dns.message.Message: |
|
"""Return the response obtained after sending a query via DNS-over-HTTPS. |
|
|
|
*q*, a ``dns.message.Message``, the query to send. |
|
|
|
*where*, a ``str``, the nameserver IP address or the full URL. If an IP address is |
|
given, the URL will be constructed using the following schema: |
|
https://<IP-address>:<port>/<path>. |
|
|
|
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query |
|
times out. If ``None``, the default, wait forever. |
|
|
|
*port*, a ``int``, the port to send the query to. The default is 443. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source |
|
address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. The default is |
|
0. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the |
|
received message. |
|
|
|
*session*, an ``httpx.Client``. If provided, the client session to use to send the |
|
queries. |
|
|
|
*path*, a ``str``. If *where* is an IP address, then *path* will be used to |
|
construct the URL to send the DNS query to. |
|
|
|
*post*, a ``bool``. If ``True``, the default, POST method will be used. |
|
|
|
*bootstrap_address*, a ``str``, the IP address to use to bypass resolution. |
|
|
|
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification |
|
of the server is done using the default CA bundle; if ``False``, then no |
|
verification is done; if a `str` then it specifies the path to a certificate file or |
|
directory which will be used for verification. |
|
|
|
*resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for |
|
resolution of hostnames in URLs. If not specified, a new resolver with a default |
|
configuration will be used; note this is *not* the default resolver as that resolver |
|
might have been configured to use DoH causing a chicken-and-egg problem. This |
|
parameter only has an effect if the HTTP library is httpx. |
|
|
|
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A |
|
and AAAA records will be retrieved. |
|
|
|
Returns a ``dns.message.Message``. |
|
""" |
|
|
|
if not have_doh: |
|
raise NoDOH |
|
if session and not isinstance(session, httpx.Client): |
|
raise ValueError("session parameter must be an httpx.Client") |
|
|
|
wire = q.to_wire() |
|
(af, _, the_source) = _destination_and_source( |
|
where, port, source, source_port, False |
|
) |
|
transport = None |
|
headers = {"accept": "application/dns-message"} |
|
if af is not None and dns.inet.is_address(where): |
|
if af == socket.AF_INET: |
|
url = "https://{}:{}{}".format(where, port, path) |
|
elif af == socket.AF_INET6: |
|
url = "https://[{}]:{}{}".format(where, port, path) |
|
else: |
|
url = where |
|
|
|
|
|
|
|
if the_source is None: |
|
local_address = None |
|
local_port = 0 |
|
else: |
|
local_address = the_source[0] |
|
local_port = the_source[1] |
|
transport = _HTTPTransport( |
|
local_address=local_address, |
|
http1=True, |
|
http2=True, |
|
verify=verify, |
|
local_port=local_port, |
|
bootstrap_address=bootstrap_address, |
|
resolver=resolver, |
|
family=family, |
|
) |
|
|
|
if session: |
|
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) |
|
else: |
|
cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) |
|
with cm as session: |
|
|
|
|
|
if post: |
|
headers.update( |
|
{ |
|
"content-type": "application/dns-message", |
|
"content-length": str(len(wire)), |
|
} |
|
) |
|
response = session.post(url, headers=headers, content=wire, timeout=timeout) |
|
else: |
|
wire = base64.urlsafe_b64encode(wire).rstrip(b"=") |
|
twire = wire.decode() |
|
response = session.get( |
|
url, headers=headers, timeout=timeout, params={"dns": twire} |
|
) |
|
|
|
|
|
|
|
if response.status_code < 200 or response.status_code > 299: |
|
raise ValueError( |
|
"{} responded with status code {}" |
|
"\nResponse body: {}".format(where, response.status_code, response.content) |
|
) |
|
r = dns.message.from_wire( |
|
response.content, |
|
keyring=q.keyring, |
|
request_mac=q.request_mac, |
|
one_rr_per_rrset=one_rr_per_rrset, |
|
ignore_trailing=ignore_trailing, |
|
) |
|
r.time = response.elapsed.total_seconds() |
|
if not q.is_response(r): |
|
raise BadResponse |
|
return r |
|
|
|
|
|
def _udp_recv(sock, max_size, expiration): |
|
"""Reads a datagram from the socket. |
|
A Timeout exception will be raised if the operation is not completed |
|
by the expiration time. |
|
""" |
|
while True: |
|
try: |
|
return sock.recvfrom(max_size) |
|
except BlockingIOError: |
|
_wait_for_readable(sock, expiration) |
|
|
|
|
|
def _udp_send(sock, data, destination, expiration): |
|
"""Sends the specified datagram to destination over the socket. |
|
A Timeout exception will be raised if the operation is not completed |
|
by the expiration time. |
|
""" |
|
while True: |
|
try: |
|
if destination: |
|
return sock.sendto(data, destination) |
|
else: |
|
return sock.send(data) |
|
except BlockingIOError: |
|
_wait_for_writable(sock, expiration) |
|
|
|
|
|
def send_udp( |
|
sock: Any, |
|
what: Union[dns.message.Message, bytes], |
|
destination: Any, |
|
expiration: Optional[float] = None, |
|
) -> Tuple[int, float]: |
|
"""Send a DNS message to the specified UDP socket. |
|
|
|
*sock*, a ``socket``. |
|
|
|
*what*, a ``bytes`` or ``dns.message.Message``, the message to send. |
|
|
|
*destination*, a destination tuple appropriate for the address family |
|
of the socket, specifying where to send the query. |
|
|
|
*expiration*, a ``float`` or ``None``, the absolute time at which |
|
a timeout exception should be raised. If ``None``, no timeout will |
|
occur. |
|
|
|
Returns an ``(int, float)`` tuple of bytes sent and the sent time. |
|
""" |
|
|
|
if isinstance(what, dns.message.Message): |
|
what = what.to_wire() |
|
sent_time = time.time() |
|
n = _udp_send(sock, what, destination, expiration) |
|
return (n, sent_time) |
|
|
|
|
|
def receive_udp( |
|
sock: Any, |
|
destination: Optional[Any] = None, |
|
expiration: Optional[float] = None, |
|
ignore_unexpected: bool = False, |
|
one_rr_per_rrset: bool = False, |
|
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, |
|
request_mac: Optional[bytes] = b"", |
|
ignore_trailing: bool = False, |
|
raise_on_truncation: bool = False, |
|
ignore_errors: bool = False, |
|
query: Optional[dns.message.Message] = None, |
|
) -> Any: |
|
"""Read a DNS message from a UDP socket. |
|
|
|
*sock*, a ``socket``. |
|
|
|
*destination*, a destination tuple appropriate for the address family |
|
of the socket, specifying where the message is expected to arrive from. |
|
When receiving a response, this would be where the associated query was |
|
sent. |
|
|
|
*expiration*, a ``float`` or ``None``, the absolute time at which |
|
a timeout exception should be raised. If ``None``, no timeout will |
|
occur. |
|
|
|
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from |
|
unexpected sources. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own |
|
RRset. |
|
|
|
*keyring*, a ``dict``, the keyring to use for TSIG. |
|
|
|
*request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing |
|
junk at end of the received message. |
|
|
|
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if |
|
the TC bit is set. |
|
|
|
Raises if the message is malformed, if network errors occur, of if |
|
there is a timeout. |
|
|
|
If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` |
|
tuple of the received message and the received time. |
|
|
|
If *destination* is ``None``, returns a |
|
``(dns.message.Message, float, tuple)`` |
|
tuple of the received message, the received time, and the address where |
|
the message arrived from. |
|
|
|
*ignore_errors*, a ``bool``. If various format errors or response |
|
mismatches occur, ignore them and keep listening for a valid response. |
|
The default is ``False``. |
|
|
|
*query*, a ``dns.message.Message`` or ``None``. If not ``None`` and |
|
*ignore_errors* is ``True``, check that the received message is a response |
|
to this query, and if not keep listening for a valid response. |
|
""" |
|
|
|
wire = b"" |
|
while True: |
|
(wire, from_address) = _udp_recv(sock, 65535, expiration) |
|
if not _matches_destination( |
|
sock.family, from_address, destination, ignore_unexpected |
|
): |
|
continue |
|
received_time = time.time() |
|
try: |
|
r = dns.message.from_wire( |
|
wire, |
|
keyring=keyring, |
|
request_mac=request_mac, |
|
one_rr_per_rrset=one_rr_per_rrset, |
|
ignore_trailing=ignore_trailing, |
|
raise_on_truncation=raise_on_truncation, |
|
) |
|
except dns.message.Truncated as e: |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
ignore_errors |
|
and query is not None |
|
and not query.is_response(e.message()) |
|
): |
|
continue |
|
else: |
|
raise |
|
except Exception: |
|
if ignore_errors: |
|
continue |
|
else: |
|
raise |
|
if ignore_errors and query is not None and not query.is_response(r): |
|
continue |
|
if destination: |
|
return (r, received_time) |
|
else: |
|
return (r, received_time, from_address) |
|
|
|
|
|
def udp( |
|
q: dns.message.Message, |
|
where: str, |
|
timeout: Optional[float] = None, |
|
port: int = 53, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
ignore_unexpected: bool = False, |
|
one_rr_per_rrset: bool = False, |
|
ignore_trailing: bool = False, |
|
raise_on_truncation: bool = False, |
|
sock: Optional[Any] = None, |
|
ignore_errors: bool = False, |
|
) -> dns.message.Message: |
|
"""Return the response obtained after sending a query via UDP. |
|
|
|
*q*, a ``dns.message.Message``, the query to send |
|
|
|
*where*, a ``str`` containing an IPv4 or IPv6 address, where |
|
to send the message. |
|
|
|
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the |
|
query times out. If ``None``, the default, wait forever. |
|
|
|
*port*, an ``int``, the port send the message to. The default is 53. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying |
|
the source address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. |
|
The default is 0. |
|
|
|
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from |
|
unexpected sources. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own |
|
RRset. |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing |
|
junk at end of the received message. |
|
|
|
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if |
|
the TC bit is set. |
|
|
|
*sock*, a ``socket.socket``, or ``None``, the socket to use for the |
|
query. If ``None``, the default, a socket is created. Note that |
|
if a socket is provided, it must be a nonblocking datagram socket, |
|
and the *source* and *source_port* are ignored. |
|
|
|
*ignore_errors*, a ``bool``. If various format errors or response |
|
mismatches occur, ignore them and keep listening for a valid response. |
|
The default is ``False``. |
|
|
|
Returns a ``dns.message.Message``. |
|
""" |
|
|
|
wire = q.to_wire() |
|
(af, destination, source) = _destination_and_source( |
|
where, port, source, source_port |
|
) |
|
(begin_time, expiration) = _compute_times(timeout) |
|
if sock: |
|
cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) |
|
else: |
|
cm = _make_socket(af, socket.SOCK_DGRAM, source) |
|
with cm as s: |
|
send_udp(s, wire, destination, expiration) |
|
(r, received_time) = receive_udp( |
|
s, |
|
destination, |
|
expiration, |
|
ignore_unexpected, |
|
one_rr_per_rrset, |
|
q.keyring, |
|
q.mac, |
|
ignore_trailing, |
|
raise_on_truncation, |
|
ignore_errors, |
|
q, |
|
) |
|
r.time = received_time - begin_time |
|
|
|
|
|
if not (ignore_errors or q.is_response(r)): |
|
raise BadResponse |
|
return r |
|
assert ( |
|
False |
|
) |
|
|
|
|
|
def udp_with_fallback( |
|
q: dns.message.Message, |
|
where: str, |
|
timeout: Optional[float] = None, |
|
port: int = 53, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
ignore_unexpected: bool = False, |
|
one_rr_per_rrset: bool = False, |
|
ignore_trailing: bool = False, |
|
udp_sock: Optional[Any] = None, |
|
tcp_sock: Optional[Any] = None, |
|
ignore_errors: bool = False, |
|
) -> Tuple[dns.message.Message, bool]: |
|
"""Return the response to the query, trying UDP first and falling back |
|
to TCP if UDP results in a truncated response. |
|
|
|
*q*, a ``dns.message.Message``, the query to send |
|
|
|
*where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message. |
|
|
|
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query |
|
times out. If ``None``, the default, wait forever. |
|
|
|
*port*, an ``int``, the port send the message to. The default is 53. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source |
|
address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. The default is |
|
0. |
|
|
|
*ignore_unexpected*, a ``bool``. If ``True``, ignore responses from unexpected |
|
sources. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the |
|
received message. |
|
|
|
*udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query. |
|
If ``None``, the default, a socket is created. Note that if a socket is provided, |
|
it must be a nonblocking datagram socket, and the *source* and *source_port* are |
|
ignored for the UDP query. |
|
|
|
*tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the |
|
TCP query. If ``None``, the default, a socket is created. Note that if a socket is |
|
provided, it must be a nonblocking connected stream socket, and *where*, *source* |
|
and *source_port* are ignored for the TCP query. |
|
|
|
*ignore_errors*, a ``bool``. If various format errors or response mismatches occur |
|
while listening for UDP, ignore them and keep listening for a valid response. The |
|
default is ``False``. |
|
|
|
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if |
|
TCP was used. |
|
""" |
|
try: |
|
response = udp( |
|
q, |
|
where, |
|
timeout, |
|
port, |
|
source, |
|
source_port, |
|
ignore_unexpected, |
|
one_rr_per_rrset, |
|
ignore_trailing, |
|
True, |
|
udp_sock, |
|
ignore_errors, |
|
) |
|
return (response, False) |
|
except dns.message.Truncated: |
|
response = tcp( |
|
q, |
|
where, |
|
timeout, |
|
port, |
|
source, |
|
source_port, |
|
one_rr_per_rrset, |
|
ignore_trailing, |
|
tcp_sock, |
|
) |
|
return (response, True) |
|
|
|
|
|
def _net_read(sock, count, expiration): |
|
"""Read the specified number of bytes from sock. Keep trying until we |
|
either get the desired amount, or we hit EOF. |
|
A Timeout exception will be raised if the operation is not completed |
|
by the expiration time. |
|
""" |
|
s = b"" |
|
while count > 0: |
|
try: |
|
n = sock.recv(count) |
|
if n == b"": |
|
raise EOFError |
|
count -= len(n) |
|
s += n |
|
except (BlockingIOError, ssl.SSLWantReadError): |
|
_wait_for_readable(sock, expiration) |
|
except ssl.SSLWantWriteError: |
|
_wait_for_writable(sock, expiration) |
|
return s |
|
|
|
|
|
def _net_write(sock, data, expiration): |
|
"""Write the specified data to the socket. |
|
A Timeout exception will be raised if the operation is not completed |
|
by the expiration time. |
|
""" |
|
current = 0 |
|
l = len(data) |
|
while current < l: |
|
try: |
|
current += sock.send(data[current:]) |
|
except (BlockingIOError, ssl.SSLWantWriteError): |
|
_wait_for_writable(sock, expiration) |
|
except ssl.SSLWantReadError: |
|
_wait_for_readable(sock, expiration) |
|
|
|
|
|
def send_tcp( |
|
sock: Any, |
|
what: Union[dns.message.Message, bytes], |
|
expiration: Optional[float] = None, |
|
) -> Tuple[int, float]: |
|
"""Send a DNS message to the specified TCP socket. |
|
|
|
*sock*, a ``socket``. |
|
|
|
*what*, a ``bytes`` or ``dns.message.Message``, the message to send. |
|
|
|
*expiration*, a ``float`` or ``None``, the absolute time at which |
|
a timeout exception should be raised. If ``None``, no timeout will |
|
occur. |
|
|
|
Returns an ``(int, float)`` tuple of bytes sent and the sent time. |
|
""" |
|
|
|
if isinstance(what, dns.message.Message): |
|
tcpmsg = what.to_wire(prepend_length=True) |
|
else: |
|
|
|
|
|
|
|
tcpmsg = len(what).to_bytes(2, "big") + what |
|
sent_time = time.time() |
|
_net_write(sock, tcpmsg, expiration) |
|
return (len(tcpmsg), sent_time) |
|
|
|
|
|
def receive_tcp( |
|
sock: Any, |
|
expiration: Optional[float] = None, |
|
one_rr_per_rrset: bool = False, |
|
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, |
|
request_mac: Optional[bytes] = b"", |
|
ignore_trailing: bool = False, |
|
) -> Tuple[dns.message.Message, float]: |
|
"""Read a DNS message from a TCP socket. |
|
|
|
*sock*, a ``socket``. |
|
|
|
*expiration*, a ``float`` or ``None``, the absolute time at which |
|
a timeout exception should be raised. If ``None``, no timeout will |
|
occur. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own |
|
RRset. |
|
|
|
*keyring*, a ``dict``, the keyring to use for TSIG. |
|
|
|
*request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing |
|
junk at end of the received message. |
|
|
|
Raises if the message is malformed, if network errors occur, of if |
|
there is a timeout. |
|
|
|
Returns a ``(dns.message.Message, float)`` tuple of the received message |
|
and the received time. |
|
""" |
|
|
|
ldata = _net_read(sock, 2, expiration) |
|
(l,) = struct.unpack("!H", ldata) |
|
wire = _net_read(sock, l, expiration) |
|
received_time = time.time() |
|
r = dns.message.from_wire( |
|
wire, |
|
keyring=keyring, |
|
request_mac=request_mac, |
|
one_rr_per_rrset=one_rr_per_rrset, |
|
ignore_trailing=ignore_trailing, |
|
) |
|
return (r, received_time) |
|
|
|
|
|
def _connect(s, address, expiration): |
|
err = s.connect_ex(address) |
|
if err == 0: |
|
return |
|
if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY): |
|
_wait_for_writable(s, expiration) |
|
err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) |
|
if err != 0: |
|
raise OSError(err, os.strerror(err)) |
|
|
|
|
|
def tcp( |
|
q: dns.message.Message, |
|
where: str, |
|
timeout: Optional[float] = None, |
|
port: int = 53, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
one_rr_per_rrset: bool = False, |
|
ignore_trailing: bool = False, |
|
sock: Optional[Any] = None, |
|
) -> dns.message.Message: |
|
"""Return the response obtained after sending a query via TCP. |
|
|
|
*q*, a ``dns.message.Message``, the query to send |
|
|
|
*where*, a ``str`` containing an IPv4 or IPv6 address, where |
|
to send the message. |
|
|
|
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the |
|
query times out. If ``None``, the default, wait forever. |
|
|
|
*port*, an ``int``, the port send the message to. The default is 53. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying |
|
the source address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. |
|
The default is 0. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own |
|
RRset. |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing |
|
junk at end of the received message. |
|
|
|
*sock*, a ``socket.socket``, or ``None``, the connected socket to use for the |
|
query. If ``None``, the default, a socket is created. Note that |
|
if a socket is provided, it must be a nonblocking connected stream |
|
socket, and *where*, *port*, *source* and *source_port* are ignored. |
|
|
|
Returns a ``dns.message.Message``. |
|
""" |
|
|
|
wire = q.to_wire() |
|
(begin_time, expiration) = _compute_times(timeout) |
|
if sock: |
|
cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) |
|
else: |
|
(af, destination, source) = _destination_and_source( |
|
where, port, source, source_port |
|
) |
|
cm = _make_socket(af, socket.SOCK_STREAM, source) |
|
with cm as s: |
|
if not sock: |
|
_connect(s, destination, expiration) |
|
send_tcp(s, wire, expiration) |
|
(r, received_time) = receive_tcp( |
|
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing |
|
) |
|
r.time = received_time - begin_time |
|
if not q.is_response(r): |
|
raise BadResponse |
|
return r |
|
assert ( |
|
False |
|
) |
|
|
|
|
|
def _tls_handshake(s, expiration): |
|
while True: |
|
try: |
|
s.do_handshake() |
|
return |
|
except ssl.SSLWantReadError: |
|
_wait_for_readable(s, expiration) |
|
except ssl.SSLWantWriteError: |
|
_wait_for_writable(s, expiration) |
|
|
|
|
|
def _make_dot_ssl_context( |
|
server_hostname: Optional[str], verify: Union[bool, str] |
|
) -> ssl.SSLContext: |
|
cafile: Optional[str] = None |
|
capath: Optional[str] = None |
|
if isinstance(verify, str): |
|
if os.path.isfile(verify): |
|
cafile = verify |
|
elif os.path.isdir(verify): |
|
capath = verify |
|
else: |
|
raise ValueError("invalid verify string") |
|
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) |
|
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 |
|
if server_hostname is None: |
|
ssl_context.check_hostname = False |
|
ssl_context.set_alpn_protocols(["dot"]) |
|
if verify is False: |
|
ssl_context.verify_mode = ssl.CERT_NONE |
|
return ssl_context |
|
|
|
|
|
def tls( |
|
q: dns.message.Message, |
|
where: str, |
|
timeout: Optional[float] = None, |
|
port: int = 853, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
one_rr_per_rrset: bool = False, |
|
ignore_trailing: bool = False, |
|
sock: Optional[ssl.SSLSocket] = None, |
|
ssl_context: Optional[ssl.SSLContext] = None, |
|
server_hostname: Optional[str] = None, |
|
verify: Union[bool, str] = True, |
|
) -> dns.message.Message: |
|
"""Return the response obtained after sending a query via TLS. |
|
|
|
*q*, a ``dns.message.Message``, the query to send |
|
|
|
*where*, a ``str`` containing an IPv4 or IPv6 address, where |
|
to send the message. |
|
|
|
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the |
|
query times out. If ``None``, the default, wait forever. |
|
|
|
*port*, an ``int``, the port send the message to. The default is 853. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying |
|
the source address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. |
|
The default is 0. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own |
|
RRset. |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing |
|
junk at end of the received message. |
|
|
|
*sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for |
|
the query. If ``None``, the default, a socket is created. Note |
|
that if a socket is provided, it must be a nonblocking connected |
|
SSL stream socket, and *where*, *port*, *source*, *source_port*, |
|
and *ssl_context* are ignored. |
|
|
|
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing |
|
a TLS connection. If ``None``, the default, creates one with the default |
|
configuration. |
|
|
|
*server_hostname*, a ``str`` containing the server's hostname. The |
|
default is ``None``, which means that no hostname is known, and if an |
|
SSL context is created, hostname checking will be disabled. |
|
|
|
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification |
|
of the server is done using the default CA bundle; if ``False``, then no |
|
verification is done; if a `str` then it specifies the path to a certificate file or |
|
directory which will be used for verification. |
|
|
|
Returns a ``dns.message.Message``. |
|
|
|
""" |
|
|
|
if sock: |
|
|
|
|
|
|
|
return tcp( |
|
q, |
|
where, |
|
timeout, |
|
port, |
|
source, |
|
source_port, |
|
one_rr_per_rrset, |
|
ignore_trailing, |
|
sock, |
|
) |
|
|
|
wire = q.to_wire() |
|
(begin_time, expiration) = _compute_times(timeout) |
|
(af, destination, source) = _destination_and_source( |
|
where, port, source, source_port |
|
) |
|
if ssl_context is None and not sock: |
|
ssl_context = _make_dot_ssl_context(server_hostname, verify) |
|
|
|
with _make_socket( |
|
af, |
|
socket.SOCK_STREAM, |
|
source, |
|
ssl_context=ssl_context, |
|
server_hostname=server_hostname, |
|
) as s: |
|
_connect(s, destination, expiration) |
|
_tls_handshake(s, expiration) |
|
send_tcp(s, wire, expiration) |
|
(r, received_time) = receive_tcp( |
|
s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing |
|
) |
|
r.time = received_time - begin_time |
|
if not q.is_response(r): |
|
raise BadResponse |
|
return r |
|
assert ( |
|
False |
|
) |
|
|
|
|
|
def quic( |
|
q: dns.message.Message, |
|
where: str, |
|
timeout: Optional[float] = None, |
|
port: int = 853, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
one_rr_per_rrset: bool = False, |
|
ignore_trailing: bool = False, |
|
connection: Optional[dns.quic.SyncQuicConnection] = None, |
|
verify: Union[bool, str] = True, |
|
server_hostname: Optional[str] = None, |
|
) -> dns.message.Message: |
|
"""Return the response obtained after sending a query via DNS-over-QUIC. |
|
|
|
*q*, a ``dns.message.Message``, the query to send. |
|
|
|
*where*, a ``str``, the nameserver IP address. |
|
|
|
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the query |
|
times out. If ``None``, the default, wait forever. |
|
|
|
*port*, a ``int``, the port to send the query to. The default is 853. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source |
|
address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. The default is |
|
0. |
|
|
|
*one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. |
|
|
|
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the |
|
received message. |
|
|
|
*connection*, a ``dns.quic.SyncQuicConnection``. If provided, the |
|
connection to use to send the query. |
|
|
|
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification |
|
of the server is done using the default CA bundle; if ``False``, then no |
|
verification is done; if a `str` then it specifies the path to a certificate file or |
|
directory which will be used for verification. |
|
|
|
*server_hostname*, a ``str`` containing the server's hostname. The |
|
default is ``None``, which means that no hostname is known, and if an |
|
SSL context is created, hostname checking will be disabled. |
|
|
|
Returns a ``dns.message.Message``. |
|
""" |
|
|
|
if not dns.quic.have_quic: |
|
raise NoDOQ("DNS-over-QUIC is not available.") |
|
|
|
q.id = 0 |
|
wire = q.to_wire() |
|
the_connection: dns.quic.SyncQuicConnection |
|
the_manager: dns.quic.SyncQuicManager |
|
if connection: |
|
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) |
|
the_connection = connection |
|
else: |
|
manager = dns.quic.SyncQuicManager( |
|
verify_mode=verify, server_name=server_hostname |
|
) |
|
the_manager = manager |
|
|
|
with manager: |
|
if not connection: |
|
the_connection = the_manager.connect(where, port, source, source_port) |
|
(start, expiration) = _compute_times(timeout) |
|
with the_connection.make_stream(timeout) as stream: |
|
stream.send(wire, True) |
|
wire = stream.receive(_remaining(expiration)) |
|
finish = time.time() |
|
r = dns.message.from_wire( |
|
wire, |
|
keyring=q.keyring, |
|
request_mac=q.request_mac, |
|
one_rr_per_rrset=one_rr_per_rrset, |
|
ignore_trailing=ignore_trailing, |
|
) |
|
r.time = max(finish - start, 0.0) |
|
if not q.is_response(r): |
|
raise BadResponse |
|
return r |
|
|
|
|
|
def xfr( |
|
where: str, |
|
zone: Union[dns.name.Name, str], |
|
rdtype: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.AXFR, |
|
rdclass: Union[dns.rdataclass.RdataClass, str] = dns.rdataclass.IN, |
|
timeout: Optional[float] = None, |
|
port: int = 53, |
|
keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]] = None, |
|
keyname: Optional[Union[dns.name.Name, str]] = None, |
|
relativize: bool = True, |
|
lifetime: Optional[float] = None, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
serial: int = 0, |
|
use_udp: bool = False, |
|
keyalgorithm: Union[dns.name.Name, str] = dns.tsig.default_algorithm, |
|
) -> Any: |
|
"""Return a generator for the responses to a zone transfer. |
|
|
|
*where*, a ``str`` containing an IPv4 or IPv6 address, where |
|
to send the message. |
|
|
|
*zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer. |
|
|
|
*rdtype*, an ``int`` or ``str``, the type of zone transfer. The |
|
default is ``dns.rdatatype.AXFR``. ``dns.rdatatype.IXFR`` can be |
|
used to do an incremental transfer instead. |
|
|
|
*rdclass*, an ``int`` or ``str``, the class of the zone transfer. |
|
The default is ``dns.rdataclass.IN``. |
|
|
|
*timeout*, a ``float``, the number of seconds to wait for each |
|
response message. If None, the default, wait forever. |
|
|
|
*port*, an ``int``, the port send the message to. The default is 53. |
|
|
|
*keyring*, a ``dict``, the keyring to use for TSIG. |
|
|
|
*keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG |
|
key to use. |
|
|
|
*relativize*, a ``bool``. If ``True``, all names in the zone will be |
|
relativized to the zone origin. It is essential that the |
|
relativize setting matches the one specified to |
|
``dns.zone.from_xfr()`` if using this generator to make a zone. |
|
|
|
*lifetime*, a ``float``, the total number of seconds to spend |
|
doing the transfer. If ``None``, the default, then there is no |
|
limit on the time the transfer may take. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying |
|
the source address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. |
|
The default is 0. |
|
|
|
*serial*, an ``int``, the SOA serial number to use as the base for |
|
an IXFR diff sequence (only meaningful if *rdtype* is |
|
``dns.rdatatype.IXFR``). |
|
|
|
*use_udp*, a ``bool``. If ``True``, use UDP (only meaningful for IXFR). |
|
|
|
*keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. |
|
|
|
Raises on errors, and so does the generator. |
|
|
|
Returns a generator of ``dns.message.Message`` objects. |
|
""" |
|
|
|
if isinstance(zone, str): |
|
zone = dns.name.from_text(zone) |
|
rdtype = dns.rdatatype.RdataType.make(rdtype) |
|
q = dns.message.make_query(zone, rdtype, rdclass) |
|
if rdtype == dns.rdatatype.IXFR: |
|
rrset = dns.rrset.from_text(zone, 0, "IN", "SOA", ". . %u 0 0 0 0" % serial) |
|
q.authority.append(rrset) |
|
if keyring is not None: |
|
q.use_tsig(keyring, keyname, algorithm=keyalgorithm) |
|
wire = q.to_wire() |
|
(af, destination, source) = _destination_and_source( |
|
where, port, source, source_port |
|
) |
|
if use_udp and rdtype != dns.rdatatype.IXFR: |
|
raise ValueError("cannot do a UDP AXFR") |
|
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM |
|
with _make_socket(af, sock_type, source) as s: |
|
(_, expiration) = _compute_times(lifetime) |
|
_connect(s, destination, expiration) |
|
l = len(wire) |
|
if use_udp: |
|
_udp_send(s, wire, None, expiration) |
|
else: |
|
tcpmsg = struct.pack("!H", l) + wire |
|
_net_write(s, tcpmsg, expiration) |
|
done = False |
|
delete_mode = True |
|
expecting_SOA = False |
|
soa_rrset = None |
|
if relativize: |
|
origin = zone |
|
oname = dns.name.empty |
|
else: |
|
origin = None |
|
oname = zone |
|
tsig_ctx = None |
|
while not done: |
|
(_, mexpiration) = _compute_times(timeout) |
|
if mexpiration is None or ( |
|
expiration is not None and mexpiration > expiration |
|
): |
|
mexpiration = expiration |
|
if use_udp: |
|
(wire, _) = _udp_recv(s, 65535, mexpiration) |
|
else: |
|
ldata = _net_read(s, 2, mexpiration) |
|
(l,) = struct.unpack("!H", ldata) |
|
wire = _net_read(s, l, mexpiration) |
|
is_ixfr = rdtype == dns.rdatatype.IXFR |
|
r = dns.message.from_wire( |
|
wire, |
|
keyring=q.keyring, |
|
request_mac=q.mac, |
|
xfr=True, |
|
origin=origin, |
|
tsig_ctx=tsig_ctx, |
|
multi=True, |
|
one_rr_per_rrset=is_ixfr, |
|
) |
|
rcode = r.rcode() |
|
if rcode != dns.rcode.NOERROR: |
|
raise TransferError(rcode) |
|
tsig_ctx = r.tsig_ctx |
|
answer_index = 0 |
|
if soa_rrset is None: |
|
if not r.answer or r.answer[0].name != oname: |
|
raise dns.exception.FormError("No answer or RRset not for qname") |
|
rrset = r.answer[0] |
|
if rrset.rdtype != dns.rdatatype.SOA: |
|
raise dns.exception.FormError("first RRset is not an SOA") |
|
answer_index = 1 |
|
soa_rrset = rrset.copy() |
|
if rdtype == dns.rdatatype.IXFR: |
|
if dns.serial.Serial(soa_rrset[0].serial) <= serial: |
|
|
|
|
|
|
|
done = True |
|
else: |
|
expecting_SOA = True |
|
|
|
|
|
|
|
|
|
for rrset in r.answer[answer_index:]: |
|
if done: |
|
raise dns.exception.FormError("answers after final SOA") |
|
if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname: |
|
if expecting_SOA: |
|
if rrset[0].serial != serial: |
|
raise dns.exception.FormError("IXFR base serial mismatch") |
|
expecting_SOA = False |
|
elif rdtype == dns.rdatatype.IXFR: |
|
delete_mode = not delete_mode |
|
|
|
|
|
|
|
|
|
|
|
if rrset == soa_rrset and ( |
|
rdtype == dns.rdatatype.AXFR |
|
or (rdtype == dns.rdatatype.IXFR and delete_mode) |
|
): |
|
done = True |
|
elif expecting_SOA: |
|
|
|
|
|
|
|
|
|
|
|
rdtype = dns.rdatatype.AXFR |
|
expecting_SOA = False |
|
if done and q.keyring and not r.had_tsig: |
|
raise dns.exception.FormError("missing TSIG") |
|
yield r |
|
|
|
|
|
class UDPMode(enum.IntEnum): |
|
"""How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? |
|
|
|
NEVER means "never use UDP; always use TCP" |
|
TRY_FIRST means "try to use UDP but fall back to TCP if needed" |
|
ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" |
|
""" |
|
|
|
NEVER = 0 |
|
TRY_FIRST = 1 |
|
ONLY = 2 |
|
|
|
|
|
def inbound_xfr( |
|
where: str, |
|
txn_manager: dns.transaction.TransactionManager, |
|
query: Optional[dns.message.Message] = None, |
|
port: int = 53, |
|
timeout: Optional[float] = None, |
|
lifetime: Optional[float] = None, |
|
source: Optional[str] = None, |
|
source_port: int = 0, |
|
udp_mode: UDPMode = UDPMode.NEVER, |
|
) -> None: |
|
"""Conduct an inbound transfer and apply it via a transaction from the |
|
txn_manager. |
|
|
|
*where*, a ``str`` containing an IPv4 or IPv6 address, where |
|
to send the message. |
|
|
|
*txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager |
|
for this transfer (typically a ``dns.zone.Zone``). |
|
|
|
*query*, the query to send. If not supplied, a default query is |
|
constructed using information from the *txn_manager*. |
|
|
|
*port*, an ``int``, the port send the message to. The default is 53. |
|
|
|
*timeout*, a ``float``, the number of seconds to wait for each |
|
response message. If None, the default, wait forever. |
|
|
|
*lifetime*, a ``float``, the total number of seconds to spend |
|
doing the transfer. If ``None``, the default, then there is no |
|
limit on the time the transfer may take. |
|
|
|
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying |
|
the source address. The default is the wildcard address. |
|
|
|
*source_port*, an ``int``, the port from which to send the message. |
|
The default is 0. |
|
|
|
*udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used |
|
for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use |
|
TCP. Other possibilities are ``dns.UDPMode.TRY_FIRST``, which |
|
means "try UDP but fallback to TCP if needed", and |
|
``dns.UDPMode.ONLY``, which means "try UDP and raise |
|
``dns.xfr.UseTCP`` if it does not succeed. |
|
|
|
Raises on errors. |
|
""" |
|
if query is None: |
|
(query, serial) = dns.xfr.make_query(txn_manager) |
|
else: |
|
serial = dns.xfr.extract_serial_from_query(query) |
|
rdtype = query.question[0].rdtype |
|
is_ixfr = rdtype == dns.rdatatype.IXFR |
|
origin = txn_manager.from_wire_origin() |
|
wire = query.to_wire() |
|
(af, destination, source) = _destination_and_source( |
|
where, port, source, source_port |
|
) |
|
(_, expiration) = _compute_times(lifetime) |
|
retry = True |
|
while retry: |
|
retry = False |
|
if is_ixfr and udp_mode != UDPMode.NEVER: |
|
sock_type = socket.SOCK_DGRAM |
|
is_udp = True |
|
else: |
|
sock_type = socket.SOCK_STREAM |
|
is_udp = False |
|
with _make_socket(af, sock_type, source) as s: |
|
_connect(s, destination, expiration) |
|
if is_udp: |
|
_udp_send(s, wire, None, expiration) |
|
else: |
|
tcpmsg = struct.pack("!H", len(wire)) + wire |
|
_net_write(s, tcpmsg, expiration) |
|
with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound: |
|
done = False |
|
tsig_ctx = None |
|
while not done: |
|
(_, mexpiration) = _compute_times(timeout) |
|
if mexpiration is None or ( |
|
expiration is not None and mexpiration > expiration |
|
): |
|
mexpiration = expiration |
|
if is_udp: |
|
(rwire, _) = _udp_recv(s, 65535, mexpiration) |
|
else: |
|
ldata = _net_read(s, 2, mexpiration) |
|
(l,) = struct.unpack("!H", ldata) |
|
rwire = _net_read(s, l, mexpiration) |
|
r = dns.message.from_wire( |
|
rwire, |
|
keyring=query.keyring, |
|
request_mac=query.mac, |
|
xfr=True, |
|
origin=origin, |
|
tsig_ctx=tsig_ctx, |
|
multi=(not is_udp), |
|
one_rr_per_rrset=is_ixfr, |
|
) |
|
try: |
|
done = inbound.process_message(r) |
|
except dns.xfr.UseTCP: |
|
assert is_udp |
|
if udp_mode == UDPMode.ONLY: |
|
raise |
|
done = True |
|
retry = True |
|
udp_mode = UDPMode.NEVER |
|
continue |
|
tsig_ctx = r.tsig_ctx |
|
if not retry and query.keyring and not r.had_tsig: |
|
raise dns.exception.FormError("missing TSIG") |
|
|