Spaces:
Paused
Paused
import os | |
import random | |
from pathlib import Path | |
from statistics import mean | |
from typing import Any, Iterator, Union | |
from fastapi.responses import HTMLResponse | |
import fasttext | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from httpx import AsyncClient, Client, Timeout | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub.utils import logging | |
from toolz import concat, groupby, valmap | |
from starlette.responses import RedirectResponse | |
app = FastAPI() | |
logger = logging.get_logger(__name__) | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" | |
DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification" | |
headers = { | |
"authorization": f"Bearer ${HF_TOKEN}", | |
} | |
timeout = Timeout(60, read=120) | |
client = Client(headers=headers, timeout=timeout) | |
async_client = AsyncClient(headers=headers, timeout=timeout) | |
TARGET_COLUMN_NAMES = { | |
"text", | |
"input", | |
"tokens", | |
"prompt", | |
"instruction", | |
"sentence_1", | |
"question", | |
"sentence2", | |
"answer", | |
"sentence", | |
"response", | |
"context", | |
"query", | |
"chosen", | |
"rejected", | |
} | |
def datasets_server_valid_rows(hub_id: str): | |
try: | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") | |
return resp.json()["viewer"] | |
except Exception as e: | |
logger.error(f"Failed to get is-valid for {hub_id}: {e}") | |
return False | |
async def get_first_config_and_split_name(hub_id: str): | |
try: | |
resp = await async_client.get( | |
f"https://datasets-server.huggingface.co/splits?dataset={hub_id}" | |
) | |
data = resp.json() | |
return data["splits"][0]["config"], data["splits"][0]["split"] | |
except Exception as e: | |
logger.error(f"Failed to get splits for {hub_id}: {e}") | |
return None | |
async def get_dataset_info(hub_id: str, config: str | None = None): | |
if config is None: | |
config = get_first_config_and_split_name(hub_id) | |
if config is None: | |
return None | |
else: | |
config = config[0] | |
resp = await async_client.get( | |
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" | |
) | |
resp.raise_for_status() | |
return resp.json() | |
async def get_random_rows( | |
hub_id: str, | |
total_length: int, | |
number_of_rows: int, | |
max_request_calls: int, | |
config="default", | |
split="train", | |
): | |
rows = [] | |
rows_per_call = min( | |
number_of_rows // max_request_calls, total_length // max_request_calls | |
) | |
rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100 | |
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): | |
offset = random.randint(0, total_length - rows_per_call) | |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" | |
logger.info(f"Fetching {url}") | |
print(url) | |
response = await async_client.get(url) | |
if response.status_code == 200: | |
data = response.json() | |
batch_rows = data.get("rows") | |
rows.extend(batch_rows) | |
else: | |
print(f"Failed to fetch data: {response.status_code}") | |
print(url) | |
if len(rows) >= number_of_rows: | |
break | |
return [row.get("row") for row in rows] | |
def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
Path("code/models").mkdir(parents=True, exist_ok=True) | |
model_path = hf_hub_download( | |
repo_id, | |
"model.bin", | |
cache_dir="code/models", | |
local_dir="code/models", | |
local_dir_use_symlinks=False, | |
) | |
return fasttext.load_model(model_path) | |
model = load_model(DEFAULT_FAST_TEXT_MODEL) | |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
for row in rows: | |
if isinstance(row, str): | |
# split on lines and remove empty lines | |
line = row.split("\n") | |
for line in line: | |
if line: | |
yield line | |
elif isinstance(row, list): | |
try: | |
line = " ".join(row) | |
if len(line) < min_length: | |
continue | |
else: | |
yield line | |
except TypeError: | |
continue | |
def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
predictions = model.predict(inputs, k=k) | |
return [ | |
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
for label, prob in zip(predictions[0], predictions[1]) | |
] | |
def get_label(x): | |
return x.get("label") | |
def get_mean_score(preds): | |
return mean([pred.get("score") for pred in preds]) | |
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
"""Filter a dict to include items whose value is above `threshold_percent`""" | |
total = sum(counts_dict.values()) | |
threshold = total * threshold_percent | |
return {k for k, v in counts_dict.items() if v >= threshold} | |
def predict_rows(rows, target_column, language_threshold_percent=0.2): | |
rows = (row.get(target_column) for row in rows) | |
rows = (row for row in rows if row is not None) | |
rows = list(yield_clean_rows(rows)) | |
predictions = [model_predict(row) for row in rows] | |
predictions = [pred for pred in predictions if pred is not None] | |
predictions = list(concat(predictions)) | |
predictions_by_lang = groupby(get_label, predictions) | |
langues_counts = valmap(len, predictions_by_lang) | |
keys_to_keep = filter_by_frequency( | |
langues_counts, threshold_percent=language_threshold_percent | |
) | |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
return { | |
"predictions": dict(valmap(get_mean_score, filtered_dict)), | |
"pred": predictions, | |
} | |
# @app.get("/", response_class=HTMLResponse) | |
# async def read_index(): | |
# html_content = Path("index.html").read_text() | |
# return HTMLResponse(content=html_content) | |
def root(): | |
return RedirectResponse(url="/docs") | |
async def predict_language( | |
hub_id: str, | |
config: str | None = None, | |
split: str | None = None, | |
max_request_calls: int = 10, | |
number_of_rows: int = 1000, | |
) -> dict[Any, Any] | None: | |
is_valid = datasets_server_valid_rows(hub_id) | |
if not is_valid: | |
logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
if not config and not split: | |
config, split = await get_first_config_and_split_name(hub_id) | |
if not config: | |
config, _ = await get_first_config_and_split_name(hub_id) | |
if not split: | |
_, split = await get_first_config_and_split_name(hub_id) | |
info = await get_dataset_info(hub_id, config) | |
if info is None: | |
logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
return None | |
if dataset_info := info.get("dataset_info"): | |
total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") | |
features = dataset_info.get("features") | |
column_names = set(features.keys()) | |
logger.info(f"Column names: {column_names}") | |
if not set(column_names).intersection(TARGET_COLUMN_NAMES): | |
logger.error( | |
f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}" | |
) | |
return None | |
for column in TARGET_COLUMN_NAMES: | |
if column in column_names: | |
target_column = column | |
logger.info(f"Using column {target_column} for language detection") | |
break | |
random_rows = await get_random_rows( | |
hub_id, | |
total_rows_for_split, | |
number_of_rows, | |
max_request_calls, | |
config, | |
split, | |
) | |
logger.info(f"Predicting language for {len(random_rows)} rows") | |
predictions = predict_rows(random_rows, target_column) | |
predictions["hub_id"] = hub_id | |
predictions["config"] = config | |
predictions["split"] = split | |
return predictions | |