Spaces:
hysts
/
Running on Zero

hysts HF staff commited on
Commit
1cfb0d6
1 Parent(s): a5b3bac
Files changed (1) hide show
  1. app.py +94 -39
app.py CHANGED
@@ -18,10 +18,12 @@ if not torch.cuda.is_available():
18
 
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
 
 
21
  MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
 
22
  MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
23
  MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
24
- assert MODEL_ID in [MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL]
25
 
26
  if torch.cuda.is_available():
27
  processor = AutoProcessor.from_pretrained(MODEL_ID)
@@ -31,10 +33,14 @@ if torch.cuda.is_available():
31
  @spaces.GPU
32
  def generate_caption(
33
  image: PIL.Image.Image,
34
- decoding_method: str,
35
- temperature: float,
36
- length_penalty: float,
37
- repetition_penalty: float,
 
 
 
 
38
  ) -> str:
39
  inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
40
  generated_ids = model.generate(
@@ -43,10 +49,10 @@ def generate_caption(
43
  temperature=temperature,
44
  length_penalty=length_penalty,
45
  repetition_penalty=repetition_penalty,
46
- max_length=50,
47
- min_length=1,
48
- num_beams=5,
49
- top_p=0.9,
50
  )
51
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
52
  return result
@@ -55,23 +61,27 @@ def generate_caption(
55
  @spaces.GPU
56
  def answer_question(
57
  image: PIL.Image.Image,
58
- text: str,
59
- decoding_method: str,
60
- temperature: float,
61
- length_penalty: float,
62
- repetition_penalty: float,
 
 
 
 
63
  ) -> str:
64
- inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
65
  generated_ids = model.generate(
66
  **inputs,
67
  do_sample=decoding_method == "Nucleus sampling",
68
  temperature=temperature,
69
  length_penalty=length_penalty,
70
  repetition_penalty=repetition_penalty,
71
- max_length=30,
72
- min_length=1,
73
- num_beams=5,
74
- top_p=0.9,
75
  )
76
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
77
  return result
@@ -86,10 +96,14 @@ def postprocess_output(output: str) -> str:
86
  def chat(
87
  image: PIL.Image.Image,
88
  text: str,
89
- decoding_method: str,
90
- temperature: float,
91
- length_penalty: float,
92
- repetition_penalty: float,
 
 
 
 
93
  history_orig: list[str] = [],
94
  history_qa: list[str] = [],
95
  ) -> tuple[list[tuple[str, str]], list[str], list[str]]:
@@ -99,12 +113,16 @@ def chat(
99
  prompt = " ".join(history_qa)
100
 
101
  output = answer_question(
102
- image,
103
- prompt,
104
- decoding_method,
105
- temperature,
106
- length_penalty,
107
- repetition_penalty,
 
 
 
 
108
  )
109
  output = postprocess_output(output)
110
  history_orig.append(output)
@@ -160,7 +178,7 @@ with gr.Blocks(css="style.css") as demo:
160
  clear_chat_button = gr.Button("Clear")
161
  chat_button = gr.Button("Submit", variant="primary")
162
  with gr.Accordion(label="Advanced settings", open=False):
163
- sampling_method = gr.Radio(
164
  label="Text Decoding Method",
165
  choices=["Beam search", "Nucleus sampling"],
166
  value="Nucleus sampling",
@@ -170,24 +188,53 @@ with gr.Blocks(css="style.css") as demo:
170
  info="Used with nucleus sampling.",
171
  minimum=0.5,
172
  maximum=1.0,
173
- value=1.0,
174
  step=0.1,
 
175
  )
176
  length_penalty = gr.Slider(
177
  label="Length Penalty",
178
  info="Set to larger for longer sequence, used with beam search.",
179
  minimum=-1.0,
180
  maximum=2.0,
181
- value=1.0,
182
  step=0.2,
 
183
  )
184
- rep_penalty = gr.Slider(
185
- label="Repeat Penalty",
186
  info="Larger value prevents repetition.",
187
  minimum=1.0,
188
  maximum=5.0,
189
- value=1.5,
190
  step=0.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  )
192
 
193
  gr.Examples(
@@ -199,10 +246,14 @@ with gr.Blocks(css="style.css") as demo:
199
  fn=generate_caption,
200
  inputs=[
201
  image,
202
- sampling_method,
203
  temperature,
204
  length_penalty,
205
- rep_penalty,
 
 
 
 
206
  ],
207
  outputs=caption_output,
208
  api_name="caption",
@@ -211,10 +262,14 @@ with gr.Blocks(css="style.css") as demo:
211
  chat_inputs = [
212
  image,
213
  vqa_input,
214
- sampling_method,
215
  temperature,
216
  length_penalty,
217
- rep_penalty,
 
 
 
 
218
  history_orig,
219
  history_qa,
220
  ]
 
18
 
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
 
21
+ MODEL_ID_OPT_2_7B = "Salesforce/blip2-opt-2.7b"
22
  MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
23
+ MODEL_ID_FLAN_T5_XL = "Salesforce/blip2-flan-t5-xl"
24
  MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
25
  MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
26
+ assert MODEL_ID in [MODEL_ID_OPT_2_7B, MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XL, MODEL_ID_FLAN_T5_XXL]
27
 
28
  if torch.cuda.is_available():
29
  processor = AutoProcessor.from_pretrained(MODEL_ID)
 
33
  @spaces.GPU
34
  def generate_caption(
35
  image: PIL.Image.Image,
36
+ decoding_method: str = "Nucleus sampling",
37
+ temperature: float = 1.0,
38
+ length_penalty: float = 1.0,
39
+ repetition_penalty: float = 1.5,
40
+ max_length: int = 50,
41
+ min_length: int = 1,
42
+ num_beams: int = 5,
43
+ top_p: float = 0.9,
44
  ) -> str:
45
  inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
46
  generated_ids = model.generate(
 
49
  temperature=temperature,
50
  length_penalty=length_penalty,
51
  repetition_penalty=repetition_penalty,
52
+ max_length=max_length,
53
+ min_length=min_length,
54
+ num_beams=num_beams,
55
+ top_p=top_p,
56
  )
57
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
58
  return result
 
61
  @spaces.GPU
62
  def answer_question(
63
  image: PIL.Image.Image,
64
+ prompt: str,
65
+ decoding_method: str = "Nucleus sampling",
66
+ temperature: float = 1.0,
67
+ length_penalty: float = 1.0,
68
+ repetition_penalty: float = 1.5,
69
+ max_length: int = 50,
70
+ min_length: int = 1,
71
+ num_beams: int = 5,
72
+ top_p: float = 0.9,
73
  ) -> str:
74
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
75
  generated_ids = model.generate(
76
  **inputs,
77
  do_sample=decoding_method == "Nucleus sampling",
78
  temperature=temperature,
79
  length_penalty=length_penalty,
80
  repetition_penalty=repetition_penalty,
81
+ max_length=max_length,
82
+ min_length=min_length,
83
+ num_beams=num_beams,
84
+ top_p=top_p,
85
  )
86
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
87
  return result
 
96
  def chat(
97
  image: PIL.Image.Image,
98
  text: str,
99
+ decoding_method: str = "Nucleus sampling",
100
+ temperature: float = 1.0,
101
+ length_penalty: float = 1.0,
102
+ repetition_penalty: float = 1.5,
103
+ max_length: int = 50,
104
+ min_length: int = 1,
105
+ num_beams: int = 5,
106
+ top_p: float = 0.9,
107
  history_orig: list[str] = [],
108
  history_qa: list[str] = [],
109
  ) -> tuple[list[tuple[str, str]], list[str], list[str]]:
 
113
  prompt = " ".join(history_qa)
114
 
115
  output = answer_question(
116
+ image=image,
117
+ prompt=prompt,
118
+ decoding_method=decoding_method,
119
+ temperature=temperature,
120
+ length_penalty=length_penalty,
121
+ repetition_penalty=repetition_penalty,
122
+ max_length=max_length,
123
+ min_length=min_length,
124
+ num_beams=num_beams,
125
+ top_p=top_p,
126
  )
127
  output = postprocess_output(output)
128
  history_orig.append(output)
 
178
  clear_chat_button = gr.Button("Clear")
179
  chat_button = gr.Button("Submit", variant="primary")
180
  with gr.Accordion(label="Advanced settings", open=False):
181
+ text_decoding_method = gr.Radio(
182
  label="Text Decoding Method",
183
  choices=["Beam search", "Nucleus sampling"],
184
  value="Nucleus sampling",
 
188
  info="Used with nucleus sampling.",
189
  minimum=0.5,
190
  maximum=1.0,
 
191
  step=0.1,
192
+ value=1.0,
193
  )
194
  length_penalty = gr.Slider(
195
  label="Length Penalty",
196
  info="Set to larger for longer sequence, used with beam search.",
197
  minimum=-1.0,
198
  maximum=2.0,
 
199
  step=0.2,
200
+ value=1.0,
201
  )
202
+ repetition_penalty = gr.Slider(
203
+ label="Repetition Penalty",
204
  info="Larger value prevents repetition.",
205
  minimum=1.0,
206
  maximum=5.0,
 
207
  step=0.5,
208
+ value=1.5,
209
+ )
210
+ max_length = gr.Slider(
211
+ label="Max Length",
212
+ minimum=1,
213
+ maximum=512,
214
+ step=1,
215
+ value=50,
216
+ )
217
+ min_length = gr.Slider(
218
+ label="Minimum Length",
219
+ minimum=1,
220
+ maximum=100,
221
+ step=1,
222
+ value=1,
223
+ )
224
+ num_beams = gr.Slider(
225
+ label="Number of Beams",
226
+ minimum=1,
227
+ maximum=10,
228
+ step=1,
229
+ value=5,
230
+ )
231
+ top_p = gr.Slider(
232
+ label="Top P",
233
+ info="Used with nucleus sampling.",
234
+ minimum=0.5,
235
+ maximum=1.0,
236
+ step=0.1,
237
+ value=0.9,
238
  )
239
 
240
  gr.Examples(
 
246
  fn=generate_caption,
247
  inputs=[
248
  image,
249
+ text_decoding_method,
250
  temperature,
251
  length_penalty,
252
+ repetition_penalty,
253
+ max_length,
254
+ min_length,
255
+ num_beams,
256
+ top_p,
257
  ],
258
  outputs=caption_output,
259
  api_name="caption",
 
262
  chat_inputs = [
263
  image,
264
  vqa_input,
265
+ text_decoding_method,
266
  temperature,
267
  length_penalty,
268
+ repetition_penalty,
269
+ max_length,
270
+ min_length,
271
+ num_beams,
272
+ top_p,
273
  history_orig,
274
  history_qa,
275
  ]