from functools import partial import pandas as pd from datasets import load_dataset from tqdm import tqdm from perplexity import KenlmModel def hub_dataset_to_dataframe(path: str, name: str, split: str, sample: int, text_column: str, model: KenlmModel, seed: int = 0) -> pd.DataFrame: load_dataset_fn = partial(load_dataset, path=path) if name: load_dataset_fn = partial(load_dataset_fn, name=name) if split: load_dataset_fn = partial(load_dataset_fn, split=split) dataset = ( load_dataset_fn(streaming=True) .shuffle(buffer_size=10000, seed=seed) .map(lambda x: {text_column: x[text_column], "perplexity": model.get_perplexity(x[text_column])}) ) instances = [] count = 0 for instance in tqdm(dataset, total=sample): instances.append(instance) count += 1 if count == sample: break return pd.DataFrame(instances)