|
import json |
|
import os |
|
import sys |
|
import time |
|
import yaml |
|
import spacy |
|
import ast |
|
from PIL import Image |
|
from glob import glob |
|
from tqdm import tqdm |
|
from collections import defaultdict |
|
import pandas as pd |
|
from io import BytesIO |
|
import base64 |
|
from anls import anls_score |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler |
|
import torchvision.transforms as T |
|
from eval import conversation as conversation_lib |
|
from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \ |
|
process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro |
|
from eval.mmmu_utils import evaluate as evaluate_mmmu |
|
from torchvision.transforms.functional import InterpolationMode |
|
from datasets import load_dataset, concatenate_datasets |
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
def build_transform(input_size): |
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=MEAN, std=STD) |
|
]) |
|
return transform |
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
best_ratio_diff = float('inf') |
|
best_ratio = (1, 1) |
|
area = width * height |
|
for ratio in target_ratios: |
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
if ratio_diff < best_ratio_diff: |
|
best_ratio_diff = ratio_diff |
|
best_ratio = ratio |
|
elif ratio_diff == best_ratio_diff: |
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
best_ratio = ratio |
|
return best_ratio |
|
|
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): |
|
orig_width, orig_height = image.size |
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
target_ratios = set( |
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_num and i * j >= min_num) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
target_height = image_size * target_aspect_ratio[1] |
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
processed_images = [] |
|
for i in range(blocks): |
|
box = ( |
|
(i % (target_width // image_size)) * image_size, |
|
(i // (target_width // image_size)) * image_size, |
|
((i % (target_width // image_size)) + 1) * image_size, |
|
((i // (target_width // image_size)) + 1) * image_size |
|
) |
|
|
|
split_img = resized_img.crop(box) |
|
processed_images.append(split_img) |
|
assert len(processed_images) == blocks |
|
if use_thumbnail and len(processed_images) != 1: |
|
thumbnail_img = image.resize((image_size, image_size)) |
|
processed_images.append(thumbnail_img) |
|
return processed_images |
|
|
|
|
|
def load_image(image, input_size=448, max_num=6, decoded=False): |
|
if not decoded: |
|
image = Image.open(image).convert('RGB') |
|
transform = build_transform(input_size=input_size) |
|
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
|
pixel_values = [transform(image) for image in images] |
|
pixel_values = torch.stack(pixel_values) |
|
return pixel_values |
|
|
|
|
|
def levenshtein_distance(s1, s2): |
|
if len(s1) > len(s2): |
|
s1, s2 = s2, s1 |
|
|
|
distances = range(len(s1) + 1) |
|
for i2, c2 in enumerate(s2): |
|
distances_ = [i2 + 1] |
|
for i1, c1 in enumerate(s1): |
|
if c1 == c2: |
|
distances_.append(distances[i1]) |
|
else: |
|
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) |
|
distances = distances_ |
|
return distances[-1] |
|
|
|
|
|
def get_anls_score(pred, gold_labels, threshold, llava_eval=False): |
|
values = [] |
|
for answer in gold_labels: |
|
|
|
gt_answer = ' '.join(answer.strip().lower().split()) |
|
det_answer = ' '.join(pred.strip().lower().split()) |
|
|
|
dist = levenshtein_distance(gt_answer, det_answer) |
|
length = max(len(answer.upper()), len(pred.upper())) |
|
values.append(0.0 if length == 0 else float(dist) / float(length)) |
|
|
|
question_result = 1 - min(values) |
|
|
|
if llava_eval: |
|
question_result = 1.0 if question_result >= threshold else 0.0 |
|
else: |
|
if (question_result < threshold): |
|
question_result = 0 |
|
|
|
return question_result |
|
|
|
|
|
def isNumber(n: str): |
|
try: |
|
float(n) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
|
|
class COCOEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, subset=None): |
|
self.args = args |
|
self.img_files = sorted(glob(os.path.join(img_dir, "*"))) |
|
|
|
if subset: |
|
self.img_files = self.img_files[:subset] |
|
|
|
self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files] |
|
|
|
def __len__(self): |
|
return len(self.img_files) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.img_files[idx] |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
return self.image_ids[idx], img |
|
|
|
|
|
class Flickr30KEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8')) |
|
|
|
if subset: |
|
self.test_samples = self.test_samples[:subset] |
|
|
|
def __len__(self): |
|
return len(self.test_samples) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"]) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", "")) |
|
|
|
return image_id, img |
|
|
|
|
|
class VQAv2EvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answer"] |
|
|
|
return img, question_id, question, answer |
|
|
|
|
|
class TextVQAEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg') |
|
if not os.path.exists(img_path): |
|
img_path = img_path.replace('.jpg', '.png') |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answers"] |
|
|
|
return img, question_id, question, answer |
|
|
|
|
|
class GQAEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
self.gt = [{ |
|
"question_id": int(k), |
|
"image": v['imageId'] + ".jpg", |
|
"question": v['question'], |
|
"answer": v['answer'] |
|
} for k, v in self.gt.items()] |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answer"] |
|
|
|
return img, question_id, question, [answer] |
|
|
|
|
|
class ChartQAEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
for i in range(len(self.gt)): |
|
self.gt[i]['question_id'] = i |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"]) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["query"] |
|
answer = self.gt[idx]["label"] |
|
|
|
return img, question_id, question, [answer] |
|
|
|
|
|
class OKVQAEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, question_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations'] |
|
self.questions = json.load(open(question_path, 'r'))['questions'] |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
qid2q = {q['question_id']: q['question'] for q in self.questions} |
|
|
|
for ann in self.gt: |
|
ann['answers'] = [ans['answer'] for ans in ann['answers']] |
|
ann['question'] = qid2q[ann['question_id']] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_id = str(self.gt[idx]["image_id"]) |
|
img_id = '0' * (12 - len(img_id)) + img_id |
|
img_file_name = f"COCO_val2014_{img_id}.jpg" |
|
img_path = os.path.join(self.img_dir, img_file_name) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answers"] |
|
|
|
return img, question_id, question, answer |
|
|
|
|
|
class DocVQAEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, split='val', subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
self.split = split |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1]) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["questionId"] |
|
question = self.gt[idx]["question"] |
|
|
|
if self.split == 'val': |
|
answer = self.gt[idx]["answers"] |
|
else: |
|
answer = [''] |
|
|
|
return img, question_id, question, answer |
|
|
|
|
|
class OCRBenchEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]['image_path']) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
dataset_name = self.gt[idx]["dataset_name"] |
|
question_id = f"{idx}" |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answers"] |
|
data_type = self.gt[idx]["type"] |
|
|
|
return img, question_id, question, answer, dataset_name, data_type |
|
|
|
|
|
class AI2DiagramEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
|
|
with open(gt_path, 'r') as json_file: |
|
json_list = list(json_file) |
|
self.gt = [json.loads(json_str) for json_str in json_list] |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]['image']) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answer"] |
|
|
|
return img, question_id, question, answer |
|
|
|
|
|
class AI2DiagramNoMaskEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
|
|
with open(gt_path, 'r') as json_file: |
|
json_list = list(json_file) |
|
self.gt = [json.loads(json_str) for json_str in json_list] |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") |
|
img_path = os.path.join(self.img_dir, img_file_name) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = self.gt[idx]["question_id"] |
|
question = self.gt[idx]["question"] |
|
answer = self.gt[idx]["answer"] |
|
|
|
return img, question_id, question, answer |
|
|
|
|
|
class RealworldQAEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, gt_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
|
|
if subset: |
|
self.gt = self.gt[:subset] |
|
|
|
def __len__(self): |
|
return len(self.gt) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.gt[idx]['image']) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
|
question_id = int(self.gt[idx]['image'].replace(".webp", "")) |
|
question = self.gt[idx]["question"] |
|
|
|
if self.gt[idx]['question_type'] == "multi-choice": |
|
choices = self.gt[idx]["choices"] |
|
start_chr = 'A' |
|
choices_str = '' |
|
index2ans = {} |
|
all_choices = [] |
|
for choice in choices: |
|
all_choices.append(start_chr) |
|
index2ans[start_chr] = choice |
|
choices_str += f"{start_chr}. {choice}\n" |
|
start_chr = chr(ord(start_chr) + 1) |
|
|
|
question = question + '\n' + choices_str |
|
question = question + "Answer with the option's letter from the given choices directly." |
|
answer = chr(ord('A') + self.gt[idx]['correct_choice_index']) |
|
else: |
|
question = question + "\nAnswer the question using a single word or phrase." |
|
answer = self.gt[idx]['answer'] |
|
|
|
return img, question_id, question, [answer] |
|
|
|
|
|
class MathVistaEvalDataset(Dataset): |
|
def __init__(self, args, task_cfg, gt_path=None): |
|
self.args = args |
|
self.task_cfg = task_cfg |
|
self.dataset = load_dataset("AI4Math/MathVista")['testmini'] |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img = self.dataset[idx]['decoded_image'] |
|
img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
|
question_id = self.dataset[idx]["pid"] |
|
question = self.dataset[idx]["question"] |
|
question_type = self.dataset[idx]["question_type"] |
|
query = self.dataset[idx]["query"] |
|
choices = self.dataset[idx]["choices"] |
|
answer = self.dataset[idx]["answer"] |
|
|
|
if question_type == 'multi_choice': |
|
start_chr = 'A' |
|
choices_str = '' |
|
index2ans = {} |
|
all_choices = [] |
|
for choice in choices: |
|
all_choices.append(start_chr) |
|
index2ans[start_chr] = choice |
|
choices_str += f"{start_chr}. {choice}\n" |
|
start_chr = chr(ord(start_chr) + 1) |
|
|
|
question = question + '\n' + choices_str |
|
question = question + "Answer with the option's letter from the given choices directly." |
|
answer = chr(ord('A') + choices.index(answer)) |
|
else: |
|
question = query.replace("Hint: ", "") |
|
index2ans = {} |
|
all_choices = [] |
|
|
|
return img, question_id, question_type, question, answer, str(index2ans), str(all_choices) |
|
|
|
|
|
def construct_prompt_for_fewshot(sample): |
|
config = { |
|
"task_instructions": "", |
|
"multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.", |
|
"short_ans_example_format": "{}\nAnswer the question using a single word or phrase." |
|
} |
|
|
|
question = sample['question'].strip() |
|
|
|
|
|
options = eval(sample['options']) |
|
example = "" |
|
if sample['question_type'] == 'multiple-choice': |
|
start_chr = 'A' |
|
prediction_range = [] |
|
index2ans = {} |
|
for option in options: |
|
prediction_range.append(start_chr) |
|
example += f"({start_chr}) {option}\n" |
|
index2ans[start_chr] = option |
|
start_chr = chr(ord(start_chr) + 1) |
|
empty_prompt_sample_structure = config['multi_choice_example_format'] |
|
empty_prompt = empty_prompt_sample_structure.format(question, example) |
|
res_dict = {'type': 'multichoice'} |
|
res_dict['index2ans'] = index2ans |
|
res_dict['correct_choice'] = sample['answer'] |
|
res_dict['all_choices'] = prediction_range |
|
res_dict['empty_prompt'] = empty_prompt |
|
if config['task_instructions']: |
|
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt |
|
else: |
|
res_dict['final_input_prompt'] = empty_prompt |
|
|
|
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] |
|
else: |
|
empty_prompt_sample_structure = config['short_ans_example_format'] |
|
empty_prompt = empty_prompt_sample_structure.format(question) |
|
res_dict = {'type': 'open'} |
|
res_dict['empty_prompt'] = empty_prompt |
|
if config['task_instructions']: |
|
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt |
|
else: |
|
res_dict['final_input_prompt'] = empty_prompt |
|
res_dict['gt_content'] = sample['answer'] |
|
|
|
res_dict.update(sample) |
|
return res_dict |
|
|
|
|
|
def process_image_tag(q): |
|
q = q.strip() |
|
|
|
|
|
if q == '<image 1>': |
|
q = 'Answer the question in the image.' |
|
elif ':<image 1>' in q: |
|
q = q.replace(':<image 1>', ' in the image. ') |
|
q = q.strip() |
|
elif ': <image 1>' in q: |
|
q = q.replace(': <image 1>', ' in the image. ') |
|
q = q.strip() |
|
elif '.<image 1>' in q or '. <image 1>' in q: |
|
q_list = q.split('<image 1>') |
|
q_list = [part.strip() for part in q_list if part.strip() != ''] |
|
q = ' '.join(q_list) |
|
elif q.startswith('<image 1> '): |
|
if q[10].isupper(): |
|
q = q.replace('<image 1>', '') |
|
else: |
|
q = q.replace('<image 1>', 'The image') |
|
q = q.strip() |
|
elif q.startswith('<image 1>'): |
|
q = q.replace('<image 1>', '') |
|
elif q.endswith('<image 1>?'): |
|
q = q.replace('<image 1>', 'the image') |
|
elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'): |
|
q = q.replace('<image 1>', '') |
|
q = q.strip() |
|
elif ' <image 1> ' in q: |
|
q = q.replace('<image 1>', 'the image') |
|
elif ' <image 1>' in q: |
|
q = q.replace('<image 1>', 'the image') |
|
elif '()<image 1>' in q: |
|
q = q.replace('()<image 1>', '') |
|
elif '(<image 1>)' in q: |
|
q = q.replace('(<image 1>)', '') |
|
elif '<image 1>.' in q: |
|
q = q.replace("<image 1>.", ". ") |
|
else: |
|
q = q.replace("<image 1>", ". ") |
|
q = q.strip() |
|
|
|
|
|
for i in range(2, 8): |
|
q = q.replace(f"<image {i}>", "") |
|
|
|
return q |
|
|
|
|
|
class MMMUProEvalDataset(Dataset): |
|
def __init__(self, args, task_cfg, subset=None): |
|
self.args = args |
|
self.task_cfg = task_cfg |
|
sub_dataset_list = [] |
|
|
|
|
|
|
|
MMMU_path = "MMMU/MMMU_Pro" |
|
|
|
_split = "test" |
|
|
|
self.dataset = load_dataset(MMMU_path, "standard", split=_split) |
|
if subset: |
|
self.dataset = self.dataset[:subset] |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
|
|
sample = self.dataset[idx] |
|
sample = process_single_sample_pro(sample) |
|
sample = construct_prompt_pro(sample, self.task_cfg) |
|
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
|
|
|
|
|
question_id = sample['id'] |
|
question = sample['final_input_prompt'] |
|
answer = sample['answer'] |
|
|
|
question = process_image_tag(question) |
|
question = self.task_cfg['default_image_token'] + '\n' + question |
|
|
|
if sample['question_type'] == 'multiple-choice': |
|
index2ans = sample['index2ans'] |
|
all_choices = sample['all_choices'] |
|
else: |
|
index2ans = {} |
|
all_choices = [] |
|
|
|
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ |
|
(all_choices) |
|
|
|
|
|
class MMMUEvalDataset(Dataset): |
|
def __init__(self, args, task_cfg, subset=None, start_idx=None): |
|
self.args = args |
|
self.task_cfg = task_cfg |
|
sub_dataset_list = [] |
|
|
|
|
|
|
|
MMMU_path = "MMMU/MMMU" |
|
|
|
_split = "test" if task_cfg["split"] == "test" else "validation" |
|
for subject in CAT_SHORT2LONG.values(): |
|
sub_dataset = load_dataset( |
|
MMMU_path, subject, |
|
split=_split, |
|
) |
|
sub_dataset_list.append(sub_dataset) |
|
|
|
dataset = concatenate_datasets(sub_dataset_list) |
|
|
|
if task_cfg["split"] != "test": |
|
dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])] |
|
|
|
|
|
|
|
self.dataset = dataset |
|
|
|
if subset: |
|
self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))] |
|
print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}") |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
|
|
sample = self.dataset[idx] |
|
sample = process_single_sample(sample) |
|
sample = construct_prompt(sample, self.task_cfg) |
|
|
|
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
|
question_id = sample['id'] |
|
question = sample['final_input_prompt'] |
|
answer = sample['answer'] |
|
|
|
question = process_image_tag(question) |
|
question = self.task_cfg['default_image_token'] + '\n' + question |
|
|
|
|
|
if sample['question_type'] == 'multiple-choice': |
|
index2ans = sample['index2ans'] |
|
all_choices = sample['all_choices'] |
|
else: |
|
index2ans = {} |
|
all_choices = [] |
|
|
|
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ |
|
(all_choices) |
|
|
|
|
|
|
|
class VizWizEvalDataset(Dataset): |
|
def __init__(self, args, img_dir, question_path, subset=None): |
|
self.args = args |
|
self.img_dir = img_dir |
|
self.questions = json.load(open(question_path, encoding='utf-8')) |
|
|
|
def __len__(self): |
|
return len(self.questions) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.img_dir, self.questions[idx]["image"]) |
|
img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
question = self.questions[idx]["question"] |
|
question_id = self.questions[idx]["image"] |
|
|
|
return img, question_id, question |
|
|
|
|
|
class MMBenchEvalDataset(Dataset): |
|
def __init__(self, args, gt_path, subset=None): |
|
self.args = args |
|
df = pd.read_csv(gt_path, sep='\t') |
|
self.dataset = [] |
|
for i, row in df.iterrows(): |
|
choices = [] |
|
for choice in ['A', 'B', 'C', 'D']: |
|
if str(row[choice]) != 'nan': |
|
choices.append(row[choice]) |
|
|
|
this_sample = { |
|
'index': row['index'], |
|
'question': row['question'], |
|
'hint': row['hint'], |
|
'category': row['category'], |
|
'image': Image.open(BytesIO(base64.b64decode(row['image']))), |
|
'choices': choices |
|
} |
|
|
|
|
|
if 'answer' in row.keys(): |
|
this_sample['answer'] = row['answer'] |
|
else: |
|
this_sample['answer'] = '' |
|
|
|
self.dataset.append(this_sample) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
|
question = self.dataset[idx]["question"] |
|
hint = self.dataset[idx]["hint"] |
|
question_id = self.dataset[idx]["index"] |
|
choices = self.dataset[idx]["choices"] |
|
answer = self.dataset[idx]["answer"] |
|
|
|
start_chr = 'A' |
|
choices_str = '' |
|
index2ans = {} |
|
all_choices = [] |
|
for choice in choices: |
|
all_choices.append(start_chr) |
|
index2ans[start_chr] = choice |
|
choices_str += f"{start_chr}. {choice}\n" |
|
start_chr = chr(ord(start_chr) + 1) |
|
|
|
question = question + '\n' + choices_str |
|
|
|
return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"] |
|
|
|
|
|
def get_task_dataloader(task_name, task_cfg, args): |
|
if "subset" in task_cfg.keys(): |
|
subset = task_cfg["subset"] |
|
else: |
|
subset = None |
|
|
|
if task_name == "coco_caption": |
|
dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset) |
|
elif task_name == "flickr30k_caption": |
|
dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset) |
|
elif task_name == "vqav2": |
|
dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == "textvqa": |
|
dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == "gqa": |
|
dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == "chartqa": |
|
dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == "okvqa": |
|
dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset) |
|
elif task_name == "vizwiz": |
|
dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset) |
|
elif task_name == "docvqa": |
|
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset) |
|
elif task_name == "docvqa_test": |
|
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset) |
|
elif task_name == "realworldqa": |
|
dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == "mmmu": |
|
dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx) |
|
elif task_name == "mmmu_pro": |
|
dataset = MMMUProEvalDataset(args, task_cfg) |
|
elif task_name == "mathvista": |
|
dataset = MathVistaEvalDataset(args, task_cfg) |
|
elif task_name == "mmbench": |
|
dataset = MMBenchEvalDataset(args, task_cfg["gt_path"]) |
|
elif task_name == 'ocrbench': |
|
dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == 'ai2diagram': |
|
dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
elif task_name == 'ai2diagram_nomask': |
|
dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
|
else: |
|
raise NotImplementedError(f"Task {task_name} is not supported yet.") |
|
|
|
dataloader = DataLoader( |
|
dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
pin_memory=True, |
|
) |
|
|
|
return dataloader |
|
|