|
|
import os |
|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Dict, Optional, Union |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
from data import ( |
|
|
DatasetType, |
|
|
DatasetConfig, |
|
|
get_dataset_config, |
|
|
get_formatted_instruction, |
|
|
process_response, |
|
|
save_descriptions, |
|
|
load_image_dataset, |
|
|
get_processed_response |
|
|
) |
|
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler |
|
|
import torch.distributed as dist |
|
|
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor |
|
|
from vllm import LLM, SamplingParams |
|
|
|
|
|
|
|
|
import io |
|
|
import base64 |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler('evaluation.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
INSTRUCTION = "\n\nYour final answer MUST BE put in \\boxed{}." |
|
|
|
|
|
def pil_to_base64(image_pil, format="PNG"): |
|
|
buffered = io.BytesIO() |
|
|
image_pil.save(buffered, format=format) |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
return img_str |
|
|
|
|
|
def base64_to_pil(base64_string): |
|
|
img_data = base64.b64decode(base64_string) |
|
|
image_pil = Image.open(io.BytesIO(img_data)) |
|
|
return image_pil |
|
|
|
|
|
class InstanceDataset(Dataset): |
|
|
|
|
|
def __init__(self, data): |
|
|
self.data = data |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
item = self.data[index] |
|
|
for k in item: |
|
|
if k == 'options' or k == 'choices': |
|
|
if item[k] == None: |
|
|
item[k] = "" |
|
|
else: |
|
|
item[k] = str(item[k]) |
|
|
if 'image_url' in item: |
|
|
image_url = item['image_url'] |
|
|
image_str = pil_to_base64(image_url) |
|
|
item['image_url'] = image_str |
|
|
instance = {'index': index, 'item': item} |
|
|
return instance |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Evaluate model on various math datasets') |
|
|
parser.add_argument('--dataset', type=str, choices=['mathvista', 'mathverse', 'mathvision', 'mathvision-mini', 'hallusionbench', 'mmmu-pro-vision', 'we-math', 'math500', 'gpqa', 'dynamath', 'logicvista'], |
|
|
default='mathvista', help='Dataset to evaluate on') |
|
|
parser.add_argument('--model_path', type=str, help='Path to the model', default="Qwen/Qwen3-VL-2B-Instruct") |
|
|
parser.add_argument('--name', type=str, help='model save name', default="plm") |
|
|
parser.add_argument('--bsz', type=int, help='batch size', default=2) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_type = DatasetType(args.dataset) |
|
|
dataset_config = get_dataset_config(dataset_type) |
|
|
|
|
|
output_folder = f"./outputs/{dataset_type.value}_{args.name}" |
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
|
|
MODEL_PATH = args.model_path |
|
|
processor = AutoProcessor.from_pretrained(MODEL_PATH) |
|
|
vlm = LLM(MODEL_PATH, limit_mm_per_prompt={"image": 1}, tensor_parallel_size=torch.cuda.device_count()) |
|
|
sampling_params = SamplingParams(max_tokens=2048, temperature=0.7, top_p=0.8, top_k=20, repetition_penalty=1.0, presence_penalty=1.5) |
|
|
|
|
|
|
|
|
logger.info(f"Loading dataset {dataset_config.name}") |
|
|
data = load_image_dataset(dataset_config) |
|
|
|
|
|
|
|
|
dataset = InstanceDataset(data) |
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=args.bsz) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for batch in tqdm(dataloader): |
|
|
|
|
|
indices = batch['index'] |
|
|
|
|
|
run_input_instances = [] |
|
|
run_indices = [] |
|
|
run_processed_responses = [] |
|
|
run_items = [] |
|
|
run_formatted_instructions = [] |
|
|
|
|
|
for j in range(len(indices)): |
|
|
index = indices[j].item() |
|
|
output_file = os.path.join(output_folder, f'{index}.json') |
|
|
global_item = batch['item'] |
|
|
if not os.path.exists(output_file): |
|
|
item = {} |
|
|
for k in global_item: |
|
|
item[k] = global_item[k][j] |
|
|
|
|
|
for k in item: |
|
|
if len(item[k]) > 0: |
|
|
if k == 'choices' or k == 'options': |
|
|
|
|
|
try: |
|
|
item[k] = eval(item[k]) |
|
|
except: |
|
|
item[k] = item[k] |
|
|
if k == 'image_url': |
|
|
item['image_url'] = base64_to_pil(item['image_url']) |
|
|
|
|
|
formatted_instruction = get_formatted_instruction(dataset_type, item) |
|
|
formatted_instruction = formatted_instruction + INSTRUCTION |
|
|
|
|
|
if 'image_url' in item: |
|
|
message = [{"role": "user", "content": [{"type": "image", "image": ""}, {"type": "text", "text": formatted_instruction}]}] |
|
|
else: |
|
|
message = [{"role": "user", "content": [{"type": "text", "text": formatted_instruction}]}] |
|
|
|
|
|
text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) |
|
|
if 'image_url' in item: |
|
|
input_instance = {'prompt': text, 'multi_modal_data': {'image': item['image_url']}} |
|
|
else: |
|
|
input_instance = {'prompt': text} |
|
|
|
|
|
|
|
|
|
|
|
run_input_instances.append(input_instance) |
|
|
run_indices.append(index) |
|
|
|
|
|
processed_response = get_processed_response(dataset_type, item) |
|
|
|
|
|
run_processed_responses.append(processed_response) |
|
|
run_items.append(item) |
|
|
run_formatted_instructions.append(formatted_instruction) |
|
|
|
|
|
outputs = vlm.generate(run_input_instances, sampling_params=sampling_params) |
|
|
|
|
|
for j in range(len(run_indices)): |
|
|
answer = outputs[j].outputs[0].text |
|
|
processed_response = run_processed_responses[j] |
|
|
item = run_items[j] |
|
|
formatted_instruction = run_formatted_instructions[j] |
|
|
|
|
|
if 'image_url' in item: |
|
|
del item['image_url'] |
|
|
|
|
|
description = { |
|
|
'index': j, |
|
|
'item': json.dumps(item), |
|
|
'formatted_instruction': formatted_instruction, |
|
|
'processed_response': processed_response, |
|
|
'answer': answer |
|
|
} |
|
|
|
|
|
with open(output_file, 'w') as f: |
|
|
json.dump(description, f, indent = 4) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|