wcy1122 commited on
Commit
9066a31
Β·
1 Parent(s): 37cde01

add app file

Browse files
Files changed (3) hide show
  1. .gitignore +37 -0
  2. app.py +339 -0
  3. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__
3
+ *.pyc
4
+ *.egg-info
5
+ dist
6
+
7
+ # Log
8
+ *.log
9
+ *.log.*
10
+ *.jsonl
11
+
12
+ # Data
13
+ !**/alpaca-data-conversation.json
14
+
15
+ # Editor
16
+ .idea
17
+ *.swp
18
+
19
+ # Other
20
+ .DS_Store
21
+ wandb
22
+ output
23
+ work_dirs
24
+ data
25
+ model_zoo
26
+
27
+ checkpoints
28
+ ckpts*
29
+
30
+ .ipynb_checkpoints
31
+ *.ipynb
32
+
33
+ # DevContainer
34
+ !.devcontainer/*
35
+
36
+ # Demo
37
+ serve_images/
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import subprocess
3
+
4
+ import timm
5
+ import spaces
6
+ import io
7
+ import base64
8
+
9
+ import torch
10
+ import gradio as gr
11
+ import os
12
+ from PIL import Image
13
+ import tempfile
14
+ from huggingface_hub import snapshot_download
15
+ from transformers import TextIteratorStreamer
16
+ from threading import Thread
17
+
18
+ from diffusers import StableDiffusionXLPipeline
19
+
20
+ from minigemini.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
21
+ from minigemini.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
22
+ from minigemini.conversation import default_conversation, conv_templates, SeparatorStyle, Conversation
23
+ from minigemini.serve.gradio_web_server import function_markdown, tos_markdown, learn_more_markdown, title_markdown, block_css
24
+ from minigemini.model.builder import load_pretrained_model
25
+
26
+ # os.system('python -m pip install paddlepaddle-gpu==2.4.2.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html')
27
+ # os.system('pip install paddleocr>=2.0.1')
28
+ # from paddleocr import PaddleOCR
29
+
30
+ def download_model(repo_id):
31
+ local_dir = os.path.join('./checkpoints', repo_id.split('/')[-1])
32
+ os.makedirs(local_dir)
33
+ snapshot_download(repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False)
34
+
35
+
36
+ if not os.path.exists('./checkpoints/'):
37
+ os.makedirs('./checkpoints/')
38
+ download_model('YanweiLi/Mini-Gemini-13B-HD')
39
+ download_model('laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup')
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ load_8bit = False
43
+ load_4bit = False
44
+ dtype = torch.float16
45
+ conv_mode = "vicuna_v1"
46
+ model_path = './checkpoints/Mini-Gemini-13B-HD'
47
+ model_name = 'Mini-Gemini-13B-HD'
48
+ model_base = None
49
+
50
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name,
51
+ load_8bit, load_4bit,
52
+ device=device)
53
+
54
+ diffusion_pipe = StableDiffusionXLPipeline.from_pretrained(
55
+ "stabilityai/stable-diffusion-xl-base-1.0",
56
+ torch_dtype=torch.float16,
57
+ use_safetensors=True, variant="fp16"
58
+ ).to(device=device)
59
+
60
+
61
+ if hasattr(model.config, 'image_size_aux'):
62
+ if not hasattr(image_processor, 'image_size_raw'):
63
+ image_processor.image_size_raw = image_processor.crop_size.copy()
64
+ image_processor.crop_size['height'] = model.config.image_size_aux
65
+ image_processor.crop_size['width'] = model.config.image_size_aux
66
+ image_processor.size['shortest_edge'] = model.config.image_size_aux
67
+
68
+ no_change_btn = gr.Button()
69
+ enable_btn = gr.Button(interactive=True)
70
+ disable_btn = gr.Button(interactive=False)
71
+
72
+
73
+ def upvote_last_response(state):
74
+ return ("",) + (disable_btn,) * 3
75
+
76
+ def downvote_last_response(state):
77
+ return ("",) + (disable_btn,) * 3
78
+
79
+ def flag_last_response(state):
80
+ return ("",) + (disable_btn,) * 3
81
+
82
+ def clear_history():
83
+ state = conv_templates[conv_mode].copy()
84
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
85
+
86
+
87
+ def process_image(prompt, images):
88
+ if images is not None and len(images) > 0:
89
+ image_convert = images
90
+
91
+ # Similar operation in model_worker.py
92
+ image_tensor = process_images(image_convert, image_processor, model.config)
93
+
94
+ image_grid = getattr(model.config, 'image_grid', 1)
95
+ if hasattr(model.config, 'image_size_aux'):
96
+ raw_shape = [image_processor.image_size_raw['height'] * image_grid,
97
+ image_processor.image_size_raw['width'] * image_grid]
98
+ image_tensor_aux = image_tensor
99
+ image_tensor = torch.nn.functional.interpolate(image_tensor,
100
+ size=raw_shape,
101
+ mode='bilinear',
102
+ align_corners=False)
103
+ else:
104
+ image_tensor_aux = []
105
+
106
+ if image_grid >= 2:
107
+ raw_image = image_tensor.reshape(3,
108
+ image_grid,
109
+ image_processor.image_size_raw['height'],
110
+ image_grid,
111
+ image_processor.image_size_raw['width'])
112
+ raw_image = raw_image.permute(1, 3, 0, 2, 4)
113
+ raw_image = raw_image.reshape(-1, 3,
114
+ image_processor.image_size_raw['height'],
115
+ image_processor.image_size_raw['width'])
116
+
117
+ if getattr(model.config, 'image_global', False):
118
+ global_image = image_tensor
119
+ if len(global_image.shape) == 3:
120
+ global_image = global_image[None]
121
+ global_image = torch.nn.functional.interpolate(global_image,
122
+ size=[image_processor.image_size_raw['height'],
123
+ image_processor.image_size_raw['width']],
124
+ mode='bilinear',
125
+ align_corners=False)
126
+ # [image_crops, image_global]
127
+ raw_image = torch.cat([raw_image, global_image], dim=0)
128
+ image_tensor = raw_image.contiguous()
129
+ image_tensor = image_tensor.unsqueeze(0)
130
+
131
+ if type(image_tensor) is list:
132
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
133
+ image_tensor_aux = [image.to(model.device, dtype=torch.float16) for image in image_tensor_aux]
134
+ else:
135
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
136
+ image_tensor_aux = image_tensor_aux.to(model.device, dtype=torch.float16)
137
+ else:
138
+ images = None
139
+ image_tensor = None
140
+ image_tensor_aux = []
141
+
142
+ image_tensor_aux = image_tensor_aux if len(image_tensor_aux) > 0 else None
143
+
144
+ replace_token = DEFAULT_IMAGE_TOKEN
145
+ if getattr(model.config, 'mm_use_im_start_end', False):
146
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
147
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
148
+
149
+ image_args = {"images": image_tensor, "images_aux": image_tensor_aux}
150
+
151
+ return prompt, image_args
152
+
153
+
154
+ @spaces.GPU
155
+ def generate(state, imagebox, textbox, image_process_mode, gen_image, temperature, top_p, max_output_tokens):
156
+ prompt = state.get_prompt()
157
+ images = state.get_images(return_pil=True)
158
+ prompt, image_args = process_image(prompt, images)
159
+
160
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to("cuda:0")
161
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=30)
162
+
163
+ max_new_tokens = 512
164
+ do_sample = True if temperature > 0.001 else False
165
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
166
+
167
+ thread = Thread(target=model.generate, kwargs=dict(
168
+ inputs=input_ids,
169
+ do_sample=do_sample,
170
+ temperature=temperature,
171
+ top_p=top_p,
172
+ max_new_tokens=max_new_tokens,
173
+ streamer=streamer,
174
+ use_cache=True,
175
+ **image_args
176
+ ))
177
+ thread.start()
178
+
179
+ generated_text = ''
180
+ for new_text in streamer:
181
+ generated_text += new_text
182
+ if generated_text.endswith(stop_str):
183
+ generated_text = generated_text[:-len(stop_str)]
184
+ state.messages[-1][-1] = generated_text
185
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
186
+
187
+ if gen_image == 'Yes' and '<h>' in generated_text and '</h>' in generated_text:
188
+ common_neg_prompt = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
189
+ prompt = generated_text.split("<h>")[1].split("</h>")[0]
190
+ generated_text = generated_text.split("<h>")[0] + '\n' + 'Prompt: ' + prompt + '\n'
191
+
192
+ torch.cuda.empty_cache()
193
+ output_img = diffusion_pipe(prompt, negative_prompt=common_neg_prompt).images[0]
194
+ buffered = io.BytesIO()
195
+ output_img.save(buffered, format='JPEG')
196
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
197
+
198
+ output = (generated_text, img_b64_str)
199
+ state.messages[-1][-1] = output
200
+
201
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
202
+
203
+ torch.cuda.empty_cache()
204
+
205
+
206
+ @spaces.GPU
207
+ def add_text(state, imagebox, textbox, image_process_mode, gen_image):
208
+ if state is None:
209
+ state = conv_templates[conv_mode].copy()
210
+
211
+ if imagebox is not None:
212
+ textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
213
+ image = Image.open(imagebox).convert('RGB')
214
+
215
+ if gen_image == 'Yes':
216
+ textbox = textbox + ' <GEN>'
217
+
218
+ if imagebox is not None:
219
+ textbox = (textbox, image, image_process_mode)
220
+
221
+ state.append_message(state.roles[0], textbox)
222
+ state.append_message(state.roles[1], None)
223
+
224
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
225
+
226
+
227
+ def delete_text(state, image_process_mode):
228
+ state.messages[-1][-1] = None
229
+ prev_human_msg = state.messages[-2]
230
+ if type(prev_human_msg[1]) in (tuple, list):
231
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
232
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
233
+
234
+
235
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
236
+ with gr.Blocks(title='Mini-Gemini') as demo:
237
+ gr.Markdown(title_markdown)
238
+ # state = default_conversation.copy()
239
+ state = gr.State()
240
+
241
+ with gr.Row():
242
+ with gr.Column(scale=3):
243
+ imagebox = gr.Image(label="Input Image", type="filepath")
244
+ image_process_mode = gr.Radio(
245
+ ["Crop", "Resize", "Pad", "Default"],
246
+ value="Default",
247
+ label="Preprocess for non-square image", visible=False)
248
+
249
+ gr.Examples(examples=[
250
+ ["./minigemini/serve/examples/monday.jpg", "Explain why this meme is funny, and generate a picture when the weekend coming."],
251
+ ["./minigemini/serve/examples/woolen.png", "Show me one idea of what I could make with this?"],
252
+ ["./minigemini/serve/examples/extreme_ironing.jpg", "What is unusual about this image?"],
253
+ ["./minigemini/serve/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
254
+ ], inputs=[imagebox, textbox])
255
+
256
+ with gr.Accordion("Function", open=True) as parameter_row:
257
+ gen_image = gr.Radio(choices=['Yes', 'No'], value='No', interactive=True, label="Generate Image")
258
+
259
+ with gr.Accordion("Parameters", open=False) as parameter_row:
260
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
261
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
262
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
263
+
264
+ with gr.Column(scale=7):
265
+ chatbot = gr.Chatbot(
266
+ elem_id="chatbot",
267
+ label="Mini-Gemini Chatbot",
268
+ height=850,
269
+ layout="panel",
270
+ )
271
+ with gr.Row():
272
+ with gr.Column(scale=7):
273
+ textbox.render()
274
+ with gr.Column(scale=1, min_width=50):
275
+ submit_btn = gr.Button(value="Send", variant="primary")
276
+ with gr.Row(elem_id="buttons") as button_row:
277
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
278
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
279
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
280
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
281
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
282
+
283
+ gr.Markdown(function_markdown)
284
+ gr.Markdown(tos_markdown)
285
+ gr.Markdown(learn_more_markdown)
286
+
287
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
288
+ upvote_btn.click(
289
+ upvote_last_response,
290
+ [state],
291
+ [textbox, upvote_btn, downvote_btn, flag_btn]
292
+ )
293
+ downvote_btn.click(
294
+ downvote_last_response,
295
+ [state],
296
+ [textbox, upvote_btn, downvote_btn, flag_btn]
297
+ )
298
+ flag_btn.click(
299
+ flag_last_response,
300
+ [state],
301
+ [textbox, upvote_btn, downvote_btn, flag_btn]
302
+ )
303
+ clear_btn.click(
304
+ clear_history,
305
+ None,
306
+ [state, chatbot, textbox, imagebox] + btn_list,
307
+ queue=False
308
+ )
309
+ regenerate_btn.click(
310
+ delete_text,
311
+ [state, image_process_mode],
312
+ [state, chatbot, textbox, imagebox] + btn_list,
313
+ ).then(
314
+ generate,
315
+ [state, imagebox, textbox, image_process_mode, gen_image, temperature, top_p, max_output_tokens],
316
+ [state, chatbot, textbox, imagebox] + btn_list,
317
+ )
318
+ textbox.submit(
319
+ add_text,
320
+ [state, imagebox, textbox, image_process_mode, gen_image],
321
+ [state, chatbot, textbox, imagebox] + btn_list,
322
+ ).then(
323
+ generate,
324
+ [state, imagebox, textbox, image_process_mode, gen_image, temperature, top_p, max_output_tokens],
325
+ [state, chatbot, textbox, imagebox] + btn_list,
326
+ )
327
+ submit_btn.click(
328
+ add_text,
329
+ [state, imagebox, textbox, image_process_mode, gen_image],
330
+ [state, chatbot, textbox, imagebox] + btn_list,
331
+ ).then(
332
+ generate,
333
+ [state, imagebox, textbox, image_process_mode, gen_image, temperature, top_p, max_output_tokens],
334
+ [state, chatbot, textbox, imagebox] + btn_list,
335
+ )
336
+
337
+
338
+ demo.launch(debug=True)
339
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.36.2
2
+ torch
3
+ accelerate
4
+ open_clip_torch
5
+ diffusers
6
+ deepspeed
7
+ gradio==4.25.0
8
+ gradio-client==0.15.0
9
+ diffusers==0.26.3