Tonic commited on
Commit
b5f76b2
1 Parent(s): 284fe2b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import argparse
4
+ import torch
5
+ import os
6
+ import json
7
+ from tqdm import tqdm
8
+ import shortuuid
9
+
10
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
11
+ from llava.conversation import conv_templates, SeparatorStyle
12
+ from llava.model.builder import load_pretrained_model
13
+ from llava.utils import disable_torch_init
14
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
15
+
16
+ from PIL import Image
17
+ import math
18
+
19
+
20
+ model_path = 'kaist-ai/prometheus-vision-13b-v1.0'
21
+ model_name = 'llava-v1.5'
22
+
23
+ def split_list(lst, n):
24
+ """Split a list into n (roughly) equal-sized chunks"""
25
+ chunk_size = math.ceil(len(lst) / n) # integer division
26
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
27
+
28
+
29
+ def get_chunk(lst, n, k):
30
+ chunks = split_list(lst, n)
31
+ return chunks[k]
32
+
33
+ @spaces.GPU
34
+ def eval_model(args, model_name = model_name, model_path = model_path):
35
+ disable_torch_init()
36
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
37
+
38
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
39
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
40
+ answers_file = os.path.expanduser(args.answers_file)
41
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
42
+ ans_file = open(answers_file, "w")
43
+ for line in tqdm(questions):
44
+ idx = line["question_id"]
45
+ image_file = line["image"]
46
+ qs = line["text"]
47
+ cur_prompt = qs
48
+ if model.config.mm_use_im_start_end:
49
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
50
+ else:
51
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
52
+
53
+ conv = conv_templates[args.conv_mode].copy()
54
+ conv.append_message(conv.roles[0], qs)
55
+ conv.append_message(conv.roles[1], None)
56
+ prompt = conv.get_prompt()
57
+
58
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
59
+
60
+ image = Image.open(os.path.join(args.image_folder, image_file))
61
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
62
+
63
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
64
+ keywords = [stop_str]
65
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
66
+
67
+ with torch.inference_mode():
68
+ output_ids = model.generate(
69
+ input_ids,
70
+ images=image_tensor.unsqueeze(0).half().cuda(),
71
+ do_sample=True if args.temperature > 0 else False,
72
+ temperature=args.temperature,
73
+ top_p=args.top_p,
74
+ num_beams=args.num_beams,
75
+ # no_repeat_ngram_size=3,
76
+ max_new_tokens=1024,
77
+ use_cache=True)
78
+
79
+ input_token_len = input_ids.shape[1]
80
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
81
+ if n_diff_input_output > 0:
82
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
83
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
84
+ outputs = outputs.strip()
85
+ if outputs.endswith(stop_str):
86
+ outputs = outputs[:-len(stop_str)]
87
+ outputs = outputs.strip()
88
+
89
+ ans_id = shortuuid.uuid()
90
+ ans_file.write(json.dumps({"question_id": idx,
91
+ "prompt": cur_prompt,
92
+ "text": outputs,
93
+ "answer_id": ans_id,
94
+ "model_id": model_name,
95
+ "metadata": {}}) + "\n")
96
+ ans_file.flush()
97
+ ans_file.close()
98
+
99
+ def gradio_wrapper( model_path = model_path , model_name = model_name, image_folder, question_file, answers_file, conv_mode, num_chunks, chunk_idx, temperature, top_p, num_beams):
100
+
101
+ question_file_path = os.path.join(tempfile.mkdtemp(), "question.jsonl")
102
+ with open(question_file_path, "w") as f:
103
+ for question in question_file:
104
+ f.write(json.dumps(question) + "\n")
105
+
106
+ temp_image_folder = tempfile.mkdtemp()
107
+ for image_file in image_folder:
108
+ image_path = os.path.join(temp_image_folder, image_file.name)
109
+ image_file.save(image_path)
110
+
111
+ args = argparse.Namespace(
112
+ model_path=model_path,
113
+ model_base=model_base,
114
+ image_folder=temp_image_folder,
115
+ question_file=question_file_path,
116
+ answers_file=answers_file,
117
+ conv_mode=conv_mode,
118
+ num_chunks=num_chunks,
119
+ chunk_idx=chunk_idx,
120
+ temperature=temperature,
121
+ top_p=top_p,
122
+ num_beams=num_beams
123
+ )
124
+
125
+ eval_model(args)
126
+
127
+ with open(answers_file, "r") as f:
128
+ answers = f.readlines()
129
+
130
+ return answers
131
+
132
+ iface = gr.Interface(
133
+ fn=gradio_wrapper,
134
+ inputs=[
135
+ gr.File(label="Image Folder", type="file", multiple=True),
136
+ gr.JSON(label="Question File"),
137
+ gr.Textbox(label="Answers File"),
138
+ gr.Dropdown(label="Conversation Mode", choices=["llava_v1"]),
139
+ gr.Slider(label="Number of Chunks", min_value=1, max_value=10, step=1, value=1),
140
+ gr.Slider(label="Chunk Index", min_value=0, max_value=9, step=1, value=0),
141
+ gr.Slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=0.2),
142
+ gr.Textbox(label="Top P", value=None),
143
+ gr.Slider(label="Number of Beams", min_value=1, max_value=10, step=1, value=1)
144
+ ],
145
+ outputs=[
146
+ gr.Textbox(label="Answers")
147
+ ],
148
+ title="Model Evaluation Interface",
149
+ description="A Gradio interface for evaluating models."
150
+ )
151
+
152
+ if __name__ == "__main__":
153
+ iface.launch()