abhin2564 commited on
Commit
caf419e
1 Parent(s): c744d49

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (1).py +295 -0
  2. requirements.txt +7 -0
app (1).py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import spaces
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import traceback
7
+ import re
8
+ import torch
9
+ import argparse
10
+ from transformers import AutoModel, AutoTokenizer
11
+
12
+ # README, How to run demo on different devices
13
+
14
+ # For Nvidia GPUs.
15
+ # python web_demo_2.5.py --device cuda
16
+
17
+ # For Mac with MPS (Apple silicon or AMD GPUs).
18
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
19
+
20
+ # Argparser
21
+ parser = argparse.ArgumentParser(description='demo')
22
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
23
+ args = parser.parse_args()
24
+ device = args.device
25
+ assert device in ['cuda', 'mps']
26
+
27
+ # Load model
28
+ model_path = 'openbmb/MiniCPM-Llama3-V-2_5'
29
+ if 'int4' in model_path:
30
+ if device == 'mps':
31
+ print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
32
+ exit()
33
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
34
+ else:
35
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
36
+ model = model.to(device=device)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
38
+ model.eval()
39
+
40
+
41
+
42
+ ERROR_MSG = "Error, please retry"
43
+ model_name = 'MiniCPM-Llama3-V 2.5'
44
+
45
+ form_radio = {
46
+ 'choices': ['Beam Search', 'Sampling'],
47
+ #'value': 'Beam Search',
48
+ 'value': 'Sampling',
49
+ 'interactive': True,
50
+ 'label': 'Decode Type'
51
+ }
52
+ # Beam Form
53
+ num_beams_slider = {
54
+ 'minimum': 0,
55
+ 'maximum': 5,
56
+ 'value': 3,
57
+ 'step': 1,
58
+ 'interactive': True,
59
+ 'label': 'Num Beams'
60
+ }
61
+ repetition_penalty_slider = {
62
+ 'minimum': 0,
63
+ 'maximum': 3,
64
+ 'value': 1.2,
65
+ 'step': 0.01,
66
+ 'interactive': True,
67
+ 'label': 'Repetition Penalty'
68
+ }
69
+ repetition_penalty_slider2 = {
70
+ 'minimum': 0,
71
+ 'maximum': 3,
72
+ 'value': 1.05,
73
+ 'step': 0.01,
74
+ 'interactive': True,
75
+ 'label': 'Repetition Penalty'
76
+ }
77
+ max_new_tokens_slider = {
78
+ 'minimum': 1,
79
+ 'maximum': 4096,
80
+ 'value': 1024,
81
+ 'step': 1,
82
+ 'interactive': True,
83
+ 'label': 'Max New Tokens'
84
+ }
85
+
86
+ top_p_slider = {
87
+ 'minimum': 0,
88
+ 'maximum': 1,
89
+ 'value': 0.8,
90
+ 'step': 0.05,
91
+ 'interactive': True,
92
+ 'label': 'Top P'
93
+ }
94
+ top_k_slider = {
95
+ 'minimum': 0,
96
+ 'maximum': 200,
97
+ 'value': 100,
98
+ 'step': 1,
99
+ 'interactive': True,
100
+ 'label': 'Top K'
101
+ }
102
+ temperature_slider = {
103
+ 'minimum': 0,
104
+ 'maximum': 2,
105
+ 'value': 0.7,
106
+ 'step': 0.05,
107
+ 'interactive': True,
108
+ 'label': 'Temperature'
109
+ }
110
+
111
+
112
+ def create_component(params, comp='Slider'):
113
+ if comp == 'Slider':
114
+ return gr.Slider(
115
+ minimum=params['minimum'],
116
+ maximum=params['maximum'],
117
+ value=params['value'],
118
+ step=params['step'],
119
+ interactive=params['interactive'],
120
+ label=params['label']
121
+ )
122
+ elif comp == 'Radio':
123
+ return gr.Radio(
124
+ choices=params['choices'],
125
+ value=params['value'],
126
+ interactive=params['interactive'],
127
+ label=params['label']
128
+ )
129
+ elif comp == 'Button':
130
+ return gr.Button(
131
+ value=params['value'],
132
+ interactive=True
133
+ )
134
+
135
+ @spaces.GPU(duration=120)
136
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
137
+ default_params = {"stream": False, "sampling": False, "num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
138
+ if params is None:
139
+ params = default_params
140
+ if img is None:
141
+ yield "Error, invalid image, please upload a new image"
142
+ else:
143
+ try:
144
+ image = img.convert('RGB')
145
+ answer = model.chat(
146
+ image=image,
147
+ msgs=msgs,
148
+ tokenizer=tokenizer,
149
+ **params
150
+ )
151
+ # if params['stream'] is False:
152
+ # res = re.sub(r'(<box>.*</box>)', '', answer)
153
+ # res = res.replace('<ref>', '')
154
+ # res = res.replace('</ref>', '')
155
+ # res = res.replace('<box>', '')
156
+ # answer = res.replace('</box>', '')
157
+ # else:
158
+ for char in answer:
159
+ yield char
160
+ except Exception as err:
161
+ print(err)
162
+ traceback.print_exc()
163
+ yield ERROR_MSG
164
+
165
+
166
+ def upload_img(image, _chatbot, _app_session):
167
+ image = Image.fromarray(image)
168
+
169
+ _app_session['sts']=None
170
+ _app_session['ctx']=[]
171
+ _app_session['img']=image
172
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
173
+ return _chatbot, _app_session
174
+
175
+
176
+ def respond(_chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
177
+ _question = _chat_bot[-1][0]
178
+ print('<Question>:', _question)
179
+ if _app_cfg.get('ctx', None) is None:
180
+ _chat_bot[-1][1] = 'Please upload an image to start'
181
+ yield (_chat_bot, _app_cfg)
182
+ else:
183
+ _context = _app_cfg['ctx'].copy()
184
+ if _context:
185
+ _context.append({"role": "user", "content": _question})
186
+ else:
187
+ _context = [{"role": "user", "content": _question}]
188
+ if params_form == 'Beam Search':
189
+ params = {
190
+ 'sampling': False,
191
+ 'stream': False,
192
+ 'num_beams': num_beams,
193
+ 'repetition_penalty': repetition_penalty,
194
+ "max_new_tokens": 896
195
+ }
196
+ else:
197
+ params = {
198
+ 'sampling': True,
199
+ 'stream': True,
200
+ 'top_p': top_p,
201
+ 'top_k': top_k,
202
+ 'temperature': temperature,
203
+ 'repetition_penalty': repetition_penalty_2,
204
+ "max_new_tokens": 896
205
+ }
206
+
207
+ gen = chat(_app_cfg['img'], _context, None, params)
208
+ _chat_bot[-1][1] = ""
209
+ for _char in gen:
210
+ _chat_bot[-1][1] += _char
211
+ _context[-1]["content"] += _char
212
+ yield (_chat_bot, _app_cfg)
213
+
214
+
215
+ def request(_question, _chat_bot, _app_cfg):
216
+ _chat_bot.append((_question, None))
217
+ return '', _chat_bot, _app_cfg
218
+
219
+
220
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg):
221
+ if len(_chat_bot) <= 1:
222
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
223
+ return '', _chat_bot, _app_cfg
224
+ elif _chat_bot[-1][0] == 'Regenerate':
225
+ return '', _chat_bot, _app_cfg
226
+ else:
227
+ _question = _chat_bot[-1][0]
228
+ _chat_bot = _chat_bot[:-1]
229
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
230
+ return request(_question, _chat_bot, _app_cfg)
231
+ # return respond(_chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
232
+
233
+
234
+ def clear_button_clicked(_question, _chat_bot, _app_cfg, _bt_pic):
235
+ _chat_bot.clear()
236
+ _app_cfg['sts'] = None
237
+ _app_cfg['ctx'] = None
238
+ _app_cfg['img'] = None
239
+ _bt_pic = None
240
+ return '', _chat_bot, _app_cfg, _bt_pic
241
+
242
+
243
+ with gr.Blocks() as demo:
244
+ with gr.Row():
245
+ with gr.Column(scale=1, min_width=300):
246
+ params_form = create_component(form_radio, comp='Radio')
247
+ with gr.Accordion("Beam Search") as beams_according:
248
+ num_beams = create_component(num_beams_slider)
249
+ repetition_penalty = create_component(repetition_penalty_slider)
250
+ with gr.Accordion("Sampling") as sampling_according:
251
+ top_p = create_component(top_p_slider)
252
+ top_k = create_component(top_k_slider)
253
+ temperature = create_component(temperature_slider)
254
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
255
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
256
+ clear = create_component({'value': 'Clear'}, comp='Button')
257
+ with gr.Column(scale=3, min_width=500):
258
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
259
+ bt_pic = gr.Image(label="Upload an image to start")
260
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
261
+ txt_message = gr.Textbox(label="Input text")
262
+
263
+ clear.click(
264
+ clear_button_clicked,
265
+ [txt_message, chat_bot, app_session, bt_pic],
266
+ [txt_message, chat_bot, app_session, bt_pic],
267
+ queue=False
268
+ )
269
+ txt_message.submit(
270
+ request,
271
+ #[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
272
+ [txt_message, chat_bot, app_session],
273
+ [txt_message, chat_bot, app_session],
274
+ queue=False
275
+ ).then(
276
+ respond,
277
+ [chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
278
+ [chat_bot, app_session]
279
+ )
280
+ regenerate.click(
281
+ regenerate_button_clicked,
282
+ [txt_message, chat_bot, app_session],
283
+ [txt_message, chat_bot, app_session],
284
+ queue=False
285
+ ).then(
286
+ respond,
287
+ [chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
288
+ [chat_bot, app_session]
289
+ )
290
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
291
+
292
+ # launch
293
+ #demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
294
+ demo.queue()
295
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ torch==2.1.2
3
+ torchvision==0.16.2
4
+ transformers==4.40.0
5
+ sentencepiece==0.1.99
6
+ opencv-python
7
+ gradio