davanstrien HF staff commited on
Commit
8035330
1 Parent(s): 748b101
Files changed (1) hide show
  1. main.py +64 -292
main.py CHANGED
@@ -1,301 +1,73 @@
1
- import logging
 
 
 
 
 
2
  import os
3
- import random
4
- from datetime import timedelta
5
- from statistics import mean
6
- from typing import Annotated, Any, Iterator, Union
7
-
8
- import fasttext
9
- from cashews import cache
10
  from dotenv import load_dotenv
11
- from fastapi import FastAPI, Path, Query
12
- from httpx import AsyncClient, Client, Timeout
13
- from huggingface_hub import hf_hub_download
14
- from iso639 import Lang
15
- from starlette.responses import RedirectResponse
16
- from toolz import concat, groupby, valmap
17
 
18
- cache.setup("mem://")
 
19
 
 
20
 
21
- logger = logging.getLogger(__name__)
22
- app = FastAPI()
23
  load_dotenv()
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
-
26
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
27
-
28
- FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
29
-
30
- BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
31
- DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification"
32
- headers = {
33
- "authorization": f"Bearer ${HF_TOKEN}",
34
- }
35
- timeout = Timeout(60, read=120)
36
- client = Client(headers=headers, timeout=timeout)
37
- async_client = AsyncClient(headers=headers, timeout=timeout)
38
-
39
- TARGET_COLUMN_NAMES = {
40
- "text",
41
- "input",
42
- "tokens",
43
- "prompt",
44
- "instruction",
45
- "sentence_1",
46
- "question",
47
- "sentence2",
48
- "answer",
49
- "sentence",
50
- "response",
51
- "context",
52
- "query",
53
- "chosen",
54
- "rejected",
55
- }
56
-
57
-
58
- def datasets_server_valid_rows(hub_id: str):
59
- try:
60
- resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
61
- return resp.json()["viewer"]
62
- except Exception as e:
63
- logger.error(f"Failed to get is-valid for {hub_id}: {e}")
64
- return False
65
-
66
-
67
- async def get_first_config_and_split_name(hub_id: str):
68
- try:
69
- resp = await async_client.get(
70
- f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
71
- )
72
-
73
- data = resp.json()
74
- return data["splits"][0]["config"], data["splits"][0]["split"]
75
- except Exception as e:
76
- logger.error(f"Failed to get splits for {hub_id}: {e}")
77
- return None
78
-
79
-
80
- async def get_dataset_info(hub_id: str, config: str | None = None):
81
- if config is None:
82
- config = get_first_config_and_split_name(hub_id)
83
- if config is None:
84
- return None
85
- else:
86
- config = config[0]
87
- resp = await async_client.get(
88
- f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
89
- )
90
- resp.raise_for_status()
91
- return resp.json()
92
-
93
-
94
- @cache(ttl=timedelta(minutes=5))
95
- async def fetch_rows(url: str) -> list[dict]:
96
- response = await async_client.get(url)
97
- if response.status_code == 200:
98
- data = response.json()
99
- return data.get("rows")
100
- else:
101
- print(f"Failed to fetch data: {response.status_code}")
102
- print(url)
103
- return []
104
-
105
-
106
- # Function to get random rows from the dataset
107
- async def get_random_rows(
108
- hub_id: str,
109
- total_length: int,
110
- number_of_rows: int,
111
- max_request_calls: int,
112
- config="default",
113
- split="train",
114
- ):
115
- rows = []
116
- rows_per_call = min(
117
- number_of_rows // max_request_calls, total_length // max_request_calls
118
- )
119
- rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100
120
- for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
121
- offset = random.randint(0, total_length - rows_per_call)
122
- url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
123
- logger.info(f"Fetching {url}")
124
- batch_rows = await fetch_rows(url)
125
- rows.extend(batch_rows)
126
- if len(rows) >= number_of_rows:
127
- break
128
- return [row.get("row") for row in rows]
129
-
130
-
131
- def load_model(repo_id: str) -> fasttext.FastText._FastText:
132
- from pathlib import Path
133
-
134
- Path("code/models").mkdir(parents=True, exist_ok=True)
135
- model_path = hf_hub_download(
136
- repo_id,
137
- "model.bin",
138
- # cache_dir="code/models",
139
- # local_dir="code/models",
140
- # local_dir_use_symlinks=False,
141
- )
142
- return fasttext.load_model(model_path)
143
-
144
-
145
- model = load_model(DEFAULT_FAST_TEXT_MODEL)
146
-
147
-
148
- def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
149
- for row in rows:
150
- if isinstance(row, str):
151
- # split on lines and remove empty lines
152
- line = row.split("\n")
153
- for line in line:
154
- if line:
155
- yield line
156
- elif isinstance(row, list):
157
- try:
158
- line = " ".join(row)
159
- if len(line) < min_length:
160
- continue
161
- else:
162
- yield line
163
- except TypeError:
164
- continue
165
-
166
-
167
- def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
168
- predictions = model.predict(inputs, k=k)
169
- return [
170
- {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
171
- for label, prob in zip(predictions[0], predictions[1])
172
- ]
173
-
174
-
175
- def get_label(x):
176
- return x.get("label")
177
-
178
-
179
- def get_mean_score(preds):
180
- return mean([pred.get("score") for pred in preds])
181
-
182
-
183
- def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
184
- """Filter a dict to include items whose value is above `threshold_percent`"""
185
- total = sum(counts_dict.values())
186
- threshold = total * threshold_percent
187
- return {k for k, v in counts_dict.items() if v >= threshold}
188
-
189
-
190
- def try_parse_language(lang: str) -> str | None:
191
- try:
192
- split = lang.split("_")
193
- lang = split[0]
194
- lang = Lang(lang)
195
- return lang.pt1
196
- except Exception as e:
197
- logger.error(f"Failed to parse language {lang}: {e}")
198
- return None
199
 
 
 
200
 
201
- def predict_rows(
202
- rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
203
- ):
204
- rows = (row.get(target_column) for row in rows)
205
- rows = (row for row in rows if row is not None)
206
- rows = list(yield_clean_rows(rows))
207
- predictions = [model_predict(row) for row in rows]
208
- predictions = [pred for pred in predictions if pred is not None]
209
- predictions = list(concat(predictions))
210
- predictions_by_lang = groupby(get_label, predictions)
211
- langues_counts = valmap(len, predictions_by_lang)
212
- keys_to_keep = filter_by_frequency(
213
- langues_counts, threshold_percent=language_threshold_percent
214
- )
215
- filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
216
- raw_model_prediction_summary = dict(valmap(get_mean_score, filtered_dict))
217
- parsed_langs = {
218
- try_parse_language(k): v for k, v in raw_model_prediction_summary.items()
219
- }
220
- default_data = {
221
- "language_prediction_summary": parsed_langs,
222
- "raw_model_prediction_summary": raw_model_prediction_summary,
223
- "hub_id": "hub_id",
224
- "config": "config",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  }
226
- if return_raw_predictions:
227
- default_data["raw_predictions"] = predictions
228
- return default_data
229
-
230
-
231
- # @app.get("/", response_class=HTMLResponse)
232
- # async def read_index():
233
- # html_content = Path("index.html").read_text()
234
- # return HTMLResponse(content=html_content)
235
-
236
-
237
- @app.get("/", include_in_schema=False)
238
- def root():
239
- return RedirectResponse(url="/docs")
240
-
241
- # item_id: Annotated[int, Path(title="The ID of the item to get", ge=1)], q: str
242
-
243
-
244
- @app.get("/predict_dataset_language/{hub_id:path}")
245
- @cache(ttl=timedelta(minutes=10))
246
- async def predict_language(
247
- hub_id: Annotated[str, Path(title="The hub id of the dataset to predict")],
248
- config: str | None = None,
249
- split: str | None = None,
250
- max_request_calls: Annotated[
251
- int, Query(title="Max number of requests to datasets server", gt=0, le=50)
252
- ] = 10,
253
- number_of_rows: int = 1000,
254
- language_threshold_percent: float = 0.2,
255
- ) -> dict[Any, Any] | None:
256
- is_valid = datasets_server_valid_rows(hub_id)
257
- if not is_valid:
258
- logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
259
- if not config and not split:
260
- config, split = await get_first_config_and_split_name(hub_id)
261
- if not config:
262
- config, _ = await get_first_config_and_split_name(hub_id)
263
- if not split:
264
- _, split = await get_first_config_and_split_name(hub_id)
265
- info = await get_dataset_info(hub_id, config)
266
- if info is None:
267
- logger.error(f"Dataset {hub_id} is not accessible via the datasets server.")
268
- return None
269
- if dataset_info := info.get("dataset_info"):
270
- total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
271
- features = dataset_info.get("features")
272
- column_names = set(features.keys())
273
- logger.info(f"Column names: {column_names}")
274
- if not set(column_names).intersection(TARGET_COLUMN_NAMES):
275
- logger.error(
276
- f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
277
- )
278
- return None
279
- for column in TARGET_COLUMN_NAMES:
280
- if column in column_names:
281
- target_column = column
282
- logger.info(f"Using column {target_column} for language detection")
283
- break
284
- random_rows = await get_random_rows(
285
- hub_id,
286
- total_rows_for_split,
287
- number_of_rows,
288
- max_request_calls,
289
- config,
290
- split,
291
- )
292
- logger.info(f"Predicting language for {len(random_rows)} rows")
293
- predictions = predict_rows(
294
- random_rows,
295
- target_column,
296
- language_threshold_percent=language_threshold_percent,
297
- )
298
- predictions["hub_id"] = hub_id
299
- predictions["config"] = config
300
- predictions["split"] = split
301
- return predictions
 
1
+ import contextlib
2
+ from fastapi import FastAPI, Request, BackgroundTasks
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ import json
6
+ from huggingface_hub import HfApi
7
  import os
 
 
 
 
 
 
 
8
  from dotenv import load_dotenv
9
+ import json
 
 
 
 
 
10
 
11
+ from datetime import datetime
12
+ from pathlib import Path
13
 
14
+ from huggingface_hub import CommitScheduler
15
 
 
 
16
  load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+ hf_api = HfApi(token=HF_TOKEN)
20
 
21
+ app = FastAPI()
22
+ VOTES_FILE = "votes/votes.jsonl"
23
+ # Configure CORS
24
+ origins = [
25
+ "https://huggingface.co",
26
+ ]
27
+
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=origins,
32
+ allow_credentials=True,
33
+ allow_methods=["POST"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ scheduler = CommitScheduler(
38
+ repo_id="davanstrien/votes",
39
+ repo_type="dataset",
40
+ folder_path="votes",
41
+ path_in_repo="data",
42
+ every=1,
43
+ hf_api=hf_api,
44
+ )
45
+
46
+
47
+ class Vote(BaseModel):
48
+ dataset: str
49
+ description: str
50
+ vote: int
51
+ userID: str
52
+
53
+
54
+ def save_vote(vote_entry):
55
+ with open(VOTES_FILE, "a") as file:
56
+ # add time stamp to the vote entry
57
+ date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
58
+ vote_entry["timestamp"] = date_time
59
+ json.dump(vote_entry, file)
60
+ file.write("\n")
61
+
62
+
63
+ @app.post("/vote")
64
+ async def receive_vote(vote: Vote, background_tasks: BackgroundTasks):
65
+ vote_entry = {
66
+ "dataset": vote.dataset,
67
+ "vote": vote.vote,
68
+ "description": vote.description,
69
+ "userID": vote.userID,
70
  }
71
+ # Append the vote entry to the JSONL file
72
+ background_tasks.add_task(save_vote, vote_entry)
73
+ return {"message": "Vote submitted successfully"}