merve HF staff commited on
Commit
411cee3
1 Parent(s): add43e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -231
app.py CHANGED
@@ -1,9 +1,7 @@
1
- import copy
2
  import gradio as gr
3
- from transformers import AutoProcessor, Idefics2ForConditionalGeneration, TextIteratorStreamer
4
- from threading import Thread
5
  import re
6
- import time
7
  from PIL import Image
8
  import torch
9
  import spaces
@@ -11,7 +9,7 @@ import subprocess
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
 
14
- PROCESSOR = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
15
 
16
  model = Idefics2ForConditionalGeneration.from_pretrained(
17
  "HuggingFaceM4/idefics2-8b",
@@ -19,117 +17,35 @@ model = Idefics2ForConditionalGeneration.from_pretrained(
19
  _attn_implementation="flash_attention_2",
20
  trust_remote_code=True).to("cuda")
21
 
22
-
23
-
24
- def turn_is_pure_media(turn):
25
- return turn[1] is None
26
- def format_user_prompt_with_im_history_and_system_conditioning(
27
- user_prompt, chat_history
28
  ):
29
- """
30
- Produces the resulting list that needs to go inside the processor.
31
- It handles the potential image(s), the history and the system conditionning.
32
- """
33
- resulting_messages = copy.deepcopy([])
34
- resulting_images = []
35
-
36
- # Format history
37
- for turn in chat_history:
38
- if not resulting_messages or (resulting_messages and resulting_messages[-1]["role"] != "user"):
39
- resulting_messages.append(
40
- {
41
- "role": "user",
42
- "content": [],
43
- }
44
- )
45
 
46
- if turn_is_pure_media(turn):
47
- media = turn[0][0]
48
- resulting_messages[-1]["content"].append({"type": "image"})
49
- resulting_images.append(Image.open(media))
50
- else:
51
- user_utterance, assistant_utterance = turn
52
- resulting_messages[-1]["content"].append(
53
- {"type": "text", "text": user_utterance.strip()}
54
- )
55
- resulting_messages.append(
56
- {
57
- "role": "assistant",
58
- "content": [
59
- {"type": "text", "text": user_utterance.strip()}
60
- ]
61
- }
62
- )
63
 
64
- # Format current input
65
- if not user_prompt["files"]:
66
- resulting_messages.append(
67
  {
68
  "role": "user",
69
- "content": [
70
- {"type": "text", "text": user_prompt['text']}
71
- ],
72
- }
73
- )
74
- else:
75
- # Choosing to put the image first (i.e. before the text), but this is an arbiratrary choice.
76
- resulting_messages.append(
77
- {
78
- "role": "user",
79
- "content": [{"type": "image"}] * len(user_prompt['files']) + [
80
- {"type": "text", "text": user_prompt['text']}
81
  ]
82
  }
83
- )
84
- for im in user_prompt["files"]:
85
- print(im)
86
- if isinstance(im, str):
87
-
88
- resulting_images.extend([Image.open(im)])
89
- elif isinstance(im, dict):
90
- resulting_images.extend([Image.open(im['path'])])
91
-
92
-
93
- return resulting_messages, resulting_images
94
-
95
-
96
- def extract_images_from_msg_list(msg_list):
97
- all_images = []
98
- for msg in msg_list:
99
- for c_ in msg["content"]:
100
- if isinstance(c_, Image.Image):
101
- all_images.append(c_)
102
- return all_images
103
-
104
- @spaces.GPU(duration=180)
105
- def model_inference(
106
- user_prompt,
107
- chat_history,
108
- decoding_strategy,
109
- temperature,
110
- max_new_tokens,
111
- repetition_penalty,
112
- top_p,
113
- ):
114
- if user_prompt["text"].strip() == "" and not user_prompt["files"]:
115
- gr.Error("Please input a query and optionally image(s).")
116
-
117
- if user_prompt["text"].strip() == "" and user_prompt["files"]:
118
- gr.Error("Please input a text query along the image(s).")
119
-
120
-
121
- streamer = TextIteratorStreamer(
122
- PROCESSOR.tokenizer,
123
- skip_prompt=True,
124
- timeout=5.,
125
- )
126
-
127
- # Common parameters to all decoding strategies
128
- # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
129
  generation_args = {
130
  "max_new_tokens": max_new_tokens,
131
  "repetition_penalty": repetition_penalty,
132
- "streamer": streamer,
133
  }
134
 
135
  assert decoding_strategy in [
@@ -143,133 +59,122 @@ def model_inference(
143
  generation_args["do_sample"] = True
144
  generation_args["top_p"] = top_p
145
 
146
- # Creating model inputs
147
- resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning(
148
- user_prompt=user_prompt,
149
- chat_history=chat_history,
150
- )
151
- prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True)
152
- inputs = PROCESSOR(text=prompt, images=resulting_images if resulting_images else None, return_tensors="pt")
153
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
154
  generation_args.update(inputs)
155
 
156
-
157
- thread = Thread(
158
- target=model.generate,
159
- kwargs=generation_args,
160
- )
161
- thread.start()
162
-
163
- print("Start generating")
164
- acc_text = ""
165
- for text_token in streamer:
166
- time.sleep(0.04)
167
- acc_text += text_token
168
- if acc_text.endswith("<end_of_utterance>"):
169
- acc_text = acc_text[:-18]
170
- yield acc_text
171
- print("Success - generated the following text:", acc_text)
172
- print("-----")
173
- BOT_AVATAR = "IDEFICS_logo.png"
174
-
175
- # Hyper-parameters for generation
176
- max_new_tokens = gr.Slider(
177
- minimum=8,
178
- maximum=1024,
179
- value=512,
180
- step=1,
181
- interactive=True,
182
- label="Maximum number of new tokens to generate",
183
- )
184
- repetition_penalty = gr.Slider(
185
- minimum=0.01,
186
- maximum=5.0,
187
- value=1.2,
188
- step=0.01,
189
- interactive=True,
190
- label="Repetition penalty",
191
- info="1.0 is equivalent to no penalty",
192
- )
193
- decoding_strategy = gr.Radio(
194
- [
195
- "Greedy",
196
- "Top P Sampling",
197
- ],
198
- value="Greedy",
199
- label="Decoding strategy",
200
- interactive=True,
201
- info="Higher values is equivalent to sampling more low-probability tokens.",
202
- )
203
- temperature = gr.Slider(
204
- minimum=0.0,
205
- maximum=5.0,
206
- value=0.4,
207
- step=0.1,
208
- interactive=True,
209
- label="Sampling temperature",
210
- info="Higher values will produce more diverse outputs.",
211
- )
212
- top_p = gr.Slider(
213
- minimum=0.01,
214
- maximum=0.99,
215
- value=0.8,
216
- step=0.01,
217
- interactive=True,
218
- label="Top P",
219
- info="Higher values is equivalent to sampling more low-probability tokens.",
220
- )
221
-
222
-
223
- chatbot = gr.Chatbot(
224
- label="Idefics2",
225
- avatar_images=[None, BOT_AVATAR],
226
- # height=750,
227
- )
228
-
229
-
230
- with gr.Blocks(fill_height=True, css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img { width: auto; max-width: 30%; height: auto; max-height: 30%; }") as demo:
231
- decoding_strategy.change(
232
- fn=lambda selection: gr.Slider(
233
- visible=(
234
- selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
235
- )
236
- ),
237
- inputs=decoding_strategy,
238
- outputs=temperature,
239
- )
240
- decoding_strategy.change(
241
- fn=lambda selection: gr.Slider(
242
- visible=(
243
- selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
244
- )
245
- ),
246
- inputs=decoding_strategy,
247
- outputs=repetition_penalty,
248
- )
249
- decoding_strategy.change(
250
- fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
251
- inputs=decoding_strategy,
252
- outputs=top_p,
253
- )
254
-
255
- description = "Try [IDEFICS2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b), the instruction fine-tuned IDEFICS2 in this demo. 💬 IDEFICS2 is a state-of-the-art vision language model in various benchmarks. To get started, upload an image and write a text prompt or try one of the examples. You can also play with advanced generation parameters. To learn more about IDEFICS2, read [the blog](https://huggingface.co/blog/idefics2). Note that this model is not as chatty as the upcoming chatty model, and it will give shorter answers."
256
-
257
-
258
- gr.ChatInterface(
259
- fn=model_inference,
260
- chatbot=chatbot,
261
- examples=[[{"text": "How many items are sold?", "files":["./example_images/docvqa_example.png"]}],
262
- [{"text": "What is this UI about?", "files":["./example_images/s2w_example.png"]}],
263
- [{"text": "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "files":["./example_images/example_images_travel_tips.jpg"]}],
264
- [{"text": "Can you tell me a very short story based on this image?", "files":["./example_images/chicken_on_money.png"]}],
265
- [{"text": "Where is this pastry from?", "files":["./example_images/baklava.png"]}],
266
- [{"text": "How much percent is the order status?", "files":["./example_images/dummy_pdf.png"]}],
267
- [{"text":"As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "files":["./example_images/art_critic.png"]}]
268
  ],
269
- description=description,
270
- title="Idefics2 Playground 🐶 ",
271
- multimodal=True,
272
- additional_inputs=[decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p],
273
- )
274
 
275
  demo.launch(debug=True)
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, Idefics2ForConditionalGeneration
 
3
  import re
4
+ import time
5
  from PIL import Image
6
  import torch
7
  import spaces
 
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
 
12
+ processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
13
 
14
  model = Idefics2ForConditionalGeneration.from_pretrained(
15
  "HuggingFaceM4/idefics2-8b",
 
17
  _attn_implementation="flash_attention_2",
18
  trust_remote_code=True).to("cuda")
19
 
20
+ @spaces.GPU
21
+ def model_inference(
22
+ image, text, decoding_strategy, temperature,
23
+ max_new_tokens, repetition_penalty, top_p
 
 
24
  ):
25
+ if text == "" and not image:
26
+ gr.Error("Please input a query and optionally image(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ if text == "" and image:
29
+ gr.Error("Please input a text query along the image(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ resulting_messages = [
 
 
32
  {
33
  "role": "user",
34
+ "content": [{"type": "image"}] + [
35
+ {"type": "text", "text": text}
 
 
 
 
 
 
 
 
 
 
36
  ]
37
  }
38
+ ]
39
+
40
+
41
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
42
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
43
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
44
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  generation_args = {
46
  "max_new_tokens": max_new_tokens,
47
  "repetition_penalty": repetition_penalty,
48
+
49
  }
50
 
51
  assert decoding_strategy in [
 
59
  generation_args["do_sample"] = True
60
  generation_args["top_p"] = top_p
61
 
62
+
 
 
 
 
 
 
 
63
  generation_args.update(inputs)
64
 
65
+ # Generate
66
+ generated_ids = model.generate(**generation_args)
67
+
68
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
69
+ print(generated_texts)
70
+ pattern = r"Assistant: (.*)"
71
+
72
+ # Use regular expression to find the desired part
73
+ result = re.search(pattern, generated_texts[0]).group(1)
74
+
75
+ return result[:-1]
76
+
77
+
78
+ with gr.Blocks(fill_height=True) as demo:
79
+ gr.Markdown("## IDEFICS2 Instruction 🐶")
80
+ gr.Markdown("Play with fine-tuned [IDEFICS2](https://huggingface.co/HuggingFaceM4/idefics2-8b) in this demo. To get started, upload an image and text or try one of the examples.")
81
+ gr.Markdown("**Important note**: This model is not made for chatting, the chatty IDEFICS2 will be released in the upcoming days. **This model is very strong on various tasks, including visual question answering, document retrieval and more.**")
82
+ gr.Markdown("Learn more about IDEFICS2 in this [blog post](https://huggingface.co/blog/idefics2).")
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ image_input = gr.Image(label="Upload your Image", type="pil")
87
+ query_input = gr.Textbox(label="Prompt")
88
+ submit_btn = gr.Button("Submit")
89
+
90
+ with gr.Column():
91
+
92
+ output = gr.Textbox(label="Output")
93
+
94
+ with gr.Accordion():
95
+ # Hyper-parameters for generation
96
+ max_new_tokens = gr.Slider(
97
+ minimum=8,
98
+ maximum=1024,
99
+ value=512,
100
+ step=1,
101
+ interactive=True,
102
+ label="Maximum number of new tokens to generate",
103
+ )
104
+ repetition_penalty = gr.Slider(
105
+ minimum=0.01,
106
+ maximum=5.0,
107
+ value=1.2,
108
+ step=0.01,
109
+ interactive=True,
110
+ label="Repetition penalty",
111
+ info="1.0 is equivalent to no penalty",
112
+ )
113
+ temperature = gr.Slider(
114
+ minimum=0.0,
115
+ maximum=5.0,
116
+ value=0.4,
117
+ step=0.1,
118
+ interactive=True,
119
+ label="Sampling temperature",
120
+ info="Higher values will produce more diverse outputs.",
121
+ )
122
+ top_p = gr.Slider(
123
+ minimum=0.01,
124
+ maximum=0.99,
125
+ value=0.8,
126
+ step=0.01,
127
+ interactive=True,
128
+ label="Top P",
129
+ info="Higher values is equivalent to sampling more low-probability tokens.",
130
+ )
131
+ decoding_strategy = gr.Radio(
132
+ [
133
+ "Greedy",
134
+ "Top P Sampling",
135
+ ],
136
+ value="Greedy",
137
+ label="Decoding strategy",
138
+ interactive=True,
139
+ info="Higher values is equivalent to sampling more low-probability tokens.",
140
+ )
141
+ decoding_strategy.change(
142
+ fn=lambda selection: gr.Slider(
143
+ visible=(
144
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
145
+ )
146
+ ),
147
+ inputs=decoding_strategy,
148
+ outputs=temperature,
149
+ )
150
+
151
+ decoding_strategy.change(
152
+ fn=lambda selection: gr.Slider(
153
+ visible=(
154
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
155
+ )
156
+ ),
157
+ inputs=decoding_strategy,
158
+ outputs=repetition_penalty,
159
+ )
160
+ decoding_strategy.change(
161
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
162
+ inputs=decoding_strategy,
163
+ outputs=top_p,
164
+ )
165
+ examples=[["./example_images/docvqa_example.png", "How many items are sold?", "Greedy", 0.4, 512, 1.2, 0.8],
166
+ ["./example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8],
167
+ ["./example_images/example_images_travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", 0.4, 512, 1.2, 0.8],
168
+ ["./example_images/chicken_on_money.png", "Can you tell me a very short story based on this image?", 0.4, 512, 1.2, 0.8],
169
+ ["./example_images/baklava.png", "Where is this pastry from?", 0.4, 512, 1.2, 0.8],
170
+ ["./example_images/dummy_pdf.png", "How much percent is the order status?", 0.4, 512, 1.2, 0.8],
171
+ ["./example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.",
172
+ 0.4, 512, 1.2, 0.8]]
 
 
 
 
173
  ],
174
+
175
+ submit_btn.click(model_inference, inputs = [image_input, query_input, decoding_strategy, temperature,
176
+ max_new_tokens, repetition_penalty, top_p],
177
+ outputs=output)
178
+
179
 
180
  demo.launch(debug=True)