|
|
|
|
|
import glob |
|
import os |
|
|
|
import random |
|
from tqdm import tqdm |
|
|
|
from datasets import Dataset, DatasetDict, load_dataset |
|
|
|
|
|
def convert(list_of_dicts): |
|
res = {} |
|
for d in list_of_dicts: |
|
for k, v in d.items(): |
|
res.setdefault(k, []).append(v) |
|
return res |
|
|
|
|
|
v10 = load_dataset("fever", "v1.0") |
|
name_lst = ["train", "labelled_dev"] |
|
|
|
old_to_new_label_map = {"SUPPORTS": "supported", "REFUTES": "refuted"} |
|
|
|
data_map = {} |
|
|
|
for name in name_lst: |
|
instance_lst = [] |
|
|
|
for entry in tqdm(v10[name]): |
|
id_ = entry["id"] |
|
label = entry["label"] |
|
claim = entry["claim"] |
|
|
|
evidence_id = entry["evidence_id"] |
|
evidence_wiki_url = entry["evidence_wiki_url"] |
|
|
|
if evidence_id != -1: |
|
assert label in {"SUPPORTS", "REFUTES"} |
|
|
|
instance = {"id": id_, "label": old_to_new_label_map[label], "claim": claim} |
|
instance_lst.append(instance) |
|
|
|
key = "dev" if name in {"labelled_dev"} else name |
|
|
|
instance_lst = sorted([dict(t) for t in {tuple(d.items()) for d in instance_lst}], key=lambda d: d["claim"]) |
|
|
|
label_to_instance_lst = {} |
|
for e in instance_lst: |
|
if e["label"] not in label_to_instance_lst: |
|
label_to_instance_lst[e["label"]] = [] |
|
label_to_instance_lst[e["label"]].append(e) |
|
|
|
min_len = min(len(v) for k, v in label_to_instance_lst.items()) |
|
|
|
new_instance_lst = [] |
|
for k in sorted(label_to_instance_lst.keys()): |
|
new_instance_lst += label_to_instance_lst[k][:min_len] |
|
|
|
random.Random(42).shuffle(new_instance_lst) |
|
data_map[key] = new_instance_lst |
|
|
|
ds_path = "pminervini/hl-fever" |
|
|
|
task_to_ds_map = {k: Dataset.from_dict(convert(v)) for k, v in data_map.items()} |
|
ds_dict = DatasetDict(task_to_ds_map) |
|
|
|
ds_dict.push_to_hub(ds_path, "v1.0") |
|
|
|
|
|
|