Spaces:
Paused
Paused
davanstrien
HF staff
Refactor predict_language function and update interface inputs and outputs
cee1d41
import gradio as gr | |
from httpx import Client | |
import random | |
import os | |
import fasttext | |
from huggingface_hub import hf_hub_download | |
from typing import Union | |
from typing import Iterator | |
from dotenv import load_dotenv | |
from toolz import groupby, valmap, concat | |
from statistics import mean | |
from httpx import Timeout | |
from huggingface_hub.utils import logging | |
logger = logging.get_logger(__name__) | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" | |
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" | |
headers = { | |
"authorization": f"Bearer ${HF_TOKEN}", | |
} | |
timeout = Timeout(60, read=120) | |
client = Client(headers=headers, timeout=timeout) | |
# non exhaustive list of columns that might contain text which can be used for language detection | |
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first | |
TARGET_COLUMN_NAMES = { | |
"text", | |
"input", | |
"tokens", | |
"prompt", | |
"instruction", | |
"sentence_1", | |
"question", | |
"sentence2", | |
"answer", | |
"sentence", | |
"response", | |
"context", | |
"query", | |
"chosen", | |
"rejected", | |
} | |
def datasets_server_valid_rows(hub_id: str): | |
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") | |
resp.raise_for_status() | |
return resp.json()["viewer"] | |
def get_first_config_and_split_name(hub_id: str): | |
resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}") | |
resp.raise_for_status() | |
data = resp.json() | |
return data["splits"][0]["config"], data["splits"][0]["split"] | |
def get_dataset_info(hub_id: str, config: str | None = None): | |
if config is None: | |
config = get_first_config_and_split_name(hub_id) | |
if config is None: | |
return None | |
else: | |
config = config[0] | |
resp = client.get( | |
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" | |
) | |
resp.raise_for_status() | |
return resp.json() | |
def get_random_rows( | |
hub_id, | |
total_length, | |
number_of_rows, | |
max_request_calls, | |
config="default", | |
split="train", | |
): | |
rows = [] | |
rows_per_call = min( | |
number_of_rows // max_request_calls, total_length // max_request_calls | |
) | |
rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100 | |
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): | |
offset = random.randint(0, total_length - rows_per_call) | |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" | |
response = client.get(url) | |
if response.status_code == 200: | |
data = response.json() | |
batch_rows = data.get("rows") | |
rows.extend(batch_rows) | |
else: | |
print(f"Failed to fetch data: {response.status_code}") | |
print(url) | |
if len(rows) >= number_of_rows: | |
break | |
return [row.get("row") for row in rows] | |
def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
model_path = hf_hub_download(repo_id, filename="model.bin") | |
return fasttext.load_model(model_path) | |
# def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str): | |
# pass | |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
for row in rows: | |
if isinstance(row, str): | |
# split on lines and remove empty lines | |
line = row.split("\n") | |
for line in line: | |
if line: | |
yield line | |
elif isinstance(row, list): | |
try: | |
line = " ".join(row) | |
if len(line) < min_length: | |
continue | |
else: | |
yield line | |
except TypeError: | |
continue | |
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
# model = load_model(DEFAULT_FAST_TEXT_MODEL) | |
model = fasttext.load_model( | |
hf_hub_download("facebook/fasttext-language-identification", "model.bin") | |
) | |
def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
predictions = model.predict(inputs, k=k) | |
return [ | |
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
for label, prob in zip(predictions[0], predictions[1]) | |
] | |
def get_label(x): | |
return x.get("label") | |
def get_mean_score(preds): | |
return mean([pred.get("score") for pred in preds]) | |
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
"""Filter a dict to include items whose value is above `threshold_percent`""" | |
total = sum(counts_dict.values()) | |
threshold = total * threshold_percent | |
return {k for k, v in counts_dict.items() if v >= threshold} | |
def predict_rows(rows, target_column, language_threshold_percent=0.2): | |
rows = (row.get(target_column) for row in rows) | |
rows = (row for row in rows if row is not None) | |
rows = list(yield_clean_rows(rows)) | |
predictions = [model_predict(row) for row in rows] | |
predictions = [pred for pred in predictions if pred is not None] | |
predictions = list(concat(predictions)) | |
predictions_by_lang = groupby(get_label, predictions) | |
langues_counts = valmap(len, predictions_by_lang) | |
keys_to_keep = filter_by_frequency( | |
langues_counts, threshold_percent=language_threshold_percent | |
) | |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
return { | |
"predictions": dict(valmap(get_mean_score, filtered_dict)), | |
"pred": predictions, | |
} | |
def predict_language( | |
hub_id: str, | |
config: str | None = None, | |
split: str | None = None, | |
max_request_calls: int = 10, | |
): | |
is_valid = datasets_server_valid_rows(hub_id) | |
if not is_valid: | |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
if not config: | |
config, split = get_first_config_and_split_name(hub_id) | |
info = get_dataset_info(hub_id, config) | |
if info is None: | |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
if dataset_info := info.get("dataset_info"): | |
total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") | |
features = dataset_info.get("features") | |
column_names = set(features.keys()) | |
logger.info(f"Column names: {column_names}") | |
if not set(column_names).intersection(TARGET_COLUMN_NAMES): | |
raise gr.Error( | |
f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}" | |
) | |
for column in TARGET_COLUMN_NAMES: | |
if column in column_names: | |
target_column = column | |
logger.info(f"Using column {target_column} for language detection") | |
break | |
random_rows = get_random_rows( | |
hub_id, total_rows_for_split, 1000, max_request_calls, config, split | |
) | |
logger.info(f"Predicting language for {len(random_rows)} rows") | |
predictions = predict_rows(random_rows, target_column) | |
predictions["hub_id"] = hub_id | |
predictions["config"] = config | |
predictions["split"] = split | |
return predictions | |
inputs = [ | |
gr.Text(label="dataset id"), | |
gr.Textbox( | |
None, | |
label="config", | |
), | |
gr.Textbox(None, label="split"), | |
] | |
interface = gr.Interface(predict_language, inputs=inputs, outputs="json") | |
interface.queue() | |
interface.launch() | |