api_for_unity / p2p_utils.py
ldhldh's picture
Upload 8 files
90e26fa
raw
history blame
No virus
2.48 kB
import re
import asyncio
import requests
import hivemind
import functools
from async_timeout import timeout
from petals.server.handler import TransformerConnectionHandler
info_cache = hivemind.TimedStorage()
async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True):
if use_cache:
entry = info_cache.get(peer_id)
if entry is not None:
return entry.value
try:
with timeout(connect_timeout):
if fetch_info: # For Petals servers
stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id)
response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID())
rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info)
rpc_info["ok"] = True
else: # For DHT-only bootstrap peers
await node.p2p._client.connect(peer_id, [])
await node.p2p._client.disconnect(peer_id)
rpc_info = {"ok": True}
except Exception as e:
# Actual connection error
if not isinstance(e, asyncio.TimeoutError):
message = str(e) if str(e) else repr(e)
if message == "protocol not supported":
# This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1
return {"ok": True}
else:
message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections"
rpc_info = {"ok": False, "error": message}
info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration)
return rpc_info
async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False):
rpc_infos = await asyncio.gather(
*[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids]
)
return dict(zip(peer_ids, rpc_infos))
async def get_peers_ips(dht, dht_node):
return await dht_node.p2p.list_peers()
@functools.cache
def get_location(ip_address):
try:
response = requests.get(f"http://ip-api.com/json/{ip_address}")
if response.status_code == 200:
return response.json()
except Exception:
pass
return {}
def extract_peer_ip_info(multiaddr_str):
if ip_match := re.search(r"/ip4/(\d+\.\d+\.\d+\.\d+)", multiaddr_str):
return get_location(ip_match[1])
return {}