|
|
import os |
|
|
|
|
|
from datasets import Dataset |
|
|
from openai import OpenAI |
|
|
from tqdm import tqdm |
|
|
|
|
|
from lmms_eval.tasks.charxiv.constant import REASONING_RESP_INST |
|
|
from lmms_eval.tasks.charxiv.descriptive_utils import ( |
|
|
build_descriptive_grading_queries, |
|
|
descriptive_query_helper, |
|
|
get_descriptive_result_gpt, |
|
|
postprocess_descriptive_grading_queries, |
|
|
) |
|
|
from lmms_eval.tasks.charxiv.reasoning_utils import ( |
|
|
build_reasoning_grading_queries, |
|
|
get_number_instruction, |
|
|
get_reasoning_result_gpt, |
|
|
) |
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY") |
|
|
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "YOUR_OPENAI_BASE_URL") |
|
|
MODEL_VERSION = os.getenv("MODEL_VERSION", "YOUR_MODEL_VERSION") |
|
|
|
|
|
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL) |
|
|
|
|
|
|
|
|
def charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None): |
|
|
inst_category = doc["reasoning_q_source"] |
|
|
if inst_category in [1, 2, 3]: |
|
|
question = REASONING_RESP_INST[inst_category].format(doc["reasoning_q"]) |
|
|
|
|
|
elif inst_category == 4: |
|
|
question = REASONING_RESP_INST[inst_category].format(doc["reasoning_q"], get_number_instruction(doc["reasoning_a"])) |
|
|
return question |
|
|
|
|
|
|
|
|
def charxiv_descriptive_process_docs(dataset: Dataset) -> Dataset: |
|
|
|
|
|
dataset = dataset.repeat(4) |
|
|
|
|
|
def _process_row(example, indice): |
|
|
q_number = indice % 4 + 1 |
|
|
descriptive_q = example[f"descriptive_q{q_number}"] |
|
|
qid = descriptive_q |
|
|
subplot_loc = example["subplot_loc"] |
|
|
if subplot_loc is None: |
|
|
subplot_row = example["subplot_row"] |
|
|
subplot_col = example["subplot_col"] |
|
|
subplot_loc = [subplot_row, subplot_col] |
|
|
descriptive_q = descriptive_query_helper(descriptive_q, subplot_loc) |
|
|
example[f"descriptive_q"] = descriptive_q |
|
|
example[f"descriptive_a"] = example[f"descriptive_a{q_number}"] |
|
|
return {"qid": qid, **example} |
|
|
|
|
|
dataset = dataset.map(_process_row, with_indices=True, num_proc=1) |
|
|
return dataset |
|
|
|
|
|
|
|
|
def charxiv_descriptive_doc_to_text_cot(doc, lmms_eval_specific_kwargs=None): |
|
|
return doc["descriptive_q"] |
|
|
|
|
|
|
|
|
def charxiv_doc_to_visual(doc): |
|
|
return [doc["image"].convert("RGB")] |
|
|
|
|
|
|
|
|
def charxiv_descriptive_doc_to_messages_cot(doc, lmms_eval_specific_kwargs=None): |
|
|
question = charxiv_descriptive_doc_to_text_cot(doc, lmms_eval_specific_kwargs) |
|
|
visuals = charxiv_doc_to_visual(doc) |
|
|
messages = [{"role": "user", "content": []}] |
|
|
messages[0]["content"].append({"type": "image", "url": visuals[0]}) |
|
|
messages[0]["content"].append({"type": "text", "text": question.strip()}) |
|
|
return messages |
|
|
|
|
|
|
|
|
def charxiv_reasoning_doc_to_messages_cot(doc, lmms_eval_specific_kwargs=None): |
|
|
question = charxiv_reasoning_doc_to_text_cot(doc, lmms_eval_specific_kwargs) |
|
|
visuals = charxiv_doc_to_visual(doc) |
|
|
messages = [{"role": "user", "content": []}] |
|
|
messages[0]["content"].append({"type": "image", "url": visuals[0]}) |
|
|
messages[0]["content"].append({"type": "text", "text": question.strip()}) |
|
|
return messages |
|
|
|
|
|
|
|
|
def charxiv_descriptive_process_results(doc, results): |
|
|
figure_id = doc["figure_path"] |
|
|
qid = doc["qid"] |
|
|
resp_key = f"{figure_id}_{qid}" |
|
|
response = results[0].strip() |
|
|
answer = doc["descriptive_a"] |
|
|
return {"descriptive_acc": {"group": (resp_key, response, answer), "qid": qid}} |
|
|
|
|
|
|
|
|
def charxiv_descriptive_aggregate_results(results): |
|
|
groups = {i: [] for i in range(1, 20)} |
|
|
for result in results: |
|
|
groups[result["qid"]].append(result["group"]) |
|
|
queries = build_descriptive_grading_queries(groups) |
|
|
combined_queries = [] |
|
|
for query in tqdm(queries): |
|
|
result = get_descriptive_result_gpt(client, query["grading_query"], len(query["resp_keys"]), model=MODEL_VERSION) |
|
|
|
|
|
combined_queries.append({**query, **result}) |
|
|
queries = combined_queries |
|
|
|
|
|
queries = postprocess_descriptive_grading_queries(queries) |
|
|
|
|
|
scores = [query["score"] for query in queries.values()] |
|
|
return sum(scores) / len(scores) |
|
|
|
|
|
|
|
|
def charxiv_reasoning_process_results(doc, results): |
|
|
figure_id = doc["figure_path"] |
|
|
inst_category = doc["reasoning_q_source"] |
|
|
response = results[0].strip() |
|
|
answer = doc["reasoning_a"] |
|
|
data = {} |
|
|
data["inst_category"] = inst_category |
|
|
data["figure_id"] = figure_id |
|
|
data["answer"] = answer |
|
|
resp_value = {"raw_question": doc["reasoning_q"], "response": response} |
|
|
return {"reasoning_acc": {"resp_value": resp_value, "resp_key": figure_id, "data": data}} |
|
|
|
|
|
|
|
|
def charxiv_reasoning_aggregate_results(results): |
|
|
data = {} |
|
|
resps = {} |
|
|
for i, result in enumerate(results): |
|
|
data[i] = result["data"] |
|
|
resps[result["resp_key"]] = result["resp_value"] |
|
|
queries = build_reasoning_grading_queries(data, resps) |
|
|
for figure_id, query in tqdm(queries.items()): |
|
|
ext, scr = get_reasoning_result_gpt(client, query["grading_query"]) |
|
|
queries[figure_id]["extracted_answer"] = ext |
|
|
queries[figure_id]["score"] = scr |
|
|
queries[figure_id].pop("grading_query") |
|
|
|
|
|
scores = [query["score"] for query in queries.values()] |
|
|
return sum(scores) / len(scores) |
|
|
|