Spaces:
Running
Running
File size: 2,832 Bytes
569f484 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from abc import abstractmethod
from ..smp import *
class TextBaseDataset:
MODALITY = 'TEXT'
DATASET_URL = {}
DATASET_MD5 = {}
def __init__(self, dataset='MMBench', **kwargs):
self.dataset_name = dataset
data = self.load_data(dataset)
data['index'] = [str(x) for x in data['index']]
if np.all([istype(x, int) for x in data['index']]):
data['index'] = [int(x) for x in data['index']]
self.data = data
self.post_build(dataset)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return dict(self.data.iloc[idx])
def prepare_tsv(self, url, file_md5=None):
data_root = LMUDataRoot()
os.makedirs(data_root, exist_ok=True)
update_flag = False
file_name = url.split('/')[-1]
data_path = osp.join(data_root, file_name)
if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
pass
else:
warnings.warn('The dataset tsv is not downloaded')
download_file(url, data_path)
update_flag = True
if file_size(data_path, 'GB') > 1:
local_path = data_path.replace('.tsv', '_local.tsv')
if not osp.exists(local_path) or os.environ.get('FORCE_LOCAL', None) or update_flag:
from ..tools import LOCALIZE
LOCALIZE(data_path, local_path)
data_path = local_path
return load(data_path)
def dump_image(self, line):
return []
def display(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
assert isinstance(line, pd.Series) or isinstance(line, dict)
mmqa_display(line)
# Return a list of dataset names that are supported by this class, can override
@classmethod
def supported_datasets(cls):
return list(cls.DATASET_URL)
# Given the dataset name, return the dataset as a pandas dataframe, can override
def load_data(self, dataset):
url = self.DATASET_URL[dataset]
file_md5 = self.DATASET_MD5[dataset]
return self.prepare_tsv(url, file_md5)
# Post built hook, will be called after the dataset is built, can override
def post_build(self, dataset):
pass
# Given one data record, return the built prompt (a multi-modal message), can override
def build_prompt(self, line):
if isinstance(line, int):
line = self.data.iloc[line]
question = line['question']
msgs = []
msgs.append(dict(type='text', value=question))
return msgs
# Given the prediction file, return the evaluation results in the format of a dictionary or pandas dataframe
@abstractmethod
def evaluate(self, eval_file, **judge_kwargs):
pass
|