File size: 2,479 Bytes
90e26fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import re
import asyncio
import requests
import hivemind
import functools
from async_timeout import timeout
from petals.server.handler import TransformerConnectionHandler

info_cache = hivemind.TimedStorage()


async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True):
    if use_cache:
        entry = info_cache.get(peer_id)
        if entry is not None:
            return entry.value

    try:
        with timeout(connect_timeout):
            if fetch_info:  # For Petals servers
                stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id)
                response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID())
                rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info)
                rpc_info["ok"] = True
            else:  # For DHT-only bootstrap peers
                await node.p2p._client.connect(peer_id, [])
                await node.p2p._client.disconnect(peer_id)
                rpc_info = {"ok": True}
    except Exception as e:
        # Actual connection error
        if not isinstance(e, asyncio.TimeoutError):
            message = str(e) if str(e) else repr(e)
            if message == "protocol not supported":
                # This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1
                return {"ok": True}
        else:
            message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections"
        rpc_info = {"ok": False, "error": message}

    info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration)
    return rpc_info


async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False):
    rpc_infos = await asyncio.gather(
        *[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids]
    )
    return dict(zip(peer_ids, rpc_infos))


async def get_peers_ips(dht, dht_node):
    return await dht_node.p2p.list_peers()

@functools.cache
def get_location(ip_address):
    try:
        response = requests.get(f"http://ip-api.com/json/{ip_address}")
        if response.status_code == 200:
            return response.json()
    except Exception:
        pass
    return {}

def extract_peer_ip_info(multiaddr_str):
    if ip_match := re.search(r"/ip4/(\d+\.\d+\.\d+\.\d+)", multiaddr_str):
        return get_location(ip_match[1])
    return {}