Surbhi123's picture
Upload folder using huggingface_hub
64772a4 verified
import asyncio
import functools
import pycares
import socket
import sys
from typing import (
Any,
Optional,
Set,
Sequence,
Tuple,
Union
)
from . import error
__version__ = '3.2.0'
__all__ = ('DNSResolver', 'error')
READ = 1
WRITE = 2
query_type_map = {'A' : pycares.QUERY_TYPE_A,
'AAAA' : pycares.QUERY_TYPE_AAAA,
'ANY' : pycares.QUERY_TYPE_ANY,
'CAA' : pycares.QUERY_TYPE_CAA,
'CNAME' : pycares.QUERY_TYPE_CNAME,
'MX' : pycares.QUERY_TYPE_MX,
'NAPTR' : pycares.QUERY_TYPE_NAPTR,
'NS' : pycares.QUERY_TYPE_NS,
'PTR' : pycares.QUERY_TYPE_PTR,
'SOA' : pycares.QUERY_TYPE_SOA,
'SRV' : pycares.QUERY_TYPE_SRV,
'TXT' : pycares.QUERY_TYPE_TXT
}
query_class_map = {'IN' : pycares.QUERY_CLASS_IN,
'CHAOS' : pycares.QUERY_CLASS_CHAOS,
'HS' : pycares.QUERY_CLASS_HS,
'NONE' : pycares.QUERY_CLASS_NONE,
'ANY' : pycares.QUERY_CLASS_ANY
}
class DNSResolver:
def __init__(self, nameservers: Optional[Sequence[str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any) -> None:
self.loop = loop or asyncio.get_event_loop()
assert self.loop is not None
if sys.platform == 'win32':
if not isinstance(self.loop, asyncio.SelectorEventLoop):
try:
import winloop
if not isinstance(self.loop , winloop.Loop):
raise RuntimeError(
'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86')
except ModuleNotFoundError:
raise RuntimeError(
'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86')
kwargs.pop('sock_state_cb', None)
timeout = kwargs.pop('timeout', None)
self._timeout = timeout
self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb,
timeout=timeout,
**kwargs)
if nameservers:
self.nameservers = nameservers
self._read_fds = set() # type: Set[int]
self._write_fds = set() # type: Set[int]
self._timer = None # type: Optional[asyncio.TimerHandle]
@property
def nameservers(self) -> Sequence[str]:
return self._channel.servers
@nameservers.setter
def nameservers(self, value: Sequence[str]) -> None:
self._channel.servers = value
@staticmethod
def _callback(fut: asyncio.Future, result: Any, errorno: int) -> None:
if fut.cancelled():
return
if errorno is not None:
fut.set_exception(error.DNSError(errorno, pycares.errno.strerror(errorno)))
else:
fut.set_result(result)
def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Future:
try:
qtype = query_type_map[qtype]
except KeyError:
raise ValueError('invalid query type: {}'.format(qtype))
if qclass is not None:
try:
qclass = query_class_map[qclass]
except KeyError:
raise ValueError('invalid query class: {}'.format(qclass))
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.query(host, qtype, cb, query_class=qclass)
return fut
def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.gethostbyname(host, family, cb)
return fut
def getaddrinfo(self, host: str, family: socket.AddressFamily = socket.AF_UNSPEC, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags)
return fut
def getnameinfo(self, sockaddr: Union[Tuple[str, int], Tuple[str, int, int, int]], flags: int = 0) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.getnameinfo(sockaddr, flags, cb)
return fut
def gethostbyaddr(self, name: str) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.gethostbyaddr(name, cb)
return fut
def cancel(self) -> None:
self._channel.cancel()
def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
if readable or writable:
if readable:
self.loop.add_reader(fd, self._handle_event, fd, READ)
self._read_fds.add(fd)
if writable:
self.loop.add_writer(fd, self._handle_event, fd, WRITE)
self._write_fds.add(fd)
if self._timer is None:
self._start_timer()
else:
# socket is now closed
if fd in self._read_fds:
self._read_fds.discard(fd)
self.loop.remove_reader(fd)
if fd in self._write_fds:
self._write_fds.discard(fd)
self.loop.remove_writer(fd)
if not self._read_fds and not self._write_fds and self._timer is not None:
self._timer.cancel()
self._timer = None
def _handle_event(self, fd: int, event: Any) -> None:
read_fd = pycares.ARES_SOCKET_BAD
write_fd = pycares.ARES_SOCKET_BAD
if event == READ:
read_fd = fd
elif event == WRITE:
write_fd = fd
self._channel.process_fd(read_fd, write_fd)
def _timer_cb(self) -> None:
if self._read_fds or self._write_fds:
self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD)
self._start_timer()
else:
self._timer = None
def _start_timer(self):
timeout = self._timeout
if timeout is None or timeout < 0 or timeout > 1:
timeout = 1
elif timeout == 0:
timeout = 0.1
self._timer = self.loop.call_later(timeout, self._timer_cb)