import os import random from pathlib import Path from statistics import mean from typing import Any, Iterator, Union from fastapi.responses import HTMLResponse import fasttext from dotenv import load_dotenv from fastapi import FastAPI from httpx import AsyncClient, Client, Timeout from huggingface_hub import hf_hub_download from huggingface_hub.utils import logging from toolz import concat, groupby, valmap from starlette.responses import RedirectResponse from cashews import cache from datetime import timedelta import logging cache.setup("mem://") logger = logging.getLogger(__name__) app = FastAPI() # logger = logging.get_logger(__name__) load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification" headers = { "authorization": f"Bearer ${HF_TOKEN}", } timeout = Timeout(60, read=120) client = Client(headers=headers, timeout=timeout) async_client = AsyncClient(headers=headers, timeout=timeout) TARGET_COLUMN_NAMES = { "text", "input", "tokens", "prompt", "instruction", "sentence_1", "question", "sentence2", "answer", "sentence", "response", "context", "query", "chosen", "rejected", } def datasets_server_valid_rows(hub_id: str): try: resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") return resp.json()["viewer"] except Exception as e: logger.error(f"Failed to get is-valid for {hub_id}: {e}") return False async def get_first_config_and_split_name(hub_id: str): try: resp = await async_client.get( f"https://datasets-server.huggingface.co/splits?dataset={hub_id}" ) data = resp.json() return data["splits"][0]["config"], data["splits"][0]["split"] except Exception as e: logger.error(f"Failed to get splits for {hub_id}: {e}") return None async def get_dataset_info(hub_id: str, config: str | None = None): if config is None: config = get_first_config_and_split_name(hub_id) if config is None: return None else: config = config[0] resp = await async_client.get( f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" ) resp.raise_for_status() return resp.json() async def get_random_rows( hub_id: str, total_length: int, number_of_rows: int, max_request_calls: int, config="default", split="train", ): rows = [] rows_per_call = min( number_of_rows // max_request_calls, total_length // max_request_calls ) rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100 for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): offset = random.randint(0, total_length - rows_per_call) url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" logger.info(f"Fetching {url}") print(url) response = await async_client.get(url) if response.status_code == 200: data = response.json() batch_rows = data.get("rows") rows.extend(batch_rows) else: print(f"Failed to fetch data: {response.status_code}") print(url) if len(rows) >= number_of_rows: break return [row.get("row") for row in rows] def load_model(repo_id: str) -> fasttext.FastText._FastText: Path("code/models").mkdir(parents=True, exist_ok=True) model_path = hf_hub_download( repo_id, "model.bin", # cache_dir="code/models", # local_dir="code/models", # local_dir_use_symlinks=False, ) return fasttext.load_model(model_path) model = load_model(DEFAULT_FAST_TEXT_MODEL) def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: for row in rows: if isinstance(row, str): # split on lines and remove empty lines line = row.split("\n") for line in line: if line: yield line elif isinstance(row, list): try: line = " ".join(row) if len(line) < min_length: continue else: yield line except TypeError: continue def model_predict(inputs: str, k=1) -> list[dict[str, float]]: predictions = model.predict(inputs, k=k) return [ {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} for label, prob in zip(predictions[0], predictions[1]) ] def get_label(x): return x.get("label") def get_mean_score(preds): return mean([pred.get("score") for pred in preds]) def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): """Filter a dict to include items whose value is above `threshold_percent`""" total = sum(counts_dict.values()) threshold = total * threshold_percent return {k for k, v in counts_dict.items() if v >= threshold} def predict_rows(rows, target_column, language_threshold_percent=0.2): rows = (row.get(target_column) for row in rows) rows = (row for row in rows if row is not None) rows = list(yield_clean_rows(rows)) predictions = [model_predict(row) for row in rows] predictions = [pred for pred in predictions if pred is not None] predictions = list(concat(predictions)) predictions_by_lang = groupby(get_label, predictions) langues_counts = valmap(len, predictions_by_lang) keys_to_keep = filter_by_frequency( langues_counts, threshold_percent=language_threshold_percent ) filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} return { "predictions": dict(valmap(get_mean_score, filtered_dict)), } @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") @app.get("/predict_dataset_language/{hub_id:path}") @cache(ttl=timedelta(minutes=10)) async def predict_language( hub_id: str, config: str | None = None, split: str | None = None, max_request_calls: int = 10, number_of_rows: int = 1000, ) -> dict[Any, Any] | None: is_valid = datasets_server_valid_rows(hub_id) if not is_valid: logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") if not config and not split: config, split = await get_first_config_and_split_name(hub_id) if not config: config, _ = await get_first_config_and_split_name(hub_id) if not split: _, split = await get_first_config_and_split_name(hub_id) info = await get_dataset_info(hub_id, config) if info is None: logger.error(f"Dataset {hub_id} is not accessible via the datasets server.") return None if dataset_info := info.get("dataset_info"): total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") features = dataset_info.get("features") column_names = set(features.keys()) logger.info(f"Column names: {column_names}") if not set(column_names).intersection(TARGET_COLUMN_NAMES): logger.error( f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}" ) return None for column in TARGET_COLUMN_NAMES: if column in column_names: target_column = column logger.info(f"Using column {target_column} for language detection") break random_rows = await get_random_rows( hub_id, total_rows_for_split, number_of_rows, max_request_calls, config, split, ) logger.info(f"Predicting language for {len(random_rows)} rows") predictions = predict_rows(random_rows, target_column) predictions["hub_id"] = hub_id predictions["config"] = config predictions["split"] = split return predictions