Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- app.py +34 -170
- config.py +31 -0
- data_structures.py +20 -0
- health.py +124 -0
- metrics.py +118 -0
- p2p_utils.py +67 -0
- pyproject.toml +10 -0
- state_updater.py +57 -0
app.py
CHANGED
@@ -1,187 +1,51 @@
|
|
1 |
-
from
|
2 |
-
import gradio as gr
|
3 |
-
import inspect
|
4 |
-
from gradio import routes
|
5 |
-
from typing import List, Type
|
6 |
|
7 |
-
import requests, os, re, asyncio, queue, sys, git
|
8 |
-
import math
|
9 |
-
import time
|
10 |
-
import datetime
|
11 |
-
import requests, json
|
12 |
-
|
13 |
-
from pprint import pprint
|
14 |
import hivemind
|
15 |
-
from
|
16 |
-
from
|
17 |
-
|
18 |
-
dht = hivemind.DHT(initial_peers=PUBLIC_INITIAL_PEERS, client_mode=True, start=True)
|
19 |
-
model_name = "quantumaikr/llama-2-70b-fb16-korean"
|
20 |
-
|
21 |
-
loop = asyncio.get_event_loop()
|
22 |
-
# Monkey patch
|
23 |
-
def get_types(cls_set: List[Type], component: str):
|
24 |
-
docset = []
|
25 |
-
types = []
|
26 |
-
if component == "input":
|
27 |
-
for cls in cls_set:
|
28 |
-
doc = inspect.getdoc(cls)
|
29 |
-
doc_lines = doc.split("\n")
|
30 |
-
docset.append(doc_lines[1].split(":")[-1])
|
31 |
-
types.append(doc_lines[1].split(")")[0].split("(")[-1])
|
32 |
-
else:
|
33 |
-
for cls in cls_set:
|
34 |
-
doc = inspect.getdoc(cls)
|
35 |
-
doc_lines = doc.split("\n")
|
36 |
-
docset.append(doc_lines[-1].split(":")[-1])
|
37 |
-
types.append(doc_lines[-1].split(")")[0].split("(")[-1])
|
38 |
-
return docset, types
|
39 |
-
routes.get_types = get_types
|
40 |
-
|
41 |
-
# App code
|
42 |
-
|
43 |
-
account_list = dict()
|
44 |
-
|
45 |
-
account_list['id'] = "pass"
|
46 |
-
|
47 |
-
name_list = dict()
|
48 |
-
name_list['id'] = 'name'
|
49 |
-
|
50 |
-
p2p_list = dict()
|
51 |
-
p2p_list['id'] = '11111111'
|
52 |
-
|
53 |
-
def chat(x):
|
54 |
-
|
55 |
-
return "AI μλ΅μ
λλ€."
|
56 |
-
|
57 |
-
|
58 |
-
def register(id, pw):
|
59 |
-
if id in account_list:
|
60 |
-
return "exist"
|
61 |
-
else:
|
62 |
-
account_list[id] = pw
|
63 |
-
return "ok"
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
return "ok"
|
69 |
-
else:
|
70 |
-
return "password error"
|
71 |
-
else:
|
72 |
-
return "no id"
|
73 |
|
74 |
-
|
75 |
-
name_list[id] = name
|
76 |
-
return "ok"
|
77 |
|
78 |
-
def get_name(id):
|
79 |
-
if id in name_list:
|
80 |
-
return name_list[id]
|
81 |
-
else:
|
82 |
-
return "no id"
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
if name in reverse_dict:
|
87 |
-
return reverse_dict[name]
|
88 |
-
else:
|
89 |
-
return "no name"
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
return "no id"
|
100 |
|
101 |
-
def get_id_from_p2p(i):
|
102 |
-
reverse_dict= dict(map(reversed,p2p_list.items()))
|
103 |
-
if i in reverse_dict:
|
104 |
-
return reverse_dict[i]
|
105 |
-
else:
|
106 |
-
return "no id"
|
107 |
|
108 |
-
|
|
|
|
|
109 |
|
110 |
-
def get_peers():
|
111 |
-
data = fetch_health_state(dht)
|
112 |
-
out = []
|
113 |
-
for d in data['model_reports']:
|
114 |
-
if d['name'] == model_name:
|
115 |
-
for r in d['server_rows']:
|
116 |
-
out.append(r['peer_id'])
|
117 |
|
118 |
-
|
|
|
|
|
119 |
|
120 |
-
get_peers()
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
129 |
)
|
130 |
|
131 |
-
rr = gr.Interface(
|
132 |
-
fn=register,
|
133 |
-
inputs=["text", "text"],
|
134 |
-
outputs="text",
|
135 |
-
description="register, νμκ°μ
(μ±κ³΅μ:ok, μ€λ³΅μ:exist λ°ν)\n /run/predict_1",
|
136 |
-
)
|
137 |
-
|
138 |
-
ll = gr.Interface(
|
139 |
-
fn=login,
|
140 |
-
inputs=["text", "text"],
|
141 |
-
outputs="text",
|
142 |
-
description="login, λ‘κ·ΈμΈ(μ±κ³΅μ: ok, μ€ν¨μ: password error, μμ΄λκ° μμΌλ©΄: no id) \n /run/predict_2",
|
143 |
-
)
|
144 |
-
|
145 |
-
ad = gr.Interface(
|
146 |
-
fn=add_name,
|
147 |
-
inputs=["text", "text"],
|
148 |
-
outputs="text",
|
149 |
-
description="add_name, idλ‘ λλ€μ μΆκ°. ok λ°ν.\n /run/predict_3",
|
150 |
-
)
|
151 |
-
|
152 |
-
nn = gr.Interface(
|
153 |
-
fn=get_name,
|
154 |
-
inputs=["text"],
|
155 |
-
outputs="text",
|
156 |
-
description="get_name, idλ‘ λλ€μ λ°ν(μμΌλ©΄ no id)\n /run/predict_4",
|
157 |
-
)
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
description="get_name, λλ€μμΌλ‘ id λ°ν(μμΌλ©΄ no name)\n /run/predict_5",
|
164 |
-
)
|
165 |
-
|
166 |
-
adp = gr.Interface(
|
167 |
-
fn=add_p,
|
168 |
-
inputs=["text", "text"],
|
169 |
-
outputs="text",
|
170 |
-
description="add_p, idλ‘ p2p id μΆκ°. ok λ°ν. \n /run/predict_6",
|
171 |
-
)
|
172 |
-
|
173 |
-
nnp = gr.Interface(
|
174 |
-
fn=get_p,
|
175 |
-
inputs=["text"],
|
176 |
-
outputs="text",
|
177 |
-
description="get_p, idλ‘ p2p id λ°ν. μμΌλ©΄ no id. \n /run/predict_7",
|
178 |
-
)
|
179 |
-
|
180 |
-
nnp = gr.Interface(
|
181 |
-
fn=get_id_from_p2p,
|
182 |
-
inputs=["text"],
|
183 |
-
outputs="text",
|
184 |
-
description="get_p, p2p idλ‘ μΌλ° id λ°ν. μμΌλ©΄ no id. \n /run/predict_8",
|
185 |
-
)
|
186 |
-
|
187 |
-
demo.queue(max_size=32).launch(enable_queue=True)
|
|
|
1 |
+
from functools import partial
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import hivemind
|
4 |
+
from flask import Flask, jsonify, request
|
5 |
+
from flask_cors import CORS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
import config
|
8 |
+
from p2p_utils import check_reachability
|
9 |
+
from state_updater import StateUpdaterThread
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
logger = hivemind.get_logger(__name__)
|
|
|
|
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
logger.info("Connecting to DHT")
|
15 |
+
dht = hivemind.DHT(initial_peers=config.INITIAL_PEERS, client_mode=True, num_workers=32, start=True)
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
logger.info("Starting Flask app")
|
18 |
+
app = Flask(__name__)
|
19 |
+
CORS(app)
|
20 |
|
21 |
+
logger.info("Starting updater")
|
22 |
+
updater = StateUpdaterThread(dht, app, daemon=True)
|
23 |
+
updater.start()
|
24 |
+
updater.ready.wait()
|
|
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
@app.route("/")
|
28 |
+
def main_page():
|
29 |
+
return updater.state_html
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
@app.route("/api/v1/state")
|
33 |
+
def api_v1_state():
|
34 |
+
return app.response_class(response=updater.state_json, status=200, mimetype="application/json")
|
35 |
|
|
|
36 |
|
37 |
+
@app.route("/api/v1/is_reachable/<peer_id>")
|
38 |
+
def api_v1_is_reachable(peer_id):
|
39 |
+
peer_id = hivemind.PeerID.from_base58(peer_id)
|
40 |
+
rpc_info = dht.run_coroutine(partial(check_reachability, peer_id, use_cache=False))
|
41 |
+
return jsonify(
|
42 |
+
success=rpc_info["ok"],
|
43 |
+
message=rpc_info.get("error"),
|
44 |
+
your_ip=request.remote_addr,
|
45 |
)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
@app.route("/metrics")
|
49 |
+
@app.route("/api/prometheus")
|
50 |
+
def metrics():
|
51 |
+
return app.response_class(response=updater.prometheus_metrics, status=200, mimetype="text/plain")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from petals.constants import PUBLIC_INITIAL_PEERS
|
2 |
+
|
3 |
+
from data_structures import ModelInfo
|
4 |
+
|
5 |
+
INITIAL_PEERS = PUBLIC_INITIAL_PEERS
|
6 |
+
|
7 |
+
MODELS = [
|
8 |
+
ModelInfo(
|
9 |
+
dht_prefix="StableBeluga2-hf",
|
10 |
+
repository="https://huggingface.co/petals-team/StableBeluga2",
|
11 |
+
num_blocks=80,
|
12 |
+
),
|
13 |
+
ModelInfo(
|
14 |
+
dht_prefix="falcon-180B-chat",
|
15 |
+
repository="https://huggingface.co/tiiuae/falcon-180B-chat",
|
16 |
+
num_blocks=80,
|
17 |
+
limited=True,
|
18 |
+
),
|
19 |
+
ModelInfo(
|
20 |
+
dht_prefix="Llama-2-70b-chat-hf",
|
21 |
+
repository="https://huggingface.co/meta-llama/Llama-2-70b-chat-hf",
|
22 |
+
num_blocks=80,
|
23 |
+
),
|
24 |
+
ModelInfo(
|
25 |
+
dht_prefix="Llama-2-70b-hf",
|
26 |
+
repository="https://huggingface.co/meta-llama/Llama-2-70b-hf",
|
27 |
+
num_blocks=80,
|
28 |
+
),
|
29 |
+
]
|
30 |
+
|
31 |
+
UPDATE_PERIOD = 60
|
data_structures.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from urllib.parse import urlparse
|
3 |
+
|
4 |
+
import petals
|
5 |
+
import pydantic
|
6 |
+
|
7 |
+
|
8 |
+
@pydantic.dataclasses.dataclass
|
9 |
+
class ModelInfo(petals.data_structures.ModelInfo):
|
10 |
+
dht_prefix: Optional[str] = None
|
11 |
+
official: bool = True
|
12 |
+
limited: bool = False
|
13 |
+
|
14 |
+
@property
|
15 |
+
def name(self) -> str:
|
16 |
+
return urlparse(self.repository).path.lstrip("/")
|
17 |
+
|
18 |
+
@property
|
19 |
+
def short_name(self) -> str:
|
20 |
+
return self.name.split("/")[-1]
|
health.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import time
|
3 |
+
from collections import Counter
|
4 |
+
from contextlib import suppress
|
5 |
+
from dataclasses import asdict
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import hivemind
|
9 |
+
import numpy as np
|
10 |
+
from multiaddr import Multiaddr
|
11 |
+
from petals.data_structures import UID_DELIMITER, ServerState
|
12 |
+
from petals.utils.dht import compute_spans, get_remote_module_infos
|
13 |
+
|
14 |
+
import config
|
15 |
+
from data_structures import ModelInfo
|
16 |
+
from p2p_utils import check_reachability_parallel, get_peers_ips, extract_peer_ip_info
|
17 |
+
|
18 |
+
logger = hivemind.get_logger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def fetch_health_state(dht: hivemind.DHT) -> dict:
|
22 |
+
start_time = time.perf_counter()
|
23 |
+
bootstrap_peer_ids = []
|
24 |
+
for addr in config.INITIAL_PEERS:
|
25 |
+
peer_id = hivemind.PeerID.from_base58(Multiaddr(addr)["p2p"])
|
26 |
+
if peer_id not in bootstrap_peer_ids:
|
27 |
+
bootstrap_peer_ids.append(peer_id)
|
28 |
+
|
29 |
+
reach_infos = dht.run_coroutine(partial(check_reachability_parallel, bootstrap_peer_ids))
|
30 |
+
bootstrap_states = ["online" if reach_infos[peer_id]["ok"] else "unreachable" for peer_id in bootstrap_peer_ids]
|
31 |
+
|
32 |
+
models = config.MODELS[:]
|
33 |
+
model_index = dht.get("_petals.models", latest=True)
|
34 |
+
if model_index is not None and isinstance(model_index.value, dict):
|
35 |
+
official_dht_prefixes = {model.dht_prefix for model in models}
|
36 |
+
custom_models = []
|
37 |
+
for dht_prefix, model in model_index.value.items():
|
38 |
+
if dht_prefix in official_dht_prefixes:
|
39 |
+
continue
|
40 |
+
with suppress(TypeError, ValueError):
|
41 |
+
model_info = ModelInfo.from_dict(model.value)
|
42 |
+
if model_info.repository is None or not model_info.repository.startswith("https://huggingface.co/"):
|
43 |
+
continue
|
44 |
+
model_info.dht_prefix = dht_prefix
|
45 |
+
model_info.official = False
|
46 |
+
custom_models.append(model_info)
|
47 |
+
models.extend(sorted(custom_models, key=lambda info: (-info.num_blocks, info.dht_prefix)))
|
48 |
+
logger.info(f"Fetching info for models {[info.name for info in models]}")
|
49 |
+
|
50 |
+
block_uids = [f"{model.dht_prefix}{UID_DELIMITER}{i}" for model in models for i in range(model.num_blocks)]
|
51 |
+
module_infos = get_remote_module_infos(dht, block_uids, latest=True)
|
52 |
+
|
53 |
+
model_servers = {}
|
54 |
+
all_servers = {}
|
55 |
+
offset = 0
|
56 |
+
for model in models:
|
57 |
+
model_servers[model.dht_prefix] = compute_spans(
|
58 |
+
module_infos[offset : offset + model.num_blocks], min_state=ServerState.OFFLINE
|
59 |
+
)
|
60 |
+
all_servers.update(model_servers[model.dht_prefix])
|
61 |
+
offset += model.num_blocks
|
62 |
+
|
63 |
+
online_servers = [peer_id for peer_id, span in all_servers.items() if span.state == ServerState.ONLINE]
|
64 |
+
|
65 |
+
reach_infos.update(dht.run_coroutine(partial(check_reachability_parallel, online_servers, fetch_info=True)))
|
66 |
+
peers_info = {str(peer.peer_id): {"location": extract_peer_ip_info(str(peer.addrs[0])), "multiaddrs": [str(multiaddr) for multiaddr in peer.addrs]} for peer in dht.run_coroutine(get_peers_ips)}
|
67 |
+
|
68 |
+
top_contributors = Counter()
|
69 |
+
model_reports = []
|
70 |
+
for model in models:
|
71 |
+
block_healthy = np.zeros(model.num_blocks, dtype=bool)
|
72 |
+
server_rows = []
|
73 |
+
for peer_id, span in sorted(model_servers[model.dht_prefix].items()):
|
74 |
+
reachable = reach_infos[peer_id]["ok"] if peer_id in reach_infos else True
|
75 |
+
state = span.state.name.lower() if reachable else "unreachable"
|
76 |
+
if state == "online":
|
77 |
+
block_healthy[span.start : span.end] = True
|
78 |
+
|
79 |
+
show_public_name = state == "online" and span.length >= 10
|
80 |
+
if model.official and span.server_info.public_name and show_public_name:
|
81 |
+
top_contributors[span.server_info.public_name] += span.length
|
82 |
+
|
83 |
+
row = {
|
84 |
+
"short_peer_id": "..." + str(peer_id)[-6:],
|
85 |
+
"peer_id": peer_id,
|
86 |
+
"peer_ip_info": peers_info.get(str(peer_id), "unknown"),
|
87 |
+
"show_public_name": show_public_name,
|
88 |
+
"state": state,
|
89 |
+
"span": span,
|
90 |
+
"adapters": [dict(name=name, short_name=name.split("/")[-1]) for name in span.server_info.adapters],
|
91 |
+
"pings_to_me": {
|
92 |
+
str(origin_id): origin.server_info.next_pings[str(peer_id)]
|
93 |
+
for origin_id, origin in model_servers[model.dht_prefix].items()
|
94 |
+
if origin.server_info.next_pings is not None and str(peer_id) in origin.server_info.next_pings
|
95 |
+
},
|
96 |
+
}
|
97 |
+
if span.server_info.cache_tokens_left is not None:
|
98 |
+
# We use num_blocks * 2 to account for both keys and values
|
99 |
+
row["cache_tokens_left_per_block"] = span.server_info.cache_tokens_left // (span.length * 2)
|
100 |
+
server_rows.append(row)
|
101 |
+
|
102 |
+
model_reports.append(
|
103 |
+
dict(
|
104 |
+
name=model.name,
|
105 |
+
short_name=model.short_name,
|
106 |
+
state="healthy" if block_healthy.all() else "broken",
|
107 |
+
server_rows=server_rows,
|
108 |
+
**asdict(model),
|
109 |
+
)
|
110 |
+
)
|
111 |
+
|
112 |
+
reachability_issues = [
|
113 |
+
dict(peer_id=peer_id, err=info["error"]) for peer_id, info in sorted(reach_infos.items()) if not info["ok"]
|
114 |
+
]
|
115 |
+
|
116 |
+
return dict(
|
117 |
+
bootstrap_states=bootstrap_states,
|
118 |
+
top_contributors=top_contributors,
|
119 |
+
model_reports=model_reports,
|
120 |
+
reachability_issues=reachability_issues,
|
121 |
+
last_updated=datetime.datetime.now(datetime.timezone.utc),
|
122 |
+
update_period=config.UPDATE_PERIOD,
|
123 |
+
update_duration=time.perf_counter() - start_time
|
124 |
+
)
|
metrics.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter, defaultdict
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def get_servers_metrics(model_reports) -> List[str]:
|
8 |
+
servers_num_total = 0
|
9 |
+
servers_num_relay = 0
|
10 |
+
num_peers = 0
|
11 |
+
pings = []
|
12 |
+
num_ping_infs = 0
|
13 |
+
version_counts = Counter()
|
14 |
+
result = ["# SERVER LEVEL METRICS"]
|
15 |
+
|
16 |
+
for model_reports in model_reports:
|
17 |
+
for server in model_reports["server_rows"]:
|
18 |
+
if server["span"].server_info is not None:
|
19 |
+
next_pings = server["span"].server_info.next_pings
|
20 |
+
if next_pings is not None:
|
21 |
+
servers_num_total += 1
|
22 |
+
num_peers += len(next_pings)
|
23 |
+
pings_not_inf = [v for k, v in next_pings.items() if v != float("inf")]
|
24 |
+
pings.extend(pings_not_inf)
|
25 |
+
num_ping_infs += len([v for v in next_pings.values() if v == float("inf")])
|
26 |
+
|
27 |
+
if server["span"].server_info.using_relay:
|
28 |
+
servers_num_relay += 1
|
29 |
+
|
30 |
+
version = server["span"].server_info.version
|
31 |
+
if version:
|
32 |
+
version_counts[version] += 1
|
33 |
+
|
34 |
+
if servers_num_total > 0 and pings:
|
35 |
+
peers_per_srv = (len(pings) + num_ping_infs) / servers_num_total
|
36 |
+
pings_inf_share = num_ping_infs / (num_ping_infs + len(pings))
|
37 |
+
|
38 |
+
result.extend(
|
39 |
+
[
|
40 |
+
f"peers_per_srv {peers_per_srv:.1f}",
|
41 |
+
f"pings_inf_share {pings_inf_share:.3f}",
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
result.append(f"servers_num_total {servers_num_total}")
|
46 |
+
result.append(f"servers_num_relay {servers_num_relay}")
|
47 |
+
|
48 |
+
if pings:
|
49 |
+
result.append("# PINGS")
|
50 |
+
pings = np.sort(pings).tolist()
|
51 |
+
for pct in (25, 50, 75, 90, 95):
|
52 |
+
result.append(f'ping_pct{{pct="{pct}"}} {np.percentile(pings, pct):.4f}')
|
53 |
+
|
54 |
+
result.append("# VERSIONS")
|
55 |
+
for version_number, version_count in version_counts.items():
|
56 |
+
result.append(f'server_version{{version_number="{version_number}"}} {version_count}')
|
57 |
+
|
58 |
+
return result
|
59 |
+
|
60 |
+
|
61 |
+
def get_models_metrics(model_reports) -> List[str]:
|
62 |
+
result = [
|
63 |
+
"# MODEL LEVEL METRICS",
|
64 |
+
]
|
65 |
+
|
66 |
+
for model_reports in model_reports:
|
67 |
+
model_name = model_reports["dht_prefix"]
|
68 |
+
|
69 |
+
result.append(f"# MODEL: {model_name} {'-' * 50}")
|
70 |
+
|
71 |
+
blocks = defaultdict(lambda: np.zeros(model_reports["num_blocks"]))
|
72 |
+
|
73 |
+
for server in model_reports["server_rows"]:
|
74 |
+
for block_idx in range(server["span"].start, server["span"].end):
|
75 |
+
blocks["total"][block_idx] += 1
|
76 |
+
blocks[server["state"]][block_idx] += 1
|
77 |
+
|
78 |
+
if server["span"].server_info is not None:
|
79 |
+
for rps in ("network_rps", "inference_rps", "forward_rps"):
|
80 |
+
rps_value = getattr(server["span"].server_info, rps, 0)
|
81 |
+
if rps_value is not None:
|
82 |
+
blocks[rps][block_idx] += rps_value
|
83 |
+
|
84 |
+
result.extend(
|
85 |
+
[
|
86 |
+
f'n_blocks{{model="{model_name}"}} {model_reports["num_blocks"]}',
|
87 |
+
f'servers_num{{model="{model_name}"}} {len(model_reports["server_rows"])}',
|
88 |
+
f'blocks_total{{model="{model_name}"}} {blocks["total"].sum()}',
|
89 |
+
f'blocks_online_min{{model="{model_name}"}} {blocks["online"].min()}',
|
90 |
+
]
|
91 |
+
)
|
92 |
+
|
93 |
+
for block_state in ("online", "joining", "offline", "unreachable"):
|
94 |
+
result.append(f'blocks{{model="{model_name}",state="{block_state}"}} {blocks[block_state].sum():.0f}')
|
95 |
+
|
96 |
+
for rps in ("network_rps", "inference_rps", "forward_rps"):
|
97 |
+
rps_type = rps.split("_")[0]
|
98 |
+
result.append(f'rps_avg{{model="{model_name}",rps="{rps_type}"}} {blocks[rps].mean():.1f}')
|
99 |
+
result.append(f'rps_min{{model="{model_name}",rps="{rps_type}"}} {blocks[rps].min():.1f}')
|
100 |
+
|
101 |
+
return result
|
102 |
+
|
103 |
+
|
104 |
+
def get_prometheus_metrics(state_dict) -> str:
|
105 |
+
"""prepares metrics in Prometeus format
|
106 |
+
description: https://prometheus.io/docs/instrumenting/exposition_formats/
|
107 |
+
returns multline string with single metric per line
|
108 |
+
"""
|
109 |
+
result = []
|
110 |
+
|
111 |
+
result.append("# GENERAL METRICS")
|
112 |
+
result.append(f"update_duration {state_dict.get('update_duration', None):.1f}")
|
113 |
+
|
114 |
+
result.extend(get_servers_metrics(state_dict["model_reports"]))
|
115 |
+
|
116 |
+
result.extend(get_models_metrics(state_dict["model_reports"]))
|
117 |
+
|
118 |
+
return "\n".join(result)
|
p2p_utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import asyncio
|
3 |
+
import requests
|
4 |
+
import hivemind
|
5 |
+
import functools
|
6 |
+
from async_timeout import timeout
|
7 |
+
from petals.server.handler import TransformerConnectionHandler
|
8 |
+
|
9 |
+
info_cache = hivemind.TimedStorage()
|
10 |
+
|
11 |
+
|
12 |
+
async def check_reachability(peer_id, _, node, *, fetch_info=False, connect_timeout=5, expiration=300, use_cache=True):
|
13 |
+
if use_cache:
|
14 |
+
entry = info_cache.get(peer_id)
|
15 |
+
if entry is not None:
|
16 |
+
return entry.value
|
17 |
+
|
18 |
+
try:
|
19 |
+
with timeout(connect_timeout):
|
20 |
+
if fetch_info: # For Petals servers
|
21 |
+
stub = TransformerConnectionHandler.get_stub(node.p2p, peer_id)
|
22 |
+
response = await stub.rpc_info(hivemind.proto.runtime_pb2.ExpertUID())
|
23 |
+
rpc_info = hivemind.MSGPackSerializer.loads(response.serialized_info)
|
24 |
+
rpc_info["ok"] = True
|
25 |
+
else: # For DHT-only bootstrap peers
|
26 |
+
await node.p2p._client.connect(peer_id, [])
|
27 |
+
await node.p2p._client.disconnect(peer_id)
|
28 |
+
rpc_info = {"ok": True}
|
29 |
+
except Exception as e:
|
30 |
+
# Actual connection error
|
31 |
+
if not isinstance(e, asyncio.TimeoutError):
|
32 |
+
message = str(e) if str(e) else repr(e)
|
33 |
+
if message == "protocol not supported":
|
34 |
+
# This may be returned when a server is joining, see https://github.com/petals-infra/health.petals.dev/issues/1
|
35 |
+
return {"ok": True}
|
36 |
+
else:
|
37 |
+
message = f"Failed to connect in {connect_timeout:.0f} sec. Firewall may be blocking connections"
|
38 |
+
rpc_info = {"ok": False, "error": message}
|
39 |
+
|
40 |
+
info_cache.store(peer_id, rpc_info, hivemind.get_dht_time() + expiration)
|
41 |
+
return rpc_info
|
42 |
+
|
43 |
+
|
44 |
+
async def check_reachability_parallel(peer_ids, dht, node, *, fetch_info=False):
|
45 |
+
rpc_infos = await asyncio.gather(
|
46 |
+
*[check_reachability(peer_id, dht, node, fetch_info=fetch_info) for peer_id in peer_ids]
|
47 |
+
)
|
48 |
+
return dict(zip(peer_ids, rpc_infos))
|
49 |
+
|
50 |
+
|
51 |
+
async def get_peers_ips(dht, dht_node):
|
52 |
+
return await dht_node.p2p.list_peers()
|
53 |
+
|
54 |
+
@functools.cache
|
55 |
+
def get_location(ip_address):
|
56 |
+
try:
|
57 |
+
response = requests.get(f"http://ip-api.com/json/{ip_address}")
|
58 |
+
if response.status_code == 200:
|
59 |
+
return response.json()
|
60 |
+
except Exception:
|
61 |
+
pass
|
62 |
+
return {}
|
63 |
+
|
64 |
+
def extract_peer_ip_info(multiaddr_str):
|
65 |
+
if ip_match := re.search(r"/ip4/(\d+\.\d+\.\d+\.\d+)", multiaddr_str):
|
66 |
+
return get_location(ip_match[1])
|
67 |
+
return {}
|
pyproject.toml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 120
|
3 |
+
required-version = "22.3.0"
|
4 |
+
|
5 |
+
[tool.isort]
|
6 |
+
profile = "black"
|
7 |
+
line_length = 120
|
8 |
+
combine_as_imports = true
|
9 |
+
combine_star = true
|
10 |
+
known_local_folder = ["tests", "cli"]
|
state_updater.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
from dataclasses import asdict, is_dataclass
|
5 |
+
from enum import Enum
|
6 |
+
|
7 |
+
import hivemind
|
8 |
+
import simplejson
|
9 |
+
from flask import Flask, render_template
|
10 |
+
|
11 |
+
import config
|
12 |
+
from health import fetch_health_state
|
13 |
+
from metrics import get_prometheus_metrics
|
14 |
+
|
15 |
+
logger = hivemind.get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class StateUpdaterThread(threading.Thread):
|
19 |
+
def __init__(self, dht: hivemind.DHT, app: Flask, **kwargs):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.dht = dht
|
22 |
+
self.app = app
|
23 |
+
|
24 |
+
self.state_json = self.state_html = None
|
25 |
+
self.ready = threading.Event()
|
26 |
+
|
27 |
+
def run(self):
|
28 |
+
while True:
|
29 |
+
start_time = time.perf_counter()
|
30 |
+
try:
|
31 |
+
state_dict = fetch_health_state(self.dht)
|
32 |
+
with self.app.app_context():
|
33 |
+
self.state_html = render_template("index.html", **state_dict)
|
34 |
+
self.prometheus_metrics = get_prometheus_metrics(state_dict)
|
35 |
+
self.state_json = simplejson.dumps(state_dict, indent=2, ignore_nan=True, default=json_default)
|
36 |
+
|
37 |
+
self.ready.set()
|
38 |
+
logger.info(f"Fetched new state in {time.perf_counter() - start_time:.1f} sec")
|
39 |
+
except Exception:
|
40 |
+
logger.error("Failed to update state:", exc_info=True)
|
41 |
+
|
42 |
+
delay = config.UPDATE_PERIOD - (time.perf_counter() - start_time)
|
43 |
+
if delay < 0:
|
44 |
+
logger.warning("Update took more than update_period, consider increasing it")
|
45 |
+
time.sleep(max(delay, 0))
|
46 |
+
|
47 |
+
|
48 |
+
def json_default(value):
|
49 |
+
if is_dataclass(value):
|
50 |
+
return asdict(value)
|
51 |
+
if isinstance(value, Enum):
|
52 |
+
return value.name.lower()
|
53 |
+
if isinstance(value, hivemind.PeerID):
|
54 |
+
return value.to_base58()
|
55 |
+
if isinstance(value, datetime.datetime):
|
56 |
+
return value.timestamp()
|
57 |
+
raise TypeError(f"Can't serialize {repr(value)}")
|