|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
from os.path import exists |
|
from os.path import join as pjoin |
|
|
|
from datasets import Dataset, load_dataset, load_from_disk |
|
from tqdm import tqdm |
|
|
|
_CACHE_DIR = "cache_dir" |
|
|
|
|
|
|
|
def load_truncated_dataset( |
|
dataset_name, |
|
config_name=None, |
|
split_name=None, |
|
num_rows=0, |
|
use_streaming=True, |
|
use_auth_token=None, |
|
use_dataset=None, |
|
): |
|
""" |
|
This function loads the first `num_rows` items of a dataset for a |
|
given `config_name` and `split_name`. |
|
When the dataset is streamable, we iterate through the first |
|
`num_rows` examples in streaming mode, write them to a jsonl file, |
|
then create a new dataset from the json. |
|
This is the most direct way to make a Dataset from an IterableDataset |
|
as of datasets version 1.6.1. |
|
Otherwise, we download the full dataset and select the first |
|
`num_rows` items |
|
Args: |
|
dataset_name (string): |
|
dataset id in the dataset library |
|
config_name (string): |
|
dataset configuration |
|
split_name (string): |
|
optional split name, defaults to `train` |
|
num_rows (int): |
|
number of rows to truncate the dataset to, <= 0 means no truncation |
|
use_streaming (bool): |
|
whether to use streaming when the dataset supports it |
|
use_auth_token (string): |
|
HF authentication token to access private datasets |
|
use_dataset (Dataset): |
|
use existing dataset instead of getting one from the hub |
|
Returns: |
|
Dataset: |
|
the truncated dataset as a Dataset object |
|
""" |
|
split_name = "train" if split_name is None else split_name |
|
cache_name = f"{dataset_name.replace('/', '---')}_{'default' if config_name is None else config_name}_{split_name}_{num_rows}" |
|
if use_streaming: |
|
if not exists(pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl")): |
|
iterable_dataset = ( |
|
load_dataset( |
|
dataset_name, |
|
name=config_name, |
|
split=split_name, |
|
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_temp"), |
|
streaming=True, |
|
use_auth_token=use_auth_token, |
|
) |
|
if use_dataset is None |
|
else use_dataset |
|
) |
|
if num_rows > 0: |
|
iterable_dataset = iterable_dataset.take(num_rows) |
|
f = open( |
|
pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl"), "w", encoding="utf-8" |
|
) |
|
for row in tqdm(iterable_dataset): |
|
_ = f.write(json.dumps(row) + "\n") |
|
f.close() |
|
dataset = Dataset.from_json( |
|
pjoin(_CACHE_DIR, "tmp", f"{cache_name}.jsonl"), |
|
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_jsonl"), |
|
) |
|
else: |
|
full_dataset = ( |
|
load_dataset( |
|
dataset_name, |
|
name=config_name, |
|
split=split_name, |
|
use_auth_token=use_auth_token, |
|
cache_dir=pjoin(_CACHE_DIR, "tmp", cache_name + "_temp"), |
|
) |
|
if use_dataset is None |
|
else use_dataset |
|
) |
|
if num_rows > 0: |
|
dataset = full_dataset.select(range(num_rows)) |
|
else: |
|
dataset = full_dataset |
|
return dataset |
|
|
|
|
|
|
|
def extract_features(examples, indices, input_field_path, label_name=None): |
|
""" |
|
This function prepares examples for further processing by: |
|
- returning an "unrolled" list of all the fields denoted by input_field_path |
|
- with the indices corresponding to the example the field item came from |
|
- optionally, the corresponding label is also returned with each field item |
|
Args: |
|
examples (dict): |
|
a dictionary of lists, provided dataset.map with batched=True |
|
indices (list): |
|
a list of indices, provided dataset.map with with_indices=True |
|
input_field_path (tuple): |
|
a tuple indicating the field we want to extract. Can be a singleton |
|
for top-level features (e.g. `("text",)`) or a full path for nested |
|
features (e.g. `("answers", "text")`) to get all answer strings in |
|
SQuAD |
|
label_name (string): |
|
optionally used to align the field items with labels. Currently, |
|
returns the top-most field that has this name, which may fail in some |
|
edge cases |
|
TODO: make it so the label is specified through a full path |
|
Returns: |
|
Dict: |
|
a dictionary of lists, used by dataset.map with batched=True. |
|
labels are all None if label_name!=None but label_name is not found |
|
TODO: raised an error if label_name is specified but not found |
|
""" |
|
top_name = input_field_path[0] |
|
if label_name is not None and label_name in examples: |
|
item_list = [ |
|
{"index": i, "label": label, "items": items} |
|
for i, items, label in zip( |
|
indices, examples[top_name], examples[label_name] |
|
) |
|
] |
|
else: |
|
item_list = [ |
|
{"index": i, "label": None, "items": items} |
|
for i, items in zip(indices, examples[top_name]) |
|
] |
|
for field_name in input_field_path[1:]: |
|
new_item_list = [] |
|
for dct in item_list: |
|
if label_name is not None and label_name in dct["items"]: |
|
if isinstance(dct["items"][field_name], list): |
|
new_item_list += [ |
|
{"index": dct["index"], "label": label, "items": next_item} |
|
for next_item, label in zip( |
|
dct["items"][field_name], dct["items"][label_name] |
|
) |
|
] |
|
else: |
|
new_item_list += [ |
|
{ |
|
"index": dct["index"], |
|
"label": dct["items"][label_name], |
|
"items": dct["items"][field_name], |
|
} |
|
] |
|
else: |
|
if isinstance(dct["items"][field_name], list): |
|
new_item_list += [ |
|
{ |
|
"index": dct["index"], |
|
"label": dct["label"], |
|
"items": next_item, |
|
} |
|
for next_item in dct["items"][field_name] |
|
] |
|
else: |
|
new_item_list += [ |
|
{ |
|
"index": dct["index"], |
|
"label": dct["label"], |
|
"items": dct["items"][field_name], |
|
} |
|
] |
|
item_list = new_item_list |
|
res = ( |
|
{ |
|
"ids": [dct["index"] for dct in item_list], |
|
"field": [dct["items"] for dct in item_list], |
|
} |
|
if label_name is None |
|
else { |
|
"ids": [dct["index"] for dct in item_list], |
|
"field": [dct["items"] for dct in item_list], |
|
"label": [dct["label"] for dct in item_list], |
|
} |
|
) |
|
return res |
|
|
|
|
|
|
|
def prepare_clustering_dataset( |
|
dataset_name, |
|
input_field_path, |
|
label_name=None, |
|
config_name=None, |
|
split_name=None, |
|
num_rows=0, |
|
use_streaming=True, |
|
use_auth_token=None, |
|
cache_dir=_CACHE_DIR, |
|
use_dataset=None, |
|
): |
|
""" |
|
This function loads the first `num_rows` items of a dataset for a |
|
given `config_name` and `split_name`, and extracts all instances of a field |
|
of interest denoted by `input_field_path` along with the indices of the |
|
examples the instances came from and optionall their labels (`label_name`) |
|
in the original dataset |
|
Args: |
|
dataset_name (string): |
|
dataset id in the dataset library |
|
input_field_path (tuple): |
|
a tuple indicating the field we want to extract. Can be a singleton |
|
for top-level features (e.g. `("text",)`) or a full path for nested |
|
features (e.g. `("answers", "text")`) to get all answer strings in |
|
SQuAD |
|
label_name (string): |
|
optionally used to align the field items with labels. Currently, |
|
returns the top-most field that has this name, which fails in edge cases |
|
config_name (string): |
|
dataset configuration |
|
split_name (string): |
|
optional split name, defaults to `train` |
|
num_rows (int): |
|
number of rows to truncate the dataset to, <= 0 means no truncation |
|
use_streaming (bool): |
|
whether to use streaming when the dataset supports it |
|
use_auth_token (string): |
|
HF authentication token to access private datasets |
|
use_dataset (Dataset): |
|
use existing dataset instead of getting one from the hub |
|
Returns: |
|
Dataset: |
|
the extracted dataset as a Dataset object. Note that if there is more |
|
than one instance of the field per example in the original dataset |
|
(e.g. multiple answers per QA example), the returned dataset will |
|
have more than `num_rows` rows |
|
string: |
|
the path to the newsly created dataset directory |
|
""" |
|
cache_path = [ |
|
cache_dir, |
|
dataset_name.replace("/", "---"), |
|
f"{'default' if config_name is None else config_name}", |
|
f"{'train' if split_name is None else split_name}", |
|
f"field-{'->'.join(input_field_path)}-label-{label_name}", |
|
f"{num_rows}_rows", |
|
"features_dset", |
|
] |
|
if exists(pjoin(*cache_path)): |
|
pre_clustering_dset = load_from_disk(pjoin(*cache_path)) |
|
else: |
|
truncated_dset = load_truncated_dataset( |
|
dataset_name, |
|
config_name, |
|
split_name, |
|
num_rows, |
|
use_streaming, |
|
use_auth_token, |
|
use_dataset, |
|
) |
|
|
|
def batch_func(examples, indices): |
|
return extract_features(examples, indices, input_field_path, label_name) |
|
|
|
pre_clustering_dset = truncated_dset.map( |
|
batch_func, |
|
remove_columns=truncated_dset.features, |
|
batched=True, |
|
with_indices=True, |
|
) |
|
for i in range(1, len(cache_path) - 1): |
|
if not exists(pjoin(*cache_path[:i])): |
|
os.mkdir(pjoin(*cache_path[:i])) |
|
pre_clustering_dset.save_to_disk(pjoin(*cache_path)) |
|
return pre_clustering_dset, pjoin(*cache_path) |
|
|