File size: 2,869 Bytes
134cb11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import json
from tqdm import tqdm
from pathlib import Path
from infer_utils import run_inference_single
import numpy as np
def run_geochat_inference(
model,
dataset_path,
processor,
tokenizer,
conv_mode,
answer_path,
use_video_data=False,
open_prompt=None,
repeat_frames=None,
prompt_strategy="interleave",
chronological_prefix=True,
data_frac=1,
data_size=None,
delete_system_prompt=False,
start_ind=0,
end_ind=None,
print_prompt=False
):
with open(dataset_path) as f:
qfabric_data = json.load(f)
if data_size is not None:
data_size = min(data_size, len(qfabric_data))
idx = np.random.choice(len(qfabric_data), data_size, replace=False)
qfabric_data = [qfabric_data[i] for i in idx]
elif data_frac < 1:
idx = np.random.choice(len(qfabric_data), int(len(qfabric_data) * data_frac), replace=False)
qfabric_data = [qfabric_data[i] for i in idx]
answers = {}
answers_tmp = str(answer_path).replace(".json", "_tmp.json")
if end_ind is not None:
answers_tmp = str(answers_tmp).replace(".json", f"_{start_ind}_{end_ind}.json")
qfabric_data = qfabric_data[start_ind:end_ind]
else:
answers_tmp = str(answers_tmp).replace(".json", f"_{start_ind}_end.json")
qfabric_data = qfabric_data[start_ind:]
print("answers_tmp: ", answers_tmp)
print("start ind: ", start_ind)
print("end ind: ", end_ind)
for question in tqdm(qfabric_data):
question_id = question["id"]
inp = question["conversations"][0]['value']
answer_str = question["conversations"][1]['value']
metadata = question['metadata']
image_paths = question['video']
task = question['task']
original_input_polygon = question['original_input_polygon']
dataset = question['dataset']
outputs = run_inference_single(
model=model,
processor=processor,
tokenizer=tokenizer,
conv_mode=conv_mode,
inp=inp,
image_paths=image_paths,
metadata=metadata,
repeat_frames=repeat_frames,
use_video_data=use_video_data,
prompt_strategy=prompt_strategy,
chronological_prefix=chronological_prefix,
delete_system_prompt=delete_system_prompt,
print_prompt=print_prompt
)
entry = {
"id": question_id,
"question": inp,
"predicted": outputs,
"ground_truth": answer_str,
"task": task,
"original_input_polygon": original_input_polygon,
"dataset": dataset,
}
answers[question_id] = entry
with open(answers_tmp, "a") as f:
f.write(json.dumps(entry) + "\n")
return answers
|