Huayang Li commited on
Commit
d711bd7
·
1 Parent(s): 89690cf

update demo with case

Browse files
Files changed (1) hide show
  1. app_case.py +234 -0
app_case.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import os
3
+ import ipdb
4
+ import gradio as gr
5
+ import mdtex2html
6
+ from model.openllama import OpenLLAMAPEFTModel
7
+ import torch
8
+ import json
9
+ from header import TaskType, LoraConfig
10
+
11
+ # init the model
12
+ args = {
13
+ 'model': 'openllama_peft',
14
+ 'imagebind_ckpt_path': 'pretrained_ckpt/imagebind_ckpt',
15
+ 'vicuna_ckpt_path': 'openllmplayground/vicuna_7b_v0',
16
+ 'delta_ckpt_path': 'pretrained_ckpt/pandagpt_ckpt/7b/pytorch_model.pt',
17
+ 'stage': 2,
18
+ 'max_tgt_len': 128,
19
+ 'lora_r': 32,
20
+ 'lora_alpha': 32,
21
+ 'lora_dropout': 0.1,
22
+ }
23
+ model = OpenLLAMAPEFTModel(**args)
24
+ delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
25
+ model.load_state_dict(delta_ckpt, strict=False)
26
+ model = model.half().cuda().eval() if torch.cuda.is_available() else model.eval()
27
+ print(f'[!] init the model over ...')
28
+
29
+
30
+ """Override Chatbot.postprocess"""
31
+
32
+
33
+ def postprocess(self, y):
34
+ if y is None:
35
+ return []
36
+ for i, (message, response) in enumerate(y):
37
+ y[i] = (
38
+ None if message is None else mdtex2html.convert((message)),
39
+ None if response is None else mdtex2html.convert(response),
40
+ )
41
+ return y
42
+
43
+
44
+ gr.Chatbot.postprocess = postprocess
45
+
46
+
47
+ def parse_text(text):
48
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
49
+ lines = text.split("\n")
50
+ lines = [line for line in lines if line != ""]
51
+ count = 0
52
+ for i, line in enumerate(lines):
53
+ if "```" in line:
54
+ count += 1
55
+ items = line.split('`')
56
+ if count % 2 == 1:
57
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
58
+ else:
59
+ lines[i] = f'<br></code></pre>'
60
+ else:
61
+ if i > 0:
62
+ if count % 2 == 1:
63
+ line = line.replace("`", "\`")
64
+ line = line.replace("<", "&lt;")
65
+ line = line.replace(">", "&gt;")
66
+ line = line.replace(" ", "&nbsp;")
67
+ line = line.replace("*", "&ast;")
68
+ line = line.replace("_", "&lowbar;")
69
+ line = line.replace("-", "&#45;")
70
+ line = line.replace(".", "&#46;")
71
+ line = line.replace("!", "&#33;")
72
+ line = line.replace("(", "&#40;")
73
+ line = line.replace(")", "&#41;")
74
+ line = line.replace("$", "&#36;")
75
+ lines[i] = "<br>"+line
76
+ text = "".join(lines)
77
+ return text
78
+
79
+
80
+ def predict(
81
+ input,
82
+ image_path,
83
+ audio_path,
84
+ video_path,
85
+ thermal_path,
86
+ chatbot,
87
+ max_length,
88
+ top_p,
89
+ temperature,
90
+ history,
91
+ modality_cache,
92
+ ):
93
+ if image_path is None and audio_path is None and video_path is None and thermal_path is None:
94
+ return [(input, "There is no image/audio/video provided. Please upload the file to start a conversation.")]
95
+ else:
96
+ print(f'[!] image path: {image_path}\n[!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal pah: {thermal_path}')
97
+ # prepare the prompt
98
+ prompt_text = ''
99
+ for idx, (q, a) in enumerate(history):
100
+ if idx == 0:
101
+ prompt_text += f'{q}\n### Assistant: {a}\n###'
102
+ else:
103
+ prompt_text += f' Human: {q}\n### Assistant: {a}\n###'
104
+ if len(history) == 0:
105
+ prompt_text += f'{input}'
106
+ else:
107
+ prompt_text += f' Human: {input}'
108
+
109
+ response = model.generate({
110
+ 'prompt': prompt_text,
111
+ 'image_paths': [image_path] if image_path else [],
112
+ 'audio_paths': [audio_path] if audio_path else [],
113
+ 'video_paths': [video_path] if video_path else [],
114
+ 'thermal_paths': [thermal_path] if thermal_path else [],
115
+ 'top_p': top_p,
116
+ 'temperature': temperature,
117
+ 'max_tgt_len': max_length,
118
+ 'modality_embeds': modality_cache
119
+ })
120
+ chatbot.append((parse_text(input), parse_text(response)))
121
+ history.append((input, response))
122
+ return chatbot, history, modality_cache
123
+
124
+
125
+ def reset_user_input():
126
+ return gr.update(value='')
127
+
128
+
129
+ def reset_state():
130
+ return None, None, None, None, [], [], []
131
+
132
+
133
+ with gr.Blocks() as demo:
134
+ gr.HTML("""<h1 align="center">PandaGPT</h1>""")
135
+ gr.Markdown('''We note that the current online demo uses the 7B version of PandaGPT due to the limitation of computation resource.
136
+
137
+ Better results should be expected when switching to the 13B version of PandaGPT.
138
+
139
+ For more details on how to run 13B PandaGPT, please refer to our [main project repository](https://github.com/yxuansu/PandaGPT).''')
140
+
141
+ with gr.Row(scale=4):
142
+ with gr.Column(scale=2):
143
+ image_path = gr.Image(type="filepath", label="Image", value=None)
144
+
145
+ gr.Examples(
146
+ [
147
+ os.path.join(os.path.dirname(__file__), "assets/images/bird_image.jpg"),
148
+ os.path.join(os.path.dirname(__file__), "assets/images/dog_image.jpg"),
149
+ os.path.join(os.path.dirname(__file__), "assets/images/car_image.jpg"),
150
+ ],
151
+ image_path
152
+ )
153
+ with gr.Column(scale=2):
154
+ audio_path = gr.Audio(type="filepath", label="Audio", value=None)
155
+ gr.Examples(
156
+ [
157
+ os.path.join(os.path.dirname(__file__), "assets/audios/bird_audio.wav"),
158
+ os.path.join(os.path.dirname(__file__), "assets/audios/dog_audio.wav"),
159
+ os.path.join(os.path.dirname(__file__), "assets/audios/car_audio.wav"),
160
+ ],
161
+ audio_path
162
+ )
163
+ with gr.Row(scale=4):
164
+ with gr.Column(scale=2):
165
+ video_path = gr.Video(type='file', label="Video")
166
+
167
+ gr.Examples(
168
+ [
169
+ os.path.join(os.path.dirname(__file__), "assets/videos/world.mp4"),
170
+ os.path.join(os.path.dirname(__file__), "assets/videos/a.mp4"),
171
+ ],
172
+ video_path
173
+ )
174
+ with gr.Column(scale=2):
175
+ thermal_path = gr.Image(type="filepath", label="Thermal Image", value=None)
176
+
177
+ gr.Examples(
178
+ [
179
+ os.path.join(os.path.dirname(__file__), "assets/thermals/190662.jpg"),
180
+ os.path.join(os.path.dirname(__file__), "assets/thermals/210009.jpg"),
181
+ ],
182
+ thermal_path
183
+ )
184
+
185
+ chatbot = gr.Chatbot()
186
+ with gr.Row():
187
+ with gr.Column(scale=4):
188
+ with gr.Column(scale=12):
189
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
190
+ with gr.Column(min_width=32, scale=1):
191
+ submitBtn = gr.Button("Submit", variant="primary")
192
+ with gr.Column(scale=1):
193
+ emptyBtn = gr.Button("Clear History")
194
+ max_length = gr.Slider(0, 512, value=128, step=1.0, label="Maximum length", interactive=True)
195
+ top_p = gr.Slider(0, 1, value=0.01, step=0.01, label="Top P", interactive=True)
196
+ temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True)
197
+
198
+ history = gr.State([])
199
+ modality_cache = gr.State([])
200
+
201
+ submitBtn.click(
202
+ predict, [
203
+ user_input,
204
+ image_path,
205
+ audio_path,
206
+ video_path,
207
+ thermal_path,
208
+ chatbot,
209
+ max_length,
210
+ top_p,
211
+ temperature,
212
+ history,
213
+ modality_cache,
214
+ ], [
215
+ chatbot,
216
+ history,
217
+ modality_cache
218
+ ],
219
+ show_progress=True
220
+ )
221
+
222
+ submitBtn.click(reset_user_input, [], [user_input])
223
+ emptyBtn.click(reset_state, outputs=[
224
+ image_path,
225
+ audio_path,
226
+ video_path,
227
+ thermal_path,
228
+ chatbot,
229
+ history,
230
+ modality_cache
231
+ ], show_progress=True)
232
+
233
+
234
+ demo.launch(enable_queue=True)