#!/usr/bin/env python3 import glob import os import random from tqdm import tqdm from datasets import Dataset, DatasetDict, load_dataset def convert(list_of_dicts): res = {} for d in list_of_dicts: for k, v in d.items(): res.setdefault(k, []).append(v) return res v10 = load_dataset("fever", "v1.0") name_lst = ["train", "labelled_dev"] old_to_new_label_map = {"SUPPORTS": "supported", "REFUTES": "refuted"} data_map = {} for name in name_lst: instance_lst = [] for entry in tqdm(v10[name]): id_ = entry["id"] label = entry["label"] claim = entry["claim"] evidence_id = entry["evidence_id"] evidence_wiki_url = entry["evidence_wiki_url"] if evidence_id != -1: assert label in {"SUPPORTS", "REFUTES"} instance = {"id": id_, "label": old_to_new_label_map[label], "claim": claim} instance_lst.append(instance) key = "dev" if name in {"labelled_dev"} else name instance_lst = sorted([dict(t) for t in {tuple(d.items()) for d in instance_lst}], key=lambda d: d["claim"]) label_to_instance_lst = {} for e in instance_lst: if e["label"] not in label_to_instance_lst: label_to_instance_lst[e["label"]] = [] label_to_instance_lst[e["label"]].append(e) min_len = min(len(v) for k, v in label_to_instance_lst.items()) new_instance_lst = [] for k in sorted(label_to_instance_lst.keys()): new_instance_lst += label_to_instance_lst[k][:min_len] random.Random(42).shuffle(new_instance_lst) data_map[key] = new_instance_lst ds_path = "pminervini/hl-fever" task_to_ds_map = {k: Dataset.from_dict(convert(v)) for k, v in data_map.items()} ds_dict = DatasetDict(task_to_ds_map) ds_dict.push_to_hub(ds_path, "v1.0") # breakpoint()