| import random |
| import sys |
| import unittest |
| import warnings |
| from os import environ |
|
|
| from datasets import Dataset, DatasetDict |
| from mmengine.config import read_base |
| from tqdm import tqdm |
|
|
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
| warnings.filterwarnings('ignore', category=DeprecationWarning) |
|
|
|
|
| def reload_datasets(): |
| modules_to_remove = [ |
| module_name for module_name in sys.modules |
| if module_name.startswith('configs.datasets') |
| ] |
|
|
| for module_name in modules_to_remove: |
| del sys.modules[module_name] |
|
|
| with read_base(): |
| from configs.datasets.ceval.ceval_gen import ceval_datasets |
| from configs.datasets.gsm8k.gsm8k_gen import gsm8k_datasets |
| from configs.datasets.cmmlu.cmmlu_gen import cmmlu_datasets |
| from configs.datasets.ARC_c.ARC_c_gen import ARC_c_datasets |
| from configs.datasets.ARC_e.ARC_e_gen import ARC_e_datasets |
| from configs.datasets.humaneval.humaneval_gen import humaneval_datasets |
| from configs.datasets.humaneval.humaneval_repeat10_gen_8e312c import humaneval_datasets as humaneval_repeat10_datasets |
| from configs.datasets.race.race_ppl import race_datasets |
| from configs.datasets.commonsenseqa.commonsenseqa_gen import commonsenseqa_datasets |
| |
| from configs.datasets.mmlu.mmlu_gen import mmlu_datasets |
| from configs.datasets.strategyqa.strategyqa_gen import strategyqa_datasets |
| from configs.datasets.bbh.bbh_gen import bbh_datasets |
| from configs.datasets.Xsum.Xsum_gen import Xsum_datasets |
| from configs.datasets.winogrande.winogrande_gen import winogrande_datasets |
| from configs.datasets.winogrande.winogrande_ll import winogrande_datasets as winogrande_ll_datasets |
| from configs.datasets.winogrande.winogrande_5shot_ll_252f01 import winogrande_datasets as winogrande_5shot_ll_datasets |
| from configs.datasets.obqa.obqa_gen import obqa_datasets |
| from configs.datasets.obqa.obqa_ppl_6aac9e import obqa_datasets as obqa_ppl_datasets |
| from configs.datasets.agieval.agieval_gen import agieval_datasets as agieval_v2_datasets |
| |
| from configs.datasets.siqa.siqa_gen import siqa_datasets as siqa_v2_datasets |
| from configs.datasets.siqa.siqa_gen_18632c import siqa_datasets as siqa_v3_datasets |
| from configs.datasets.siqa.siqa_ppl_42bc6e import siqa_datasets as siqa_ppl_datasets |
| from configs.datasets.storycloze.storycloze_gen import storycloze_datasets |
| from configs.datasets.storycloze.storycloze_ppl import storycloze_datasets as storycloze_ppl_datasets |
| from configs.datasets.summedits.summedits_gen import summedits_datasets as summedits_v2_datasets |
| |
| from configs.datasets.hellaswag.hellaswag_gen import hellaswag_datasets as hellaswag_v2_datasets |
| from configs.datasets.hellaswag.hellaswag_10shot_gen_e42710 import hellaswag_datasets as hellaswag_ice_datasets |
| from configs.datasets.hellaswag.hellaswag_ppl_9dbb12 import hellaswag_datasets as hellaswag_v1_datasets |
| from configs.datasets.hellaswag.hellaswag_ppl_a6e128 import hellaswag_datasets as hellaswag_v3_datasets |
| from configs.datasets.mbpp.mbpp_gen import mbpp_datasets as mbpp_v1_datasets |
| from configs.datasets.mbpp.mbpp_passk_gen_830460 import mbpp_datasets as mbpp_v2_datasets |
| from configs.datasets.mbpp.sanitized_mbpp_gen_830460 import sanitized_mbpp_datasets |
| from configs.datasets.nq.nq_gen import nq_datasets |
| from configs.datasets.lcsts.lcsts_gen import lcsts_datasets |
| from configs.datasets.math.math_gen import math_datasets |
| from configs.datasets.piqa.piqa_gen import piqa_datasets as piqa_v2_datasets |
| from configs.datasets.piqa.piqa_ppl import piqa_datasets as piqa_v1_datasets |
| from configs.datasets.piqa.piqa_ppl_0cfff2 import piqa_datasets as piqa_v3_datasets |
| from configs.datasets.lambada.lambada_gen import lambada_datasets |
| from configs.datasets.tydiqa.tydiqa_gen import tydiqa_datasets |
| from configs.datasets.GaokaoBench.GaokaoBench_gen import GaokaoBench_datasets |
| from configs.datasets.GaokaoBench.GaokaoBench_mixed import GaokaoBench_datasets as GaokaoBench_mixed_datasets |
| from configs.datasets.GaokaoBench.GaokaoBench_no_subjective_gen_4c31db import GaokaoBench_datasets as GaokaoBench_no_subjective_datasets |
| from configs.datasets.triviaqa.triviaqa_gen import triviaqa_datasets |
| from configs.datasets.triviaqa.triviaqa_wiki_1shot_gen_20a989 import triviaqa_datasets as triviaqa_wiki_1shot_datasets |
| |
| from configs.datasets.CLUE_cmnli.CLUE_cmnli_gen import cmnli_datasets |
| from configs.datasets.CLUE_cmnli.CLUE_cmnli_ppl import cmnli_datasets as cmnli_ppl_datasets |
| from configs.datasets.CLUE_ocnli.CLUE_ocnli_gen import ocnli_datasets |
| |
| from configs.datasets.ceval.ceval_clean_ppl import ceval_datasets as ceval_clean_datasets |
| from configs.datasets.ARC_c.ARC_c_clean_ppl import ARC_c_datasets as ARC_c_clean_datasets |
| from configs.datasets.mmlu.mmlu_clean_ppl import mmlu_datasets as mmlu_clean_datasets |
| from configs.datasets.hellaswag.hellaswag_clean_ppl import hellaswag_datasets as hellaswag_clean_datasets |
| from configs.datasets.FewCLUE_ocnli_fc.FewCLUE_ocnli_fc_gen import ocnli_fc_datasets |
|
|
| return sum((v for k, v in locals().items() if k.endswith('_datasets')), []) |
|
|
|
|
| def load_datasets_conf(source): |
| environ['DATASET_SOURCE'] = source |
| datasets_conf = reload_datasets() |
| return datasets_conf |
|
|
|
|
| def load_datasets(source, conf): |
| environ['DATASET_SOURCE'] = source |
| if 'lang' in conf: |
| dataset = conf['type'].load(path=conf['path'], lang=conf['lang']) |
| return dataset |
| if 'setting_name' in conf: |
| dataset = conf['type'].load(path=conf['path'], |
| name=conf['name'], |
| setting_name=conf['setting_name']) |
| return dataset |
| if 'name' in conf: |
| dataset = conf['type'].load(path=conf['path'], name=conf['name']) |
| return dataset |
|
|
| if 'local_mode' in conf: |
| dataset = conf['type'].load(path=conf['path'], local_mode=conf['local_mode']) |
| return dataset |
| try: |
| dataset = conf['type'].load(path=conf['path']) |
| except Exception: |
| dataset = conf['type'].load(**conf) |
| return dataset |
|
|
|
|
| def clean_string(value): |
| """Helper function to clean and normalize string data. |
| |
| It strips leading and trailing whitespace and replaces multiple whitespace |
| characters with a single space. |
| """ |
| if isinstance(value, str): |
| return ' '.join(value.split()) |
| return value |
|
|
|
|
| class TestingLocalDatasets(unittest.TestCase): |
|
|
| def test_datasets(self): |
| |
| |
| local_datasets_conf = load_datasets_conf('Local') |
| |
| |
| successful_comparisons = [] |
| failed_comparisons = [] |
| |
| def compare_datasets(local_conf): |
| |
| local_dataset = load_datasets('Local', local_conf) |
| |
| local_path_name = f"{local_conf.get('path')}/{local_conf.get('name', '')}\t{local_conf.get('lang', '')}" |
| |
| |
| |
| try: |
| |
| local_dataset = load_datasets('Local', local_conf) |
| |
| return 'success', f'{local_path_name}' |
| except Exception as exception: |
| |
| return 'failure', f'can\'t load {local_path_name}' |
| |
| with ThreadPoolExecutor(16) as executor: |
| futures = { |
| executor.submit(compare_datasets, local_conf): local_conf |
| for local_conf in local_datasets_conf |
| } |
| |
| for future in tqdm(as_completed(futures), total=len(futures)): |
| result, message = future.result() |
| if result == 'success': |
| successful_comparisons.append(message) |
| else: |
| failed_comparisons.append(message) |
| |
| |
| total_datasets = len(local_datasets_conf) |
| print(f"All {total_datasets} datasets") |
| print(f"OK {len(successful_comparisons)} datasets") |
| for success in successful_comparisons: |
| print(f" {success}") |
| print(f"Fail {len(failed_comparisons)} datasets") |
| for failure in failed_comparisons: |
| print(f" {failure}") |
|
|
|
|
| def _check_data(ms_dataset: Dataset | DatasetDict, |
| oc_dataset: Dataset | DatasetDict, |
| sample_size): |
| assert type(ms_dataset) == type( |
| oc_dataset |
| ), f'Dataset type not match: {type(ms_dataset)} != {type(oc_dataset)}' |
|
|
| |
| if isinstance(oc_dataset, DatasetDict): |
| assert ms_dataset.keys() == oc_dataset.keys( |
| ), f'DatasetDict not match: {ms_dataset.keys()} != {oc_dataset.keys()}' |
|
|
| for key in ms_dataset.keys(): |
| _check_data(ms_dataset[key], oc_dataset[key], sample_size=sample_size) |
|
|
| elif isinstance(oc_dataset, Dataset): |
| |
| assert set(ms_dataset.column_names) == set( |
| oc_dataset.column_names |
| ), f'Column names do not match: {ms_dataset.column_names} != {oc_dataset.column_names}' |
|
|
| |
| assert len(ms_dataset) == len( |
| oc_dataset |
| ), f'Number of rows do not match: {len(ms_dataset)} != {len(oc_dataset)}' |
|
|
| |
| sample_indices = random.sample(range(len(ms_dataset)), |
| min(sample_size, len(ms_dataset))) |
|
|
| for i, idx in enumerate(sample_indices): |
| for col in ms_dataset.column_names: |
| ms_value = clean_string(str(ms_dataset[col][idx])) |
| oc_value = clean_string(str(oc_dataset[col][idx])) |
| try: |
| assert ms_value == oc_value, f"Value mismatch in column '{col}', index {idx}: {ms_value} != {oc_value}" |
| except AssertionError as e: |
| print(f"Assertion failed for column '{col}', index {idx}") |
| print(f"ms_data: {ms_dataset[idx]}") |
| print(f'oc_data: {oc_dataset[idx]}') |
| print(f'ms_value: {ms_value} ({type(ms_value)})') |
| print(f'oc_value: {oc_value} ({type(oc_value)})') |
| raise e |
| else: |
| raise ValueError(f'Datasets type not supported {type(ms_dataset)}') |
|
|
|
|
| if __name__ == '__main__': |
| sample_size = 100 |
| unittest.main() |
|
|