dataset-tldr / app.py
davanstrien's picture
davanstrien HF staff
Refactor predict_language function and update interface inputs and outputs
cee1d41
raw
history blame
7.54 kB
import gradio as gr
from httpx import Client
import random
import os
import fasttext
from huggingface_hub import hf_hub_download
from typing import Union
from typing import Iterator
from dotenv import load_dotenv
from toolz import groupby, valmap, concat
from statistics import mean
from httpx import Timeout
from huggingface_hub.utils import logging
logger = logging.get_logger(__name__)
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID"
headers = {
"authorization": f"Bearer ${HF_TOKEN}",
}
timeout = Timeout(60, read=120)
client = Client(headers=headers, timeout=timeout)
# non exhaustive list of columns that might contain text which can be used for language detection
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
TARGET_COLUMN_NAMES = {
"text",
"input",
"tokens",
"prompt",
"instruction",
"sentence_1",
"question",
"sentence2",
"answer",
"sentence",
"response",
"context",
"query",
"chosen",
"rejected",
}
def datasets_server_valid_rows(hub_id: str):
resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}")
resp.raise_for_status()
return resp.json()["viewer"]
def get_first_config_and_split_name(hub_id: str):
resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}")
resp.raise_for_status()
data = resp.json()
return data["splits"][0]["config"], data["splits"][0]["split"]
def get_dataset_info(hub_id: str, config: str | None = None):
if config is None:
config = get_first_config_and_split_name(hub_id)
if config is None:
return None
else:
config = config[0]
resp = client.get(
f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
)
resp.raise_for_status()
return resp.json()
def get_random_rows(
hub_id,
total_length,
number_of_rows,
max_request_calls,
config="default",
split="train",
):
rows = []
rows_per_call = min(
number_of_rows // max_request_calls, total_length // max_request_calls
)
rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
offset = random.randint(0, total_length - rows_per_call)
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
response = client.get(url)
if response.status_code == 200:
data = response.json()
batch_rows = data.get("rows")
rows.extend(batch_rows)
else:
print(f"Failed to fetch data: {response.status_code}")
print(url)
if len(rows) >= number_of_rows:
break
return [row.get("row") for row in rows]
def load_model(repo_id: str) -> fasttext.FastText._FastText:
model_path = hf_hub_download(repo_id, filename="model.bin")
return fasttext.load_model(model_path)
# def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str):
# pass
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
for row in rows:
if isinstance(row, str):
# split on lines and remove empty lines
line = row.split("\n")
for line in line:
if line:
yield line
elif isinstance(row, list):
try:
line = " ".join(row)
if len(line) < min_length:
continue
else:
yield line
except TypeError:
continue
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
# model = load_model(DEFAULT_FAST_TEXT_MODEL)
model = fasttext.load_model(
hf_hub_download("facebook/fasttext-language-identification", "model.bin")
)
def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
predictions = model.predict(inputs, k=k)
return [
{"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob}
for label, prob in zip(predictions[0], predictions[1])
]
def get_label(x):
return x.get("label")
def get_mean_score(preds):
return mean([pred.get("score") for pred in preds])
def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
"""Filter a dict to include items whose value is above `threshold_percent`"""
total = sum(counts_dict.values())
threshold = total * threshold_percent
return {k for k, v in counts_dict.items() if v >= threshold}
def predict_rows(rows, target_column, language_threshold_percent=0.2):
rows = (row.get(target_column) for row in rows)
rows = (row for row in rows if row is not None)
rows = list(yield_clean_rows(rows))
predictions = [model_predict(row) for row in rows]
predictions = [pred for pred in predictions if pred is not None]
predictions = list(concat(predictions))
predictions_by_lang = groupby(get_label, predictions)
langues_counts = valmap(len, predictions_by_lang)
keys_to_keep = filter_by_frequency(
langues_counts, threshold_percent=language_threshold_percent
)
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
return {
"predictions": dict(valmap(get_mean_score, filtered_dict)),
"pred": predictions,
}
def predict_language(
hub_id: str,
config: str | None = None,
split: str | None = None,
max_request_calls: int = 10,
):
is_valid = datasets_server_valid_rows(hub_id)
if not is_valid:
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
if not config:
config, split = get_first_config_and_split_name(hub_id)
info = get_dataset_info(hub_id, config)
if info is None:
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
if dataset_info := info.get("dataset_info"):
total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples")
features = dataset_info.get("features")
column_names = set(features.keys())
logger.info(f"Column names: {column_names}")
if not set(column_names).intersection(TARGET_COLUMN_NAMES):
raise gr.Error(
f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}"
)
for column in TARGET_COLUMN_NAMES:
if column in column_names:
target_column = column
logger.info(f"Using column {target_column} for language detection")
break
random_rows = get_random_rows(
hub_id, total_rows_for_split, 1000, max_request_calls, config, split
)
logger.info(f"Predicting language for {len(random_rows)} rows")
predictions = predict_rows(random_rows, target_column)
predictions["hub_id"] = hub_id
predictions["config"] = config
predictions["split"] = split
return predictions
inputs = [
gr.Text(label="dataset id"),
gr.Textbox(
None,
label="config",
),
gr.Textbox(None, label="split"),
]
interface = gr.Interface(predict_language, inputs=inputs, outputs="json")
interface.queue()
interface.launch()