NVLM-D-72B / eval /eval_dataset.py
boxinw@nvidia.com
Add benchmark evaluation scripts
b925209
raw
history blame
29.6 kB
import json
import os
import sys
import time
import yaml
import spacy
import ast
from PIL import Image
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
from io import BytesIO
import base64
from anls import anls_score
import torch
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import torchvision.transforms as T
from eval import conversation as conversation_lib
from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
from eval.mmmu_utils import evaluate as evaluate_mmmu
from torchvision.transforms.functional import InterpolationMode
from datasets import load_dataset, concatenate_datasets
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image, input_size=448, max_num=6, decoded=False):
if not decoded:
image = Image.open(image).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def get_anls_score(pred, gold_labels, threshold, llava_eval=False):
values = []
for answer in gold_labels:
# preprocess both the answers - gt and prediction
gt_answer = ' '.join(answer.strip().lower().split())
det_answer = ' '.join(pred.strip().lower().split())
dist = levenshtein_distance(gt_answer, det_answer)
length = max(len(answer.upper()), len(pred.upper()))
values.append(0.0 if length == 0 else float(dist) / float(length))
question_result = 1 - min(values)
if llava_eval:
question_result = 1.0 if question_result >= threshold else 0.0
else:
if (question_result < threshold):
question_result = 0
return question_result
def isNumber(n: str):
try:
float(n)
return True
except ValueError:
return False
class COCOEvalDataset(Dataset):
def __init__(self, args, img_dir, subset=None):
self.args = args
self.img_files = sorted(glob(os.path.join(img_dir, "*")))
if subset:
self.img_files = self.img_files[:subset]
self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files]
def __len__(self):
return len(self.img_files)
def __getitem__(self, idx):
img_path = self.img_files[idx]
img = load_image(img_path, max_num=6).to(torch.bfloat16)
return self.image_ids[idx], img
class Flickr30KEvalDataset(Dataset):
def __init__(self, args, img_dir, subset=None):
self.args = args
self.img_dir = img_dir
self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8'))
if subset:
self.test_samples = self.test_samples[:subset]
def __len__(self):
return len(self.test_samples)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", ""))
return image_id, img
class VQAv2EvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["question"]
answer = self.gt[idx]["answer"]
return img, question_id, question, answer
class TextVQAEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg')
if not os.path.exists(img_path):
img_path = img_path.replace('.jpg', '.png')
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["question"]
answer = self.gt[idx]["answers"]
return img, question_id, question, answer
class GQAEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))
self.gt = [{
"question_id": int(k),
"image": v['imageId'] + ".jpg",
"question": v['question'],
"answer": v['answer']
} for k, v in self.gt.items()]
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["question"]
answer = self.gt[idx]["answer"]
return img, question_id, question, [answer]
class ChartQAEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))
for i in range(len(self.gt)):
self.gt[i]['question_id'] = i
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["query"]
answer = self.gt[idx]["label"]
return img, question_id, question, [answer]
class OKVQAEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, question_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations']
self.questions = json.load(open(question_path, 'r'))['questions']
if subset:
self.gt = self.gt[:subset]
qid2q = {q['question_id']: q['question'] for q in self.questions}
for ann in self.gt:
ann['answers'] = [ans['answer'] for ans in ann['answers']]
ann['question'] = qid2q[ann['question_id']]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_id = str(self.gt[idx]["image_id"])
img_id = '0' * (12 - len(img_id)) + img_id
img_file_name = f"COCO_val2014_{img_id}.jpg"
img_path = os.path.join(self.img_dir, img_file_name)
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["question"]
answer = self.gt[idx]["answers"]
return img, question_id, question, answer
class DocVQAEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, split='val', subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
if subset:
self.gt = self.gt[:subset]
self.split = split
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["questionId"]
question = self.gt[idx]["question"]
if self.split == 'val':
answer = self.gt[idx]["answers"]
else:
answer = ['']
return img, question_id, question, answer
class OCRBenchEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]['image_path'])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
dataset_name = self.gt[idx]["dataset_name"]
question_id = f"{idx}"
question = self.gt[idx]["question"]
answer = self.gt[idx]["answers"]
data_type = self.gt[idx]["type"]
return img, question_id, question, answer, dataset_name, data_type
class AI2DiagramEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
with open(gt_path, 'r') as json_file:
json_list = list(json_file)
self.gt = [json.loads(json_str) for json_str in json_list]
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["question"]
answer = self.gt[idx]["answer"]
return img, question_id, question, answer
class AI2DiagramNoMaskEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
with open(gt_path, 'r') as json_file:
json_list = list(json_file)
self.gt = [json.loads(json_str) for json_str in json_list]
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES")
img_path = os.path.join(self.img_dir, img_file_name)
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = self.gt[idx]["question_id"]
question = self.gt[idx]["question"]
answer = self.gt[idx]["answer"]
return img, question_id, question, answer
class RealworldQAEvalDataset(Dataset):
def __init__(self, args, img_dir, gt_path, subset=None):
self.args = args
self.img_dir = img_dir
self.gt = json.load(open(gt_path, encoding='utf-8'))
if subset:
self.gt = self.gt[:subset]
def __len__(self):
return len(self.gt)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question_id = int(self.gt[idx]['image'].replace(".webp", ""))
question = self.gt[idx]["question"]
if self.gt[idx]['question_type'] == "multi-choice":
choices = self.gt[idx]["choices"]
start_chr = 'A'
choices_str = ''
index2ans = {}
all_choices = []
for choice in choices:
all_choices.append(start_chr)
index2ans[start_chr] = choice
choices_str += f"{start_chr}. {choice}\n"
start_chr = chr(ord(start_chr) + 1)
question = question + '\n' + choices_str
question = question + "Answer with the option's letter from the given choices directly."
answer = chr(ord('A') + self.gt[idx]['correct_choice_index'])
else:
question = question + "\nAnswer the question using a single word or phrase."
answer = self.gt[idx]['answer']
return img, question_id, question, [answer]
class MathVistaEvalDataset(Dataset):
def __init__(self, args, task_cfg, gt_path=None):
self.args = args
self.task_cfg = task_cfg
self.dataset = load_dataset("AI4Math/MathVista")['testmini']
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img = self.dataset[idx]['decoded_image']
img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
question_id = self.dataset[idx]["pid"]
question = self.dataset[idx]["question"]
question_type = self.dataset[idx]["question_type"] # free_form or multi_choice
query = self.dataset[idx]["query"]
choices = self.dataset[idx]["choices"]
answer = self.dataset[idx]["answer"]
if question_type == 'multi_choice':
start_chr = 'A'
choices_str = ''
index2ans = {}
all_choices = []
for choice in choices:
all_choices.append(start_chr)
index2ans[start_chr] = choice
choices_str += f"{start_chr}. {choice}\n"
start_chr = chr(ord(start_chr) + 1)
question = question + '\n' + choices_str
question = question + "Answer with the option's letter from the given choices directly."
answer = chr(ord('A') + choices.index(answer))
else:
question = query.replace("Hint: ", "")
index2ans = {}
all_choices = []
return img, question_id, question_type, question, answer, str(index2ans), str(all_choices)
def construct_prompt_for_fewshot(sample):
config = {
"task_instructions": "",
"multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.",
"short_ans_example_format": "{}\nAnswer the question using a single word or phrase."
}
question = sample['question'].strip()
options = eval(sample['options'])
example = ""
if sample['question_type'] == 'multiple-choice':
start_chr = 'A'
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
empty_prompt_sample_structure = config['multi_choice_example_format']
empty_prompt = empty_prompt_sample_structure.format(question, example)
res_dict = {'type': 'multichoice'}
res_dict['index2ans'] = index2ans
res_dict['correct_choice'] = sample['answer']
res_dict['all_choices'] = prediction_range
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
else:
empty_prompt_sample_structure = config['short_ans_example_format']
empty_prompt = empty_prompt_sample_structure.format(question)
res_dict = {'type': 'open'}
res_dict['empty_prompt'] = empty_prompt
if config['task_instructions']:
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
else:
res_dict['final_input_prompt'] = empty_prompt
res_dict['gt_content'] = sample['answer']
res_dict.update(sample)
return res_dict
def process_image_tag(q):
q = q.strip()
# heuristic way of removing <image 1>
if q == '<image 1>':
q = 'Answer the question in the image.'
elif ':<image 1>' in q:
q = q.replace(':<image 1>', ' in the image. ')
q = q.strip()
elif ': <image 1>' in q:
q = q.replace(': <image 1>', ' in the image. ')
q = q.strip()
elif '.<image 1>' in q or '. <image 1>' in q:
q_list = q.split('<image 1>')
q_list = [part.strip() for part in q_list if part.strip() != '']
q = ' '.join(q_list)
elif q.startswith('<image 1> '):
if q[10].isupper():
q = q.replace('<image 1>', '')
else:
q = q.replace('<image 1>', 'The image')
q = q.strip()
elif q.startswith('<image 1>'):
q = q.replace('<image 1>', '')
elif q.endswith('<image 1>?'):
q = q.replace('<image 1>', 'the image')
elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'):
q = q.replace('<image 1>', '')
q = q.strip()
elif ' <image 1> ' in q:
q = q.replace('<image 1>', 'the image')
elif ' <image 1>' in q:
q = q.replace('<image 1>', 'the image')
elif '()<image 1>' in q:
q = q.replace('()<image 1>', '')
elif '(<image 1>)' in q:
q = q.replace('(<image 1>)', '')
elif '<image 1>.' in q:
q = q.replace("<image 1>.", ". ")
else:
q = q.replace("<image 1>", ". ")
q = q.strip()
# remove <image 2> to <image 8>
for i in range(2, 8):
q = q.replace(f"<image {i}>", "")
return q
class MMMUProEvalDataset(Dataset):
def __init__(self, args, task_cfg, subset=None):
self.args = args
self.task_cfg = task_cfg
sub_dataset_list = []
# load_dataset will throw error if split is 'dev'
# 'dev' is part of the 'validation' and we need to manually split them
MMMU_path = "MMMU/MMMU_Pro"
_split = "test"
self.dataset = load_dataset(MMMU_path, "standard", split=_split)
if subset:
self.dataset = self.dataset[:subset]
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# ===== single-image =====
sample = self.dataset[idx]
sample = process_single_sample_pro(sample)
sample = construct_prompt_pro(sample, self.task_cfg)
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
# img = img.reshape(-1, 3, self.args.img_h, self.args.img_w)
question_id = sample['id']
question = sample['final_input_prompt']
answer = sample['answer']
question = process_image_tag(question)
question = self.task_cfg['default_image_token'] + '\n' + question
if sample['question_type'] == 'multiple-choice':
index2ans = sample['index2ans']
all_choices = sample['all_choices']
else:
index2ans = {}
all_choices = []
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
(all_choices)
class MMMUEvalDataset(Dataset):
def __init__(self, args, task_cfg, subset=None, start_idx=None):
self.args = args
self.task_cfg = task_cfg
sub_dataset_list = []
# load_dataset will throw error if split is 'dev'
# 'dev' is part of the 'validation' and we need to manually split them
MMMU_path = "MMMU/MMMU"
_split = "test" if task_cfg["split"] == "test" else "validation"
for subject in CAT_SHORT2LONG.values():
sub_dataset = load_dataset(
MMMU_path, subject,
split=_split,
)
sub_dataset_list.append(sub_dataset)
dataset = concatenate_datasets(sub_dataset_list)
if task_cfg["split"] != "test":
dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])]
# dataset = [s for s in dataset if s['image_2'] is not None][1:]
self.dataset = dataset
if subset:
self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))]
print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}")
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# ===== single-image =====
sample = self.dataset[idx]
sample = process_single_sample(sample)
sample = construct_prompt(sample, self.task_cfg)
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
question_id = sample['id']
question = sample['final_input_prompt']
answer = sample['answer']
question = process_image_tag(question)
question = self.task_cfg['default_image_token'] + '\n' + question
if sample['question_type'] == 'multiple-choice':
index2ans = sample['index2ans']
all_choices = sample['all_choices']
else:
index2ans = {}
all_choices = []
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
(all_choices)
class VizWizEvalDataset(Dataset):
def __init__(self, args, img_dir, question_path, subset=None):
self.args = args
self.img_dir = img_dir
self.questions = json.load(open(question_path, encoding='utf-8'))
def __len__(self):
return len(self.questions)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.questions[idx]["image"])
img = load_image(img_path, max_num=6).to(torch.bfloat16)
question = self.questions[idx]["question"]
question_id = self.questions[idx]["image"]
return img, question_id, question
class MMBenchEvalDataset(Dataset):
def __init__(self, args, gt_path, subset=None):
self.args = args
df = pd.read_csv(gt_path, sep='\t')
self.dataset = []
for i, row in df.iterrows():
choices = []
for choice in ['A', 'B', 'C', 'D']:
if str(row[choice]) != 'nan':
choices.append(row[choice])
this_sample = {
'index': row['index'],
'question': row['question'],
'hint': row['hint'],
'category': row['category'],
'image': Image.open(BytesIO(base64.b64decode(row['image']))),
'choices': choices
}
# Only dev set gives the ground truth answer
if 'answer' in row.keys():
this_sample['answer'] = row['answer']
else:
this_sample['answer'] = ''
self.dataset.append(this_sample)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
question = self.dataset[idx]["question"]
hint = self.dataset[idx]["hint"]
question_id = self.dataset[idx]["index"]
choices = self.dataset[idx]["choices"]
answer = self.dataset[idx]["answer"]
start_chr = 'A'
choices_str = ''
index2ans = {}
all_choices = []
for choice in choices:
all_choices.append(start_chr)
index2ans[start_chr] = choice
choices_str += f"{start_chr}. {choice}\n"
start_chr = chr(ord(start_chr) + 1)
question = question + '\n' + choices_str
return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"]
def get_task_dataloader(task_name, task_cfg, args):
if "subset" in task_cfg.keys():
subset = task_cfg["subset"]
else:
subset = None
if task_name == "coco_caption":
dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset)
elif task_name == "flickr30k_caption":
dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset)
elif task_name == "vqav2":
dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == "textvqa":
dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == "gqa":
dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == "chartqa":
dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == "okvqa":
dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset)
elif task_name == "vizwiz":
dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset)
elif task_name == "docvqa":
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset)
elif task_name == "docvqa_test":
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset)
elif task_name == "realworldqa":
dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == "mmmu":
dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx)
elif task_name == "mmmu_pro":
dataset = MMMUProEvalDataset(args, task_cfg)
elif task_name == "mathvista":
dataset = MathVistaEvalDataset(args, task_cfg)
elif task_name == "mmbench":
dataset = MMBenchEvalDataset(args, task_cfg["gt_path"])
elif task_name == 'ocrbench':
dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == 'ai2diagram':
dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
elif task_name == 'ai2diagram_nomask':
dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
else:
raise NotImplementedError(f"Task {task_name} is not supported yet.")
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
pin_memory=True,
)
return dataloader