|
import os |
|
import random |
|
from pathlib import Path |
|
from statistics import mean |
|
from typing import Any, Iterator, Union |
|
|
|
import fasttext |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI |
|
from httpx import AsyncClient, Client, Timeout |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import logging |
|
from toolz import concat, groupby, valmap |
|
|
|
app = FastAPI() |
|
logger = logging.get_logger(__name__) |
|
load_dotenv() |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" |
|
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" |
|
headers = { |
|
"authorization": f"Bearer ${HF_TOKEN}", |
|
} |
|
timeout = Timeout(60, read=120) |
|
client = Client(headers=headers, timeout=timeout) |
|
async_client = AsyncClient(headers=headers, timeout=timeout) |
|
|
|
|
|
TARGET_COLUMN_NAMES = { |
|
"text", |
|
"input", |
|
"tokens", |
|
"prompt", |
|
"instruction", |
|
"sentence_1", |
|
"question", |
|
"sentence2", |
|
"answer", |
|
"sentence", |
|
"response", |
|
"context", |
|
"query", |
|
"chosen", |
|
"rejected", |
|
} |
|
|
|
|
|
def datasets_server_valid_rows(hub_id: str): |
|
try: |
|
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") |
|
return resp.json()["viewer"] |
|
except Exception as e: |
|
logger.error(f"Failed to get is-valid for {hub_id}: {e}") |
|
return False |
|
|
|
|
|
async def get_first_config_and_split_name(hub_id: str): |
|
try: |
|
resp = await async_client.get( |
|
f"https://datasets-server.huggingface.co/splits?dataset={hub_id}" |
|
) |
|
|
|
data = resp.json() |
|
return data["splits"][0]["config"], data["splits"][0]["split"] |
|
except Exception as e: |
|
logger.error(f"Failed to get splits for {hub_id}: {e}") |
|
return None |
|
|
|
|
|
async def get_dataset_info(hub_id: str, config: str | None = None): |
|
if config is None: |
|
config = get_first_config_and_split_name(hub_id) |
|
if config is None: |
|
return None |
|
else: |
|
config = config[0] |
|
resp = await async_client.get( |
|
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" |
|
) |
|
resp.raise_for_status() |
|
return resp.json() |
|
|
|
|
|
async def get_random_rows( |
|
hub_id: str, |
|
total_length: int, |
|
number_of_rows: int, |
|
max_request_calls: int, |
|
config="default", |
|
split="train", |
|
): |
|
rows = [] |
|
rows_per_call = min( |
|
number_of_rows // max_request_calls, total_length // max_request_calls |
|
) |
|
rows_per_call = min(rows_per_call, 100) |
|
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): |
|
offset = random.randint(0, total_length - rows_per_call) |
|
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" |
|
logger.info(f"Fetching {url}") |
|
print(url) |
|
response = await async_client.get(url) |
|
if response.status_code == 200: |
|
data = response.json() |
|
batch_rows = data.get("rows") |
|
rows.extend(batch_rows) |
|
else: |
|
print(f"Failed to fetch data: {response.status_code}") |
|
print(url) |
|
if len(rows) >= number_of_rows: |
|
break |
|
return [row.get("row") for row in rows] |
|
|
|
|
|
def load_model(repo_id: str) -> fasttext.FastText._FastText: |
|
model_path = hf_hub_download(repo_id, filename="model.bin") |
|
return fasttext.load_model(model_path) |
|
|
|
|
|
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: |
|
for row in rows: |
|
if isinstance(row, str): |
|
|
|
line = row.split("\n") |
|
for line in line: |
|
if line: |
|
yield line |
|
elif isinstance(row, list): |
|
try: |
|
line = " ".join(row) |
|
if len(line) < min_length: |
|
continue |
|
else: |
|
yield line |
|
except TypeError: |
|
continue |
|
|
|
|
|
FASTTEXT_PREFIX_LENGTH = 9 |
|
|
|
|
|
Path("code/models").mkdir(parents=True, exist_ok=True) |
|
model = fasttext.load_model( |
|
hf_hub_download( |
|
"facebook/fasttext-language-identification", |
|
"model.bin", |
|
cache_dir="code/models", |
|
local_dir="code/models", |
|
local_dir_use_symlinks=False, |
|
) |
|
) |
|
|
|
|
|
def model_predict(inputs: str, k=1) -> list[dict[str, float]]: |
|
predictions = model.predict(inputs, k=k) |
|
return [ |
|
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} |
|
for label, prob in zip(predictions[0], predictions[1]) |
|
] |
|
|
|
|
|
def get_label(x): |
|
return x.get("label") |
|
|
|
|
|
def get_mean_score(preds): |
|
return mean([pred.get("score") for pred in preds]) |
|
|
|
|
|
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): |
|
"""Filter a dict to include items whose value is above `threshold_percent`""" |
|
total = sum(counts_dict.values()) |
|
threshold = total * threshold_percent |
|
return {k for k, v in counts_dict.items() if v >= threshold} |
|
|
|
|
|
def predict_rows(rows, target_column, language_threshold_percent=0.2): |
|
rows = (row.get(target_column) for row in rows) |
|
rows = (row for row in rows if row is not None) |
|
rows = list(yield_clean_rows(rows)) |
|
predictions = [model_predict(row) for row in rows] |
|
predictions = [pred for pred in predictions if pred is not None] |
|
predictions = list(concat(predictions)) |
|
predictions_by_lang = groupby(get_label, predictions) |
|
langues_counts = valmap(len, predictions_by_lang) |
|
keys_to_keep = filter_by_frequency( |
|
langues_counts, threshold_percent=language_threshold_percent |
|
) |
|
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} |
|
return { |
|
"predictions": dict(valmap(get_mean_score, filtered_dict)), |
|
"pred": predictions, |
|
} |
|
|
|
|
|
@app.get("/predict_dataset_language/{hub_id}") |
|
async def predict_language( |
|
hub_id: str, |
|
config: str | None = None, |
|
split: str | None = None, |
|
max_request_calls: int = 10, |
|
number_of_rows: int = 1000, |
|
) -> dict[Any, Any] | None: |
|
is_valid = datasets_server_valid_rows(hub_id) |
|
if not is_valid: |
|
logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") |
|
if not config and not split: |
|
config, split = await get_first_config_and_split_name(hub_id) |
|
if not config: |
|
config, _ = await get_first_config_and_split_name(hub_id) |
|
if not split: |
|
_, split = await get_first_config_and_split_name(hub_id) |
|
info = await get_dataset_info(hub_id, config) |
|
if info is None: |
|
logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") |
|
return None |
|
if dataset_info := info.get("dataset_info"): |
|
total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") |
|
features = dataset_info.get("features") |
|
column_names = set(features.keys()) |
|
logger.info(f"Column names: {column_names}") |
|
if not set(column_names).intersection(TARGET_COLUMN_NAMES): |
|
logger.error( |
|
f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}" |
|
) |
|
return None |
|
for column in TARGET_COLUMN_NAMES: |
|
if column in column_names: |
|
target_column = column |
|
logger.info(f"Using column {target_column} for language detection") |
|
break |
|
random_rows = await get_random_rows( |
|
hub_id, |
|
total_rows_for_split, |
|
number_of_rows, |
|
max_request_calls, |
|
config, |
|
split, |
|
) |
|
logger.info(f"Predicting language for {len(random_rows)} rows") |
|
predictions = predict_rows(random_rows, target_column) |
|
predictions["hub_id"] = hub_id |
|
predictions["config"] = config |
|
predictions["split"] = split |
|
return predictions |
|
|