Spaces:
Paused
Paused
Commit
•
8035330
1
Parent(s):
748b101
update
Browse files
main.py
CHANGED
@@ -1,301 +1,73 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
-
import random
|
4 |
-
from datetime import timedelta
|
5 |
-
from statistics import mean
|
6 |
-
from typing import Annotated, Any, Iterator, Union
|
7 |
-
|
8 |
-
import fasttext
|
9 |
-
from cashews import cache
|
10 |
from dotenv import load_dotenv
|
11 |
-
|
12 |
-
from httpx import AsyncClient, Client, Timeout
|
13 |
-
from huggingface_hub import hf_hub_download
|
14 |
-
from iso639 import Lang
|
15 |
-
from starlette.responses import RedirectResponse
|
16 |
-
from toolz import concat, groupby, valmap
|
17 |
|
18 |
-
|
|
|
19 |
|
|
|
20 |
|
21 |
-
logger = logging.getLogger(__name__)
|
22 |
-
app = FastAPI()
|
23 |
load_dotenv()
|
24 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
25 |
-
|
26 |
-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
27 |
-
|
28 |
-
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
|
29 |
-
|
30 |
-
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
31 |
-
DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification"
|
32 |
-
headers = {
|
33 |
-
"authorization": f"Bearer ${HF_TOKEN}",
|
34 |
-
}
|
35 |
-
timeout = Timeout(60, read=120)
|
36 |
-
client = Client(headers=headers, timeout=timeout)
|
37 |
-
async_client = AsyncClient(headers=headers, timeout=timeout)
|
38 |
-
|
39 |
-
TARGET_COLUMN_NAMES = {
|
40 |
-
"text",
|
41 |
-
"input",
|
42 |
-
"tokens",
|
43 |
-
"prompt",
|
44 |
-
"instruction",
|
45 |
-
"sentence_1",
|
46 |
-
"question",
|
47 |
-
"sentence2",
|
48 |
-
"answer",
|
49 |
-
"sentence",
|
50 |
-
"response",
|
51 |
-
"context",
|
52 |
-
"query",
|
53 |
-
"chosen",
|
54 |
-
"rejected",
|
55 |
-
}
|
56 |
-
|
57 |
-
|
58 |
-
def datasets_server_valid_rows(hub_id: str):
|
59 |
-
try:
|
60 |
-
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
|
61 |
-
return resp.json()["viewer"]
|
62 |
-
except Exception as e:
|
63 |
-
logger.error(f"Failed to get is-valid for {hub_id}: {e}")
|
64 |
-
return False
|
65 |
-
|
66 |
-
|
67 |
-
async def get_first_config_and_split_name(hub_id: str):
|
68 |
-
try:
|
69 |
-
resp = await async_client.get(
|
70 |
-
f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
|
71 |
-
)
|
72 |
-
|
73 |
-
data = resp.json()
|
74 |
-
return data["splits"][0]["config"], data["splits"][0]["split"]
|
75 |
-
except Exception as e:
|
76 |
-
logger.error(f"Failed to get splits for {hub_id}: {e}")
|
77 |
-
return None
|
78 |
-
|
79 |
-
|
80 |
-
async def get_dataset_info(hub_id: str, config: str | None = None):
|
81 |
-
if config is None:
|
82 |
-
config = get_first_config_and_split_name(hub_id)
|
83 |
-
if config is None:
|
84 |
-
return None
|
85 |
-
else:
|
86 |
-
config = config[0]
|
87 |
-
resp = await async_client.get(
|
88 |
-
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
|
89 |
-
)
|
90 |
-
resp.raise_for_status()
|
91 |
-
return resp.json()
|
92 |
-
|
93 |
-
|
94 |
-
@cache(ttl=timedelta(minutes=5))
|
95 |
-
async def fetch_rows(url: str) -> list[dict]:
|
96 |
-
response = await async_client.get(url)
|
97 |
-
if response.status_code == 200:
|
98 |
-
data = response.json()
|
99 |
-
return data.get("rows")
|
100 |
-
else:
|
101 |
-
print(f"Failed to fetch data: {response.status_code}")
|
102 |
-
print(url)
|
103 |
-
return []
|
104 |
-
|
105 |
-
|
106 |
-
# Function to get random rows from the dataset
|
107 |
-
async def get_random_rows(
|
108 |
-
hub_id: str,
|
109 |
-
total_length: int,
|
110 |
-
number_of_rows: int,
|
111 |
-
max_request_calls: int,
|
112 |
-
config="default",
|
113 |
-
split="train",
|
114 |
-
):
|
115 |
-
rows = []
|
116 |
-
rows_per_call = min(
|
117 |
-
number_of_rows // max_request_calls, total_length // max_request_calls
|
118 |
-
)
|
119 |
-
rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100
|
120 |
-
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
|
121 |
-
offset = random.randint(0, total_length - rows_per_call)
|
122 |
-
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
123 |
-
logger.info(f"Fetching {url}")
|
124 |
-
batch_rows = await fetch_rows(url)
|
125 |
-
rows.extend(batch_rows)
|
126 |
-
if len(rows) >= number_of_rows:
|
127 |
-
break
|
128 |
-
return [row.get("row") for row in rows]
|
129 |
-
|
130 |
-
|
131 |
-
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
132 |
-
from pathlib import Path
|
133 |
-
|
134 |
-
Path("code/models").mkdir(parents=True, exist_ok=True)
|
135 |
-
model_path = hf_hub_download(
|
136 |
-
repo_id,
|
137 |
-
"model.bin",
|
138 |
-
# cache_dir="code/models",
|
139 |
-
# local_dir="code/models",
|
140 |
-
# local_dir_use_symlinks=False,
|
141 |
-
)
|
142 |
-
return fasttext.load_model(model_path)
|
143 |
-
|
144 |
-
|
145 |
-
model = load_model(DEFAULT_FAST_TEXT_MODEL)
|
146 |
-
|
147 |
-
|
148 |
-
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
149 |
-
for row in rows:
|
150 |
-
if isinstance(row, str):
|
151 |
-
# split on lines and remove empty lines
|
152 |
-
line = row.split("\n")
|
153 |
-
for line in line:
|
154 |
-
if line:
|
155 |
-
yield line
|
156 |
-
elif isinstance(row, list):
|
157 |
-
try:
|
158 |
-
line = " ".join(row)
|
159 |
-
if len(line) < min_length:
|
160 |
-
continue
|
161 |
-
else:
|
162 |
-
yield line
|
163 |
-
except TypeError:
|
164 |
-
continue
|
165 |
-
|
166 |
-
|
167 |
-
def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
|
168 |
-
predictions = model.predict(inputs, k=k)
|
169 |
-
return [
|
170 |
-
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
|
171 |
-
for label, prob in zip(predictions[0], predictions[1])
|
172 |
-
]
|
173 |
-
|
174 |
-
|
175 |
-
def get_label(x):
|
176 |
-
return x.get("label")
|
177 |
-
|
178 |
-
|
179 |
-
def get_mean_score(preds):
|
180 |
-
return mean([pred.get("score") for pred in preds])
|
181 |
-
|
182 |
-
|
183 |
-
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
|
184 |
-
"""Filter a dict to include items whose value is above `threshold_percent`"""
|
185 |
-
total = sum(counts_dict.values())
|
186 |
-
threshold = total * threshold_percent
|
187 |
-
return {k for k, v in counts_dict.items() if v >= threshold}
|
188 |
-
|
189 |
-
|
190 |
-
def try_parse_language(lang: str) -> str | None:
|
191 |
-
try:
|
192 |
-
split = lang.split("_")
|
193 |
-
lang = split[0]
|
194 |
-
lang = Lang(lang)
|
195 |
-
return lang.pt1
|
196 |
-
except Exception as e:
|
197 |
-
logger.error(f"Failed to parse language {lang}: {e}")
|
198 |
-
return None
|
199 |
|
|
|
|
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
}
|
226 |
-
|
227 |
-
|
228 |
-
return
|
229 |
-
|
230 |
-
|
231 |
-
# @app.get("/", response_class=HTMLResponse)
|
232 |
-
# async def read_index():
|
233 |
-
# html_content = Path("index.html").read_text()
|
234 |
-
# return HTMLResponse(content=html_content)
|
235 |
-
|
236 |
-
|
237 |
-
@app.get("/", include_in_schema=False)
|
238 |
-
def root():
|
239 |
-
return RedirectResponse(url="/docs")
|
240 |
-
|
241 |
-
# item_id: Annotated[int, Path(title="The ID of the item to get", ge=1)], q: str
|
242 |
-
|
243 |
-
|
244 |
-
@app.get("/predict_dataset_language/{hub_id:path}")
|
245 |
-
@cache(ttl=timedelta(minutes=10))
|
246 |
-
async def predict_language(
|
247 |
-
hub_id: Annotated[str, Path(title="The hub id of the dataset to predict")],
|
248 |
-
config: str | None = None,
|
249 |
-
split: str | None = None,
|
250 |
-
max_request_calls: Annotated[
|
251 |
-
int, Query(title="Max number of requests to datasets server", gt=0, le=50)
|
252 |
-
] = 10,
|
253 |
-
number_of_rows: int = 1000,
|
254 |
-
language_threshold_percent: float = 0.2,
|
255 |
-
) -> dict[Any, Any] | None:
|
256 |
-
is_valid = datasets_server_valid_rows(hub_id)
|
257 |
-
if not is_valid:
|
258 |
-
logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
|
259 |
-
if not config and not split:
|
260 |
-
config, split = await get_first_config_and_split_name(hub_id)
|
261 |
-
if not config:
|
262 |
-
config, _ = await get_first_config_and_split_name(hub_id)
|
263 |
-
if not split:
|
264 |
-
_, split = await get_first_config_and_split_name(hub_id)
|
265 |
-
info = await get_dataset_info(hub_id, config)
|
266 |
-
if info is None:
|
267 |
-
logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
|
268 |
-
return None
|
269 |
-
if dataset_info := info.get("dataset_info"):
|
270 |
-
total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
|
271 |
-
features = dataset_info.get("features")
|
272 |
-
column_names = set(features.keys())
|
273 |
-
logger.info(f"Column names: {column_names}")
|
274 |
-
if not set(column_names).intersection(TARGET_COLUMN_NAMES):
|
275 |
-
logger.error(
|
276 |
-
f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
|
277 |
-
)
|
278 |
-
return None
|
279 |
-
for column in TARGET_COLUMN_NAMES:
|
280 |
-
if column in column_names:
|
281 |
-
target_column = column
|
282 |
-
logger.info(f"Using column {target_column} for language detection")
|
283 |
-
break
|
284 |
-
random_rows = await get_random_rows(
|
285 |
-
hub_id,
|
286 |
-
total_rows_for_split,
|
287 |
-
number_of_rows,
|
288 |
-
max_request_calls,
|
289 |
-
config,
|
290 |
-
split,
|
291 |
-
)
|
292 |
-
logger.info(f"Predicting language for {len(random_rows)} rows")
|
293 |
-
predictions = predict_rows(
|
294 |
-
random_rows,
|
295 |
-
target_column,
|
296 |
-
language_threshold_percent=language_threshold_percent,
|
297 |
-
)
|
298 |
-
predictions["hub_id"] = hub_id
|
299 |
-
predictions["config"] = config
|
300 |
-
predictions["split"] = split
|
301 |
-
return predictions
|
|
|
1 |
+
import contextlib
|
2 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from pydantic import BaseModel
|
5 |
+
import json
|
6 |
+
from huggingface_hub import HfApi
|
7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from dotenv import load_dotenv
|
9 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
from datetime import datetime
|
12 |
+
from pathlib import Path
|
13 |
|
14 |
+
from huggingface_hub import CommitScheduler
|
15 |
|
|
|
|
|
16 |
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
19 |
+
hf_api = HfApi(token=HF_TOKEN)
|
20 |
|
21 |
+
app = FastAPI()
|
22 |
+
VOTES_FILE = "votes/votes.jsonl"
|
23 |
+
# Configure CORS
|
24 |
+
origins = [
|
25 |
+
"https://huggingface.co",
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
app.add_middleware(
|
30 |
+
CORSMiddleware,
|
31 |
+
allow_origins=origins,
|
32 |
+
allow_credentials=True,
|
33 |
+
allow_methods=["POST"],
|
34 |
+
allow_headers=["*"],
|
35 |
+
)
|
36 |
+
|
37 |
+
scheduler = CommitScheduler(
|
38 |
+
repo_id="davanstrien/votes",
|
39 |
+
repo_type="dataset",
|
40 |
+
folder_path="votes",
|
41 |
+
path_in_repo="data",
|
42 |
+
every=1,
|
43 |
+
hf_api=hf_api,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
class Vote(BaseModel):
|
48 |
+
dataset: str
|
49 |
+
description: str
|
50 |
+
vote: int
|
51 |
+
userID: str
|
52 |
+
|
53 |
+
|
54 |
+
def save_vote(vote_entry):
|
55 |
+
with open(VOTES_FILE, "a") as file:
|
56 |
+
# add time stamp to the vote entry
|
57 |
+
date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
58 |
+
vote_entry["timestamp"] = date_time
|
59 |
+
json.dump(vote_entry, file)
|
60 |
+
file.write("\n")
|
61 |
+
|
62 |
+
|
63 |
+
@app.post("/vote")
|
64 |
+
async def receive_vote(vote: Vote, background_tasks: BackgroundTasks):
|
65 |
+
vote_entry = {
|
66 |
+
"dataset": vote.dataset,
|
67 |
+
"vote": vote.vote,
|
68 |
+
"description": vote.description,
|
69 |
+
"userID": vote.userID,
|
70 |
}
|
71 |
+
# Append the vote entry to the JSONL file
|
72 |
+
background_tasks.add_task(save_vote, vote_entry)
|
73 |
+
return {"message": "Vote submitted successfully"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|