|
from __future__ import annotations |
|
import pandas as pd |
|
from openai import OpenAI |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from time import sleep |
|
from typing import List, Dict, Any, Optional |
|
from openai import OpenAI |
|
import json |
|
from pathlib import Path |
|
from typing import Any, Dict, Iterable, List, Union |
|
import re |
|
import datetime |
|
from typing import Dict, List, Optional |
|
from mathruler.grader import extract_boxed_content, grade_answer |
|
import math |
|
from tqdm.auto import tqdm |
|
|
|
|
|
def extract_description(predict: str) -> Optional[str]: |
|
""" |
|
Extracts the content of the <answer>…</answer> block from `predict`. |
|
Returns the inner text (with leading/trailing whitespace stripped), |
|
or None if no <answer> tag is found. |
|
""" |
|
match = re.search(r"<description>([\s\S]*?)</description>", predict, re.DOTALL) |
|
if not match: |
|
return predict |
|
return match.group(1).strip() |
|
|
|
|
|
client = OpenAI( |
|
base_url="http://29.81.244.54:8081/v1", |
|
api_key="ANYKEY", |
|
) |
|
|
|
|
|
def chat_once(messages): |
|
resp = client.chat.completions.create( |
|
model="Qwen2.5-VL-72B-Instruct", |
|
messages=messages |
|
) |
|
return resp.choices[0].message.content |
|
|
|
|
|
def chat_batch( |
|
client, |
|
all_message_batches: List[List[Dict[str, str]]], |
|
*, |
|
model: str = "Qwen2.5-VL-72B-Instruct", |
|
max_workers: int = 8, |
|
retries: int = 2, |
|
backoff: float = 0.5, |
|
timeout: Optional[float] = None, |
|
) -> List[str]: |
|
""" |
|
Send many chat requests in parallel and return replies as a list of strings, |
|
preserving the order of `all_message_batches`. |
|
""" |
|
|
|
def _chat_once_with_retry(messages: List[Dict[str, str]]) -> str: |
|
last_err: Optional[BaseException] = None |
|
for attempt in range(retries + 1): |
|
try: |
|
resp = client.chat.completions.create( |
|
model=model, |
|
messages=messages, |
|
timeout=timeout, |
|
) |
|
|
|
choice = resp.choices[0] |
|
if hasattr(choice, "message") and getattr(choice.message, "content", None) is not None: |
|
return choice.message.content |
|
if hasattr(choice, "text") and choice.text is not None: |
|
return choice.text |
|
|
|
return str(choice) |
|
except Exception as e: |
|
last_err = e |
|
if attempt < retries: |
|
sleep(backoff * (2 ** attempt)) |
|
return f"Error: {last_err!r}" |
|
|
|
results: List[Optional[str]] = [None] * len(all_message_batches) |
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
future_to_idx = { |
|
executor.submit(_chat_once_with_retry, batch): i |
|
for i, batch in enumerate(all_message_batches) |
|
} |
|
for fut in as_completed(future_to_idx): |
|
i = future_to_idx[fut] |
|
results[i] = fut.result() |
|
|
|
|
|
return [r if r is not None else "Error: Unknown failure" for r in results] |
|
|
|
|
|
|
|
def load_json_list(path: Union[str, Path], encoding: str = "utf-8") -> List[Dict[str, Any]]: |
|
""" |
|
Load a JSON file whose top-level structure is a list of dicts. |
|
|
|
Raises: |
|
FileNotFoundError, json.JSONDecodeError, TypeError |
|
""" |
|
p = Path(path) |
|
with p.open("r", encoding=encoding) as f: |
|
data = json.load(f) |
|
|
|
if not isinstance(data, list): |
|
raise TypeError(f"Expected top-level JSON to be a list, got {type(data).__name__}") |
|
|
|
for i, item in enumerate(data): |
|
if not isinstance(item, dict): |
|
raise TypeError(f"Item at index {i} is {type(item).__name__}, expected dict") |
|
|
|
return data |
|
|
|
|
|
all_message_batches = [ |
|
[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "Hello, how are you?"} |
|
], |
|
[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "Tell me a joke."} |
|
], |
|
[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "Tell me a joke."} |
|
], |
|
[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "Tell me a joke."} |
|
], |
|
|
|
] |
|
|
|
|
|
res = chat_batch(client, all_message_batches) |
|
|
|
prompt_template = '''Text description: {Description}\nQuestion: {Question}\nYou are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide an internal step-by-step reasoning within <think> </think> tags, then provide a single word or phrase answer in \\boxed{}.''' |
|
MODEL = "Qwen2.5-VL-72B-Instruct" |
|
BATCH_SIZE = 16 |
|
filename = "MLLM_rlvr_train" |
|
out_file = f'./caption_out/{filename}.json' |
|
data = load_json_list(f'./gemini-flash/{filename}.json') |
|
|
|
|
|
def to_messages(example: Dict[str, Any]) -> List[Dict[str, str]]: |
|
"""Use the single string inside `predictions` as the user input.""" |
|
preds = example.get("predictions") |
|
question = example.get("problem") |
|
|
|
if isinstance(preds, list) and preds: |
|
first = preds[0] |
|
text = first if isinstance(first, str) else json.dumps(first, ensure_ascii=False) |
|
description = extract_description(text) |
|
input_question = prompt_template.replace('{Description}', description).replace('{Question}', question) |
|
else: |
|
input_question = 'None' |
|
|
|
return [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": input_question}, |
|
] |
|
|
|
|
|
|
|
Path(out_file).parent.mkdir(parents=True, exist_ok=True) |
|
with open(out_file, "w", encoding="utf-8"): |
|
pass |
|
|
|
total = len(data) |
|
num_batches = math.ceil(total / BATCH_SIZE) |
|
|
|
for start in tqdm(range(0, total, BATCH_SIZE), |
|
total=num_batches, desc="Batches", unit="batch"): |
|
chunk = data[start : start + BATCH_SIZE] |
|
batch_messages = [to_messages(ex) for ex in chunk] |
|
|
|
replies = chat_batch(client, batch_messages, model=MODEL, |
|
max_workers=8, retries=2, backoff=0.5, timeout=None) |
|
|
|
print(replies[0]) |
|
with open(out_file, "a", encoding="utf-8") as f: |
|
for ex, reply in zip(chunk, replies): |
|
record = {**ex, "model": MODEL, "model_caption_response": reply} |
|
f.write(json.dumps(record, ensure_ascii=False) + "\n") |
|
f.flush() |