import json |
import os |
import sys |
import time |
import yaml |
import spacy |
import ast |
from PIL import Image |
from glob import glob |
from tqdm import tqdm |
from collections import defaultdict |
import pandas as pd |
from io import BytesIO |
import base64 |
from anls import anls_score |
import torch |
from torch.utils.data import Dataset, DataLoader, DistributedSampler |
import torchvision.transforms as T |
from eval import conversation as conversation_lib |
from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \ |
process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro |
from eval.mmmu_utils import evaluate as evaluate_mmmu |
from torchvision.transforms.functional import InterpolationMode |
from datasets import load_dataset, concatenate_datasets |
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
IMAGENET_STD = (0.229, 0.224, 0.225) |
def build_transform(input_size): |
transform = T.Compose([ |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
T.ToTensor(), |
T.Normalize(mean=MEAN, std=STD) |
]) |
return transform |
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
best_ratio_diff = float('inf') |
best_ratio = (1, 1) |
area = width * height |
for ratio in target_ratios: |
target_aspect_ratio = ratio[0] / ratio[1] |
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
if ratio_diff < best_ratio_diff: |
best_ratio_diff = ratio_diff |
best_ratio = ratio |
elif ratio_diff == best_ratio_diff: |
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
best_ratio = ratio |
return best_ratio |
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): |
orig_width, orig_height = image.size |
aspect_ratio = orig_width / orig_height |
target_ratios = set( |
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
i * j <= max_num and i * j >= min_num) |
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
target_aspect_ratio = find_closest_aspect_ratio( |
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
target_width = image_size * target_aspect_ratio[0] |
target_height = image_size * target_aspect_ratio[1] |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
resized_img = image.resize((target_width, target_height)) |
processed_images = [] |
for i in range(blocks): |
box = ( |
(i % (target_width // image_size)) * image_size, |
(i // (target_width // image_size)) * image_size, |
((i % (target_width // image_size)) + 1) * image_size, |
((i // (target_width // image_size)) + 1) * image_size |
) |
split_img = resized_img.crop(box) |
processed_images.append(split_img) |
assert len(processed_images) == blocks |
if use_thumbnail and len(processed_images) != 1: |
thumbnail_img = image.resize((image_size, image_size)) |
processed_images.append(thumbnail_img) |
return processed_images |
def load_image(image, input_size=448, max_num=6, decoded=False): |
if not decoded: |
image = Image.open(image).convert('RGB') |
transform = build_transform(input_size=input_size) |
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
pixel_values = [transform(image) for image in images] |
pixel_values = torch.stack(pixel_values) |
return pixel_values |
def levenshtein_distance(s1, s2): |
if len(s1) > len(s2): |
s1, s2 = s2, s1 |
distances = range(len(s1) + 1) |
for i2, c2 in enumerate(s2): |
distances_ = [i2 + 1] |
for i1, c1 in enumerate(s1): |
if c1 == c2: |
distances_.append(distances[i1]) |
else: |
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) |
distances = distances_ |
return distances[-1] |
def get_anls_score(pred, gold_labels, threshold, llava_eval=False): |
values = [] |
for answer in gold_labels: |
gt_answer = ' '.join(answer.strip().lower().split()) |
det_answer = ' '.join(pred.strip().lower().split()) |
dist = levenshtein_distance(gt_answer, det_answer) |
length = max(len(answer.upper()), len(pred.upper())) |
values.append(0.0 if length == 0 else float(dist) / float(length)) |
question_result = 1 - min(values) |
if llava_eval: |
question_result = 1.0 if question_result >= threshold else 0.0 |
else: |
if (question_result < threshold): |
question_result = 0 |
return question_result |
def isNumber(n: str): |
try: |
float(n) |
return True |
except ValueError: |
return False |
class COCOEvalDataset(Dataset): |
def __init__(self, args, img_dir, subset=None): |
self.args = args |
self.img_files = sorted(glob(os.path.join(img_dir, "*"))) |
if subset: |
self.img_files = self.img_files[:subset] |
self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files] |
def __len__(self): |
return len(self.img_files) |
def __getitem__(self, idx): |
img_path = self.img_files[idx] |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
return self.image_ids[idx], img |
class Flickr30KEvalDataset(Dataset): |
def __init__(self, args, img_dir, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8')) |
if subset: |
self.test_samples = self.test_samples[:subset] |
def __len__(self): |
return len(self.test_samples) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"]) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", "")) |
return image_id, img |
class VQAv2EvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8')) |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answer"] |
return img, question_id, question, answer |
class TextVQAEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg') |
if not os.path.exists(img_path): |
img_path = img_path.replace('.jpg', '.png') |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answers"] |
return img, question_id, question, answer |
class GQAEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8')) |
self.gt = [{ |
"question_id": int(k), |
"image": v['imageId'] + ".jpg", |
"question": v['question'], |
"answer": v['answer'] |
} for k, v in self.gt.items()] |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answer"] |
return img, question_id, question, [answer] |
class ChartQAEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8')) |
for i in range(len(self.gt)): |
self.gt[i]['question_id'] = i |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"]) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["query"] |
answer = self.gt[idx]["label"] |
return img, question_id, question, [answer] |
class OKVQAEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, question_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations'] |
self.questions = json.load(open(question_path, 'r'))['questions'] |
if subset: |
self.gt = self.gt[:subset] |
qid2q = {q['question_id']: q['question'] for q in self.questions} |
for ann in self.gt: |
ann['answers'] = [ans['answer'] for ans in ann['answers']] |
ann['question'] = qid2q[ann['question_id']] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_id = str(self.gt[idx]["image_id"]) |
img_id = '0' * (12 - len(img_id)) + img_id |
img_file_name = f"COCO_val2014_{img_id}.jpg" |
img_path = os.path.join(self.img_dir, img_file_name) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answers"] |
return img, question_id, question, answer |
class DocVQAEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, split='val', subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] |
if subset: |
self.gt = self.gt[:subset] |
self.split = split |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1]) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["questionId"] |
question = self.gt[idx]["question"] |
if self.split == 'val': |
answer = self.gt[idx]["answers"] |
else: |
answer = [''] |
return img, question_id, question, answer |
class OCRBenchEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8')) |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]['image_path']) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
dataset_name = self.gt[idx]["dataset_name"] |
question_id = f"{idx}" |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answers"] |
data_type = self.gt[idx]["type"] |
return img, question_id, question, answer, dataset_name, data_type |
class AI2DiagramEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
with open(gt_path, 'r') as json_file: |
json_list = list(json_file) |
self.gt = [json.loads(json_str) for json_str in json_list] |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]['image']) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answer"] |
return img, question_id, question, answer |
class AI2DiagramNoMaskEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
with open(gt_path, 'r') as json_file: |
json_list = list(json_file) |
self.gt = [json.loads(json_str) for json_str in json_list] |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") |
img_path = os.path.join(self.img_dir, img_file_name) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = self.gt[idx]["question_id"] |
question = self.gt[idx]["question"] |
answer = self.gt[idx]["answer"] |
return img, question_id, question, answer |
class RealworldQAEvalDataset(Dataset): |
def __init__(self, args, img_dir, gt_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.gt = json.load(open(gt_path, encoding='utf-8')) |
if subset: |
self.gt = self.gt[:subset] |
def __len__(self): |
return len(self.gt) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.gt[idx]['image']) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question_id = int(self.gt[idx]['image'].replace(".webp", "")) |
question = self.gt[idx]["question"] |
if self.gt[idx]['question_type'] == "multi-choice": |
choices = self.gt[idx]["choices"] |
start_chr = 'A' |
choices_str = '' |
index2ans = {} |
all_choices = [] |
for choice in choices: |
all_choices.append(start_chr) |
index2ans[start_chr] = choice |
choices_str += f"{start_chr}. {choice}\n" |
start_chr = chr(ord(start_chr) + 1) |
question = question + '\n' + choices_str |
question = question + "Answer with the option's letter from the given choices directly." |
answer = chr(ord('A') + self.gt[idx]['correct_choice_index']) |
else: |
question = question + "\nAnswer the question using a single word or phrase." |
answer = self.gt[idx]['answer'] |
return img, question_id, question, [answer] |
class MathVistaEvalDataset(Dataset): |
def __init__(self, args, task_cfg, gt_path=None): |
self.args = args |
self.task_cfg = task_cfg |
self.dataset = load_dataset("AI4Math/MathVista")['testmini'] |
def __len__(self): |
return len(self.dataset) |
def __getitem__(self, idx): |
img = self.dataset[idx]['decoded_image'] |
img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
question_id = self.dataset[idx]["pid"] |
question = self.dataset[idx]["question"] |
question_type = self.dataset[idx]["question_type"] |
query = self.dataset[idx]["query"] |
choices = self.dataset[idx]["choices"] |
answer = self.dataset[idx]["answer"] |
if question_type == 'multi_choice': |
start_chr = 'A' |
choices_str = '' |
index2ans = {} |
all_choices = [] |
for choice in choices: |
all_choices.append(start_chr) |
index2ans[start_chr] = choice |
choices_str += f"{start_chr}. {choice}\n" |
start_chr = chr(ord(start_chr) + 1) |
question = question + '\n' + choices_str |
question = question + "Answer with the option's letter from the given choices directly." |
answer = chr(ord('A') + choices.index(answer)) |
else: |
question = query.replace("Hint: ", "") |
index2ans = {} |
all_choices = [] |
return img, question_id, question_type, question, answer, str(index2ans), str(all_choices) |
def construct_prompt_for_fewshot(sample): |
config = { |
"task_instructions": "", |
"multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.", |
"short_ans_example_format": "{}\nAnswer the question using a single word or phrase." |
} |
question = sample['question'].strip() |
options = eval(sample['options']) |
example = "" |
if sample['question_type'] == 'multiple-choice': |
start_chr = 'A' |
prediction_range = [] |
index2ans = {} |
for option in options: |
prediction_range.append(start_chr) |
example += f"({start_chr}) {option}\n" |
index2ans[start_chr] = option |
start_chr = chr(ord(start_chr) + 1) |
empty_prompt_sample_structure = config['multi_choice_example_format'] |
empty_prompt = empty_prompt_sample_structure.format(question, example) |
res_dict = {'type': 'multichoice'} |
res_dict['index2ans'] = index2ans |
res_dict['correct_choice'] = sample['answer'] |
res_dict['all_choices'] = prediction_range |
res_dict['empty_prompt'] = empty_prompt |
if config['task_instructions']: |
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt |
else: |
res_dict['final_input_prompt'] = empty_prompt |
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] |
else: |
empty_prompt_sample_structure = config['short_ans_example_format'] |
empty_prompt = empty_prompt_sample_structure.format(question) |
res_dict = {'type': 'open'} |
res_dict['empty_prompt'] = empty_prompt |
if config['task_instructions']: |
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt |
else: |
res_dict['final_input_prompt'] = empty_prompt |
res_dict['gt_content'] = sample['answer'] |
res_dict.update(sample) |
return res_dict |
def process_image_tag(q): |
q = q.strip() |
if q == '<image 1>': |
q = 'Answer the question in the image.' |
elif ':<image 1>' in q: |
q = q.replace(':<image 1>', ' in the image. ') |
q = q.strip() |
elif ': <image 1>' in q: |
q = q.replace(': <image 1>', ' in the image. ') |
q = q.strip() |
elif '.<image 1>' in q or '. <image 1>' in q: |
q_list = q.split('<image 1>') |
q_list = [part.strip() for part in q_list if part.strip() != ''] |
q = ' '.join(q_list) |
elif q.startswith('<image 1> '): |
if q[10].isupper(): |
q = q.replace('<image 1>', '') |
else: |
q = q.replace('<image 1>', 'The image') |
q = q.strip() |
elif q.startswith('<image 1>'): |
q = q.replace('<image 1>', '') |
elif q.endswith('<image 1>?'): |
q = q.replace('<image 1>', 'the image') |
elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'): |
q = q.replace('<image 1>', '') |
q = q.strip() |
elif ' <image 1> ' in q: |
q = q.replace('<image 1>', 'the image') |
elif ' <image 1>' in q: |
q = q.replace('<image 1>', 'the image') |
elif '()<image 1>' in q: |
q = q.replace('()<image 1>', '') |
elif '(<image 1>)' in q: |
q = q.replace('(<image 1>)', '') |
elif '<image 1>.' in q: |
q = q.replace("<image 1>.", ". ") |
else: |
q = q.replace("<image 1>", ". ") |
q = q.strip() |
for i in range(2, 8): |
q = q.replace(f"<image {i}>", "") |
return q |
class MMMUProEvalDataset(Dataset): |
def __init__(self, args, task_cfg, subset=None): |
self.args = args |
self.task_cfg = task_cfg |
sub_dataset_list = [] |
MMMU_path = "MMMU/MMMU_Pro" |
_split = "test" |
self.dataset = load_dataset(MMMU_path, "standard", split=_split) |
if subset: |
self.dataset = self.dataset[:subset] |
def __len__(self): |
return len(self.dataset) |
def __getitem__(self, idx): |
sample = self.dataset[idx] |
sample = process_single_sample_pro(sample) |
sample = construct_prompt_pro(sample, self.task_cfg) |
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
question_id = sample['id'] |
question = sample['final_input_prompt'] |
answer = sample['answer'] |
question = process_image_tag(question) |
question = self.task_cfg['default_image_token'] + '\n' + question |
if sample['question_type'] == 'multiple-choice': |
index2ans = sample['index2ans'] |
all_choices = sample['all_choices'] |
else: |
index2ans = {} |
all_choices = [] |
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ |
(all_choices) |
class MMMUEvalDataset(Dataset): |
def __init__(self, args, task_cfg, subset=None, start_idx=None): |
self.args = args |
self.task_cfg = task_cfg |
sub_dataset_list = [] |
MMMU_path = "MMMU/MMMU" |
_split = "test" if task_cfg["split"] == "test" else "validation" |
for subject in CAT_SHORT2LONG.values(): |
sub_dataset = load_dataset( |
MMMU_path, subject, |
split=_split, |
) |
sub_dataset_list.append(sub_dataset) |
dataset = concatenate_datasets(sub_dataset_list) |
if task_cfg["split"] != "test": |
dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])] |
self.dataset = dataset |
if subset: |
self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))] |
print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}") |
def __len__(self): |
return len(self.dataset) |
def __getitem__(self, idx): |
sample = self.dataset[idx] |
sample = process_single_sample(sample) |
sample = construct_prompt(sample, self.task_cfg) |
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
question_id = sample['id'] |
question = sample['final_input_prompt'] |
answer = sample['answer'] |
question = process_image_tag(question) |
question = self.task_cfg['default_image_token'] + '\n' + question |
if sample['question_type'] == 'multiple-choice': |
index2ans = sample['index2ans'] |
all_choices = sample['all_choices'] |
else: |
index2ans = {} |
all_choices = [] |
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ |
(all_choices) |
class VizWizEvalDataset(Dataset): |
def __init__(self, args, img_dir, question_path, subset=None): |
self.args = args |
self.img_dir = img_dir |
self.questions = json.load(open(question_path, encoding='utf-8')) |
def __len__(self): |
return len(self.questions) |
def __getitem__(self, idx): |
img_path = os.path.join(self.img_dir, self.questions[idx]["image"]) |
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
question = self.questions[idx]["question"] |
question_id = self.questions[idx]["image"] |
return img, question_id, question |
class MMBenchEvalDataset(Dataset): |
def __init__(self, args, gt_path, subset=None): |
self.args = args |
df = pd.read_csv(gt_path, sep='\t') |
self.dataset = [] |
for i, row in df.iterrows(): |
choices = [] |
for choice in ['A', 'B', 'C', 'D']: |
if str(row[choice]) != 'nan': |
choices.append(row[choice]) |
this_sample = { |
'index': row['index'], |
'question': row['question'], |
'hint': row['hint'], |
'category': row['category'], |
'image': Image.open(BytesIO(base64.b64decode(row['image']))), |
'choices': choices |
} |
if 'answer' in row.keys(): |
this_sample['answer'] = row['answer'] |
else: |
this_sample['answer'] = '' |
self.dataset.append(this_sample) |
def __len__(self): |
return len(self.dataset) |
def __getitem__(self, idx): |
img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
question = self.dataset[idx]["question"] |
hint = self.dataset[idx]["hint"] |
question_id = self.dataset[idx]["index"] |
choices = self.dataset[idx]["choices"] |
answer = self.dataset[idx]["answer"] |
start_chr = 'A' |
choices_str = '' |
index2ans = {} |
all_choices = [] |
for choice in choices: |
all_choices.append(start_chr) |
index2ans[start_chr] = choice |
choices_str += f"{start_chr}. {choice}\n" |
start_chr = chr(ord(start_chr) + 1) |
question = question + '\n' + choices_str |
return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"] |
def get_task_dataloader(task_name, task_cfg, args): |
if "subset" in task_cfg.keys(): |
subset = task_cfg["subset"] |
else: |
subset = None |
if task_name == "coco_caption": |
dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset) |
elif task_name == "flickr30k_caption": |
dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset) |
elif task_name == "vqav2": |
dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == "textvqa": |
dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == "gqa": |
dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == "chartqa": |
dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == "okvqa": |
dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset) |
elif task_name == "vizwiz": |
dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset) |
elif task_name == "docvqa": |
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset) |
elif task_name == "docvqa_test": |
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset) |
elif task_name == "realworldqa": |
dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == "mmmu": |
dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx) |
elif task_name == "mmmu_pro": |
dataset = MMMUProEvalDataset(args, task_cfg) |
elif task_name == "mathvista": |
dataset = MathVistaEvalDataset(args, task_cfg) |
elif task_name == "mmbench": |
dataset = MMBenchEvalDataset(args, task_cfg["gt_path"]) |
elif task_name == 'ocrbench': |
dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == 'ai2diagram': |
dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
elif task_name == 'ai2diagram_nomask': |
dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
else: |
raise NotImplementedError(f"Task {task_name} is not supported yet.") |
dataloader = DataLoader( |
dataset, |
batch_size=1, |
shuffle=False, |
pin_memory=True, |
) |
return dataloader |