finalf0 commited on
Commit
dfde69d
1 Parent(s): b4c39c5
Files changed (2) hide show
  1. app.py +256 -62
  2. requirements.txt +8 -1
app.py CHANGED
@@ -1,63 +1,257 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
  import gradio as gr
4
+ from PIL import Image
5
+ import traceback
6
+ import re
7
+ import torch
8
+ import argparse
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ # README, How to run demo on different devices
12
+
13
+ # For Nvidia GPUs.
14
+ # python web_demo_2.5.py --device cuda
15
+
16
+ # For Mac with MPS (Apple silicon or AMD GPUs).
17
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
18
+
19
+ # Argparser
20
+ parser = argparse.ArgumentParser(description='demo')
21
+ parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
22
+ args = parser.parse_args()
23
+ device = args.device
24
+ assert device in ['cuda', 'mps']
25
+
26
+ # Load model
27
+ model_path = 'openbmb/MiniCPM-Llama3-V-2_5'
28
+ if 'int4' in model_path:
29
+ if device == 'mps':
30
+ print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
31
+ exit()
32
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
33
+ else:
34
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
35
+ model = model.to(device=device)
36
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
37
+ model.eval()
38
+
39
+
40
+
41
+ ERROR_MSG = "Error, please retry"
42
+ model_name = 'MiniCPM-Llama3-V 2.5'
43
+
44
+ form_radio = {
45
+ 'choices': ['Beam Search', 'Sampling'],
46
+ #'value': 'Beam Search',
47
+ 'value': 'Sampling',
48
+ 'interactive': True,
49
+ 'label': 'Decode Type'
50
+ }
51
+ # Beam Form
52
+ num_beams_slider = {
53
+ 'minimum': 0,
54
+ 'maximum': 5,
55
+ 'value': 3,
56
+ 'step': 1,
57
+ 'interactive': True,
58
+ 'label': 'Num Beams'
59
+ }
60
+ repetition_penalty_slider = {
61
+ 'minimum': 0,
62
+ 'maximum': 3,
63
+ 'value': 1.2,
64
+ 'step': 0.01,
65
+ 'interactive': True,
66
+ 'label': 'Repetition Penalty'
67
+ }
68
+ repetition_penalty_slider2 = {
69
+ 'minimum': 0,
70
+ 'maximum': 3,
71
+ 'value': 1.05,
72
+ 'step': 0.01,
73
+ 'interactive': True,
74
+ 'label': 'Repetition Penalty'
75
+ }
76
+ max_new_tokens_slider = {
77
+ 'minimum': 1,
78
+ 'maximum': 4096,
79
+ 'value': 1024,
80
+ 'step': 1,
81
+ 'interactive': True,
82
+ 'label': 'Max New Tokens'
83
+ }
84
+
85
+ top_p_slider = {
86
+ 'minimum': 0,
87
+ 'maximum': 1,
88
+ 'value': 0.8,
89
+ 'step': 0.05,
90
+ 'interactive': True,
91
+ 'label': 'Top P'
92
+ }
93
+ top_k_slider = {
94
+ 'minimum': 0,
95
+ 'maximum': 200,
96
+ 'value': 100,
97
+ 'step': 1,
98
+ 'interactive': True,
99
+ 'label': 'Top K'
100
+ }
101
+ temperature_slider = {
102
+ 'minimum': 0,
103
+ 'maximum': 2,
104
+ 'value': 0.7,
105
+ 'step': 0.05,
106
+ 'interactive': True,
107
+ 'label': 'Temperature'
108
+ }
109
+
110
+
111
+ def create_component(params, comp='Slider'):
112
+ if comp == 'Slider':
113
+ return gr.Slider(
114
+ minimum=params['minimum'],
115
+ maximum=params['maximum'],
116
+ value=params['value'],
117
+ step=params['step'],
118
+ interactive=params['interactive'],
119
+ label=params['label']
120
+ )
121
+ elif comp == 'Radio':
122
+ return gr.Radio(
123
+ choices=params['choices'],
124
+ value=params['value'],
125
+ interactive=params['interactive'],
126
+ label=params['label']
127
+ )
128
+ elif comp == 'Button':
129
+ return gr.Button(
130
+ value=params['value'],
131
+ interactive=True
132
+ )
133
+
134
+
135
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
136
+ default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
137
+ if params is None:
138
+ params = default_params
139
+ if img is None:
140
+ return -1, "Error, invalid image, please upload a new image", None, None
141
+ try:
142
+ image = img.convert('RGB')
143
+ answer = model.chat(
144
+ image=image,
145
+ msgs=msgs,
146
+ tokenizer=tokenizer,
147
+ **params
148
+ )
149
+ res = re.sub(r'(<box>.*</box>)', '', answer)
150
+ res = res.replace('<ref>', '')
151
+ res = res.replace('</ref>', '')
152
+ res = res.replace('<box>', '')
153
+ answer = res.replace('</box>', '')
154
+ return -1, answer, None, None
155
+ except Exception as err:
156
+ print(err)
157
+ traceback.print_exc()
158
+ return -1, ERROR_MSG, None, None
159
+
160
+
161
+ def upload_img(image, _chatbot, _app_session):
162
+ image = Image.fromarray(image)
163
+
164
+ _app_session['sts']=None
165
+ _app_session['ctx']=[]
166
+ _app_session['img']=image
167
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
168
+ return _chatbot, _app_session
169
+
170
+
171
+ def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
172
+ if _app_cfg.get('ctx', None) is None:
173
+ _chat_bot.append((_question, 'Please upload an image to start'))
174
+ return '', _chat_bot, _app_cfg
175
+
176
+ _context = _app_cfg['ctx'].copy()
177
+ if _context:
178
+ _context.append({"role": "user", "content": _question})
179
+ else:
180
+ _context = [{"role": "user", "content": _question}]
181
+ print('<User>:', _question)
182
+
183
+ if params_form == 'Beam Search':
184
+ params = {
185
+ 'sampling': False,
186
+ 'num_beams': num_beams,
187
+ 'repetition_penalty': repetition_penalty,
188
+ "max_new_tokens": 896
189
+ }
190
+ else:
191
+ params = {
192
+ 'sampling': True,
193
+ 'top_p': top_p,
194
+ 'top_k': top_k,
195
+ 'temperature': temperature,
196
+ 'repetition_penalty': repetition_penalty_2,
197
+ "max_new_tokens": 896
198
+ }
199
+ code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
200
+ print('<Assistant>:', _answer)
201
+
202
+ _context.append({"role": "assistant", "content": _answer})
203
+ _chat_bot.append((_question, _answer))
204
+ if code == 0:
205
+ _app_cfg['ctx']=_context
206
+ _app_cfg['sts']=sts
207
+ return '', _chat_bot, _app_cfg
208
+
209
+
210
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
211
+ if len(_chat_bot) <= 1:
212
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
213
+ return '', _chat_bot, _app_cfg
214
+ elif _chat_bot[-1][0] == 'Regenerate':
215
+ return '', _chat_bot, _app_cfg
216
+ else:
217
+ _question = _chat_bot[-1][0]
218
+ _chat_bot = _chat_bot[:-1]
219
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
220
+ return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
221
+
222
+
223
+
224
+ with gr.Blocks() as demo:
225
+ with gr.Row():
226
+ with gr.Column(scale=1, min_width=300):
227
+ params_form = create_component(form_radio, comp='Radio')
228
+ with gr.Accordion("Beam Search") as beams_according:
229
+ num_beams = create_component(num_beams_slider)
230
+ repetition_penalty = create_component(repetition_penalty_slider)
231
+ with gr.Accordion("Sampling") as sampling_according:
232
+ top_p = create_component(top_p_slider)
233
+ top_k = create_component(top_k_slider)
234
+ temperature = create_component(temperature_slider)
235
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
236
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
237
+ with gr.Column(scale=3, min_width=500):
238
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
239
+ bt_pic = gr.Image(label="Upload an image to start")
240
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
241
+ txt_message = gr.Textbox(label="Input text")
242
+
243
+ regenerate.click(
244
+ regenerate_button_clicked,
245
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
246
+ [txt_message, chat_bot, app_session]
247
+ )
248
+ txt_message.submit(
249
+ respond,
250
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
251
+ [txt_message, chat_bot, app_session]
252
+ )
253
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
254
+
255
+ # launch
256
+ demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
257
+
requirements.txt CHANGED
@@ -1 +1,8 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ timm==0.9.10
3
+ torch==2.1.2
4
+ torchvision==0.16.2
5
+ transformers==4.40.0
6
+ sentencepiece==0.1.99
7
+ opencv-python
8
+ gradio