import gradio as gr from httpx import Client import random import os import fasttext from huggingface_hub import hf_hub_download from typing import Union from typing import Iterator from dotenv import load_dotenv from toolz import groupby, valmap, concat from statistics import mean from httpx import Timeout from huggingface_hub.utils import logging logger = logging.get_logger(__name__) load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" headers = { "authorization": f"Bearer ${HF_TOKEN}", } timeout = Timeout(60, read=120) client = Client(headers=headers, timeout=timeout) # non exhaustive list of columns that might contain text which can be used for language detection # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first TARGET_COLUMN_NAMES = { "text", "input", "tokens", "prompt", "instruction", "sentence_1", "question", "sentence2", "answer", "sentence", "response", "context", "query", "chosen", "rejected", } def datasets_server_valid_rows(hub_id: str): resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") resp.raise_for_status() return resp.json()["viewer"] def get_first_config_and_split_name(hub_id: str): resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}") resp.raise_for_status() data = resp.json() return data["splits"][0]["config"], data["splits"][0]["split"] def get_dataset_info(hub_id: str, config: str | None = None): if config is None: config = get_first_config_and_split_name(hub_id) if config is None: return None else: config = config[0] resp = client.get( f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" ) resp.raise_for_status() return resp.json() def get_random_rows( hub_id, total_length, number_of_rows, max_request_calls, config="default", split="train", ): rows = [] rows_per_call = min( number_of_rows // max_request_calls, total_length // max_request_calls ) rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100 for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): offset = random.randint(0, total_length - rows_per_call) url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" response = client.get(url) if response.status_code == 200: data = response.json() batch_rows = data.get("rows") rows.extend(batch_rows) else: print(f"Failed to fetch data: {response.status_code}") print(url) if len(rows) >= number_of_rows: break return [row.get("row") for row in rows] def load_model(repo_id: str) -> fasttext.FastText._FastText: model_path = hf_hub_download(repo_id, filename="model.bin") return fasttext.load_model(model_path) # def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str): # pass def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: for row in rows: if isinstance(row, str): # split on lines and remove empty lines line = row.split("\n") for line in line: if line: yield line elif isinstance(row, list): try: line = " ".join(row) if len(line) < min_length: continue else: yield line except TypeError: continue FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" # model = load_model(DEFAULT_FAST_TEXT_MODEL) model = fasttext.load_model( hf_hub_download("facebook/fasttext-language-identification", "model.bin") ) def model_predict(inputs: str, k=1) -> list[dict[str, float]]: predictions = model.predict(inputs, k=k) return [ {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} for label, prob in zip(predictions[0], predictions[1]) ] def get_label(x): return x.get("label") def get_mean_score(preds): return mean([pred.get("score") for pred in preds]) def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): """Filter a dict to include items whose value is above `threshold_percent`""" total = sum(counts_dict.values()) threshold = total * threshold_percent return {k for k, v in counts_dict.items() if v >= threshold} def predict_rows(rows, target_column, language_threshold_percent=0.2): rows = (row.get(target_column) for row in rows) rows = (row for row in rows if row is not None) rows = list(yield_clean_rows(rows)) predictions = [model_predict(row) for row in rows] predictions = [pred for pred in predictions if pred is not None] predictions = list(concat(predictions)) predictions_by_lang = groupby(get_label, predictions) langues_counts = valmap(len, predictions_by_lang) keys_to_keep = filter_by_frequency( langues_counts, threshold_percent=language_threshold_percent ) filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} return { "predictions": dict(valmap(get_mean_score, filtered_dict)), "pred": predictions, } def predict_language( hub_id: str, config: str | None = None, split: str | None = None, max_request_calls: int = 10, ): is_valid = datasets_server_valid_rows(hub_id) if not is_valid: gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") if not config: config, split = get_first_config_and_split_name(hub_id) info = get_dataset_info(hub_id, config) if info is None: gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") if dataset_info := info.get("dataset_info"): total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") features = dataset_info.get("features") column_names = set(features.keys()) logger.info(f"Column names: {column_names}") if not set(column_names).intersection(TARGET_COLUMN_NAMES): raise gr.Error( f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}" ) for column in TARGET_COLUMN_NAMES: if column in column_names: target_column = column logger.info(f"Using column {target_column} for language detection") break random_rows = get_random_rows( hub_id, total_rows_for_split, 1000, max_request_calls, config, split ) logger.info(f"Predicting language for {len(random_rows)} rows") predictions = predict_rows(random_rows, target_column) predictions["hub_id"] = hub_id predictions["config"] = config predictions["split"] = split return predictions inputs = [ gr.Text(label="dataset id"), gr.Textbox( None, label="config", ), gr.Textbox(None, label="split"), ] interface = gr.Interface(predict_language, inputs=inputs, outputs="json") interface.queue() interface.launch()