Spaces:
Sleeping
Sleeping
import re | |
import asyncio | |
import requests | |
import hivemind | |
import functools | |
from async_timeout import timeout | |
from petals.server.handler import TransformerConnectionHandler | |
info_cache = hivemind.TimedStorage() | |
async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True): | |
if use_cache: | |
entry = info_cache.get(peer_id) | |
if entry is not None: | |
return entry.value | |
try: | |
with timeout(connect_timeout): | |
if fetch_info: # For Petals servers | |
stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id) | |
response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID()) | |
rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info) | |
rpc_info["ok"] = True | |
else: # For DHT-only bootstrap peers | |
await node.p2p._client.connect(peer_id, []) | |
await node.p2p._client.disconnect(peer_id) | |
rpc_info = {"ok": True} | |
except Exception as e: | |
# Actual connection error | |
if not isinstance(e, asyncio.TimeoutError): | |
message = str(e) if str(e) else repr(e) | |
if message == "protocol not supported": | |
# This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1 | |
return {"ok": True} | |
else: | |
message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections" | |
rpc_info = {"ok": False, "error": message} | |
info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration) | |
return rpc_info | |
async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False): | |
rpc_infos = await asyncio.gather( | |
*[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids] | |
) | |
return dict(zip(peer_ids, rpc_infos)) | |
async def get_peers_ips(dht, dht_node): | |
return await dht_node.p2p.list_peers() | |
def get_location(ip_address): | |
try: | |
response = requests.get(f"http://ip-api.com/json/{ip_address}") | |
if response.status_code == 200: | |
return response.json() | |
except Exception: | |
pass | |
return {} | |
def extract_peer_ip_info(multiaddr_str): | |
if ip_match := re.search(r"/ip4/(\d+\.\d+\.\d+\.\d+)", multiaddr_str): | |
return get_location(ip_match[1]) | |
return {} | |