hysts HF staff commited on
Commit
40ccb44
1 Parent(s): 056e5c6
Files changed (2) hide show
  1. app.py +13 -53
  2. style.css +7 -0
app.py CHANGED
@@ -12,8 +12,6 @@ from transformers import AutoProcessor, Blip2ForConditionalGeneration
12
 
13
  DESCRIPTION = "# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
14
 
15
- if (SPACE_ID := os.getenv("SPACE_ID")) is not None:
16
- DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
17
  if not torch.cuda.is_available():
18
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
19
 
@@ -21,40 +19,23 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
22
  MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
23
  MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
 
24
 
25
  if torch.cuda.is_available():
26
- model_dict = {
27
- # MODEL_ID_OPT_6_7B: {
28
- # 'processor':
29
- # AutoProcessor.from_pretrained(MODEL_ID_OPT_6_7B),
30
- # 'model':
31
- # Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_OPT_6_7B,
32
- # device_map='auto',
33
- # load_in_8bit=True),
34
- # },
35
- MODEL_ID_FLAN_T5_XXL: {
36
- "processor": AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL),
37
- "model": Blip2ForConditionalGeneration.from_pretrained(
38
- MODEL_ID_FLAN_T5_XXL, device_map="auto", load_in_8bit=True
39
- ),
40
- }
41
- }
42
  else:
43
- model_dict = {}
 
44
 
45
 
46
  def generate_caption(
47
- model_id: str,
48
  image: PIL.Image.Image,
49
  decoding_method: str,
50
  temperature: float,
51
  length_penalty: float,
52
  repetition_penalty: float,
53
  ) -> str:
54
- model_info = model_dict[model_id]
55
- processor = model_info["processor"]
56
- model = model_info["model"]
57
-
58
  inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
59
  generated_ids = model.generate(
60
  pixel_values=inputs.pixel_values,
@@ -72,7 +53,6 @@ def generate_caption(
72
 
73
 
74
  def answer_question(
75
- model_id: str,
76
  image: PIL.Image.Image,
77
  text: str,
78
  decoding_method: str,
@@ -80,10 +60,6 @@ def answer_question(
80
  length_penalty: float,
81
  repetition_penalty: float,
82
  ) -> str:
83
- model_info = model_dict[model_id]
84
- processor = model_info["processor"]
85
- model = model_info["model"]
86
-
87
  inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
88
  generated_ids = model.generate(
89
  **inputs,
@@ -107,7 +83,6 @@ def postprocess_output(output: str) -> str:
107
 
108
 
109
  def chat(
110
- model_id: str,
111
  image: PIL.Image.Image,
112
  text: str,
113
  decoding_method: str,
@@ -123,7 +98,6 @@ def chat(
123
  prompt = " ".join(history_qa)
124
 
125
  output = answer_question(
126
- model_id,
127
  image,
128
  prompt,
129
  decoding_method,
@@ -164,24 +138,14 @@ examples = [
164
 
165
  with gr.Blocks(css="style.css") as demo:
166
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
167
 
168
  image = gr.Image(type="pil")
169
  with gr.Accordion(label="Advanced settings", open=False):
170
- with gr.Row():
171
- model_id_caption = gr.Dropdown(
172
- label="Model ID for image captioning",
173
- choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
174
- value=MODEL_ID_FLAN_T5_XXL,
175
- interactive=False,
176
- visible=False,
177
- )
178
- model_id_chat = gr.Dropdown(
179
- label="Model ID for VQA",
180
- choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
181
- value=MODEL_ID_FLAN_T5_XXL,
182
- interactive=False,
183
- visible=False,
184
- )
185
  sampling_method = gr.Radio(
186
  label="Text Decoding Method",
187
  choices=["Beam search", "Nucleus sampling"],
@@ -225,16 +189,12 @@ with gr.Blocks(css="style.css") as demo:
225
 
226
  gr.Examples(
227
  examples=examples,
228
- inputs=[
229
- image,
230
- vqa_input,
231
- ],
232
  )
233
 
234
  caption_button.click(
235
  fn=generate_caption,
236
  inputs=[
237
- model_id_caption,
238
  image,
239
  sampling_method,
240
  temperature,
@@ -246,7 +206,6 @@ with gr.Blocks(css="style.css") as demo:
246
  )
247
 
248
  chat_inputs = [
249
- model_id_chat,
250
  image,
251
  vqa_input,
252
  sampling_method,
@@ -296,4 +255,5 @@ with gr.Blocks(css="style.css") as demo:
296
  queue=False,
297
  )
298
 
299
- demo.queue(max_size=10).launch()
 
 
12
 
13
  DESCRIPTION = "# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
14
 
 
 
15
  if not torch.cuda.is_available():
16
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
17
 
 
19
 
20
  MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
21
  MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
22
+ MODEL_ID = os.getenv("MODEL_ID", MODEL_ID_FLAN_T5_XXL)
23
 
24
  if torch.cuda.is_available():
25
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
26
+ model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto", load_in_8bit=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  else:
28
+ processor = None
29
+ model = None
30
 
31
 
32
  def generate_caption(
 
33
  image: PIL.Image.Image,
34
  decoding_method: str,
35
  temperature: float,
36
  length_penalty: float,
37
  repetition_penalty: float,
38
  ) -> str:
 
 
 
 
39
  inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
40
  generated_ids = model.generate(
41
  pixel_values=inputs.pixel_values,
 
53
 
54
 
55
  def answer_question(
 
56
  image: PIL.Image.Image,
57
  text: str,
58
  decoding_method: str,
 
60
  length_penalty: float,
61
  repetition_penalty: float,
62
  ) -> str:
 
 
 
 
63
  inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
64
  generated_ids = model.generate(
65
  **inputs,
 
83
 
84
 
85
  def chat(
 
86
  image: PIL.Image.Image,
87
  text: str,
88
  decoding_method: str,
 
98
  prompt = " ".join(history_qa)
99
 
100
  output = answer_question(
 
101
  image,
102
  prompt,
103
  decoding_method,
 
138
 
139
  with gr.Blocks(css="style.css") as demo:
140
  gr.Markdown(DESCRIPTION)
141
+ gr.DuplicateButton(
142
+ value="Duplicate Space for private use",
143
+ elem_id="duplicate-button",
144
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
145
+ )
146
 
147
  image = gr.Image(type="pil")
148
  with gr.Accordion(label="Advanced settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  sampling_method = gr.Radio(
150
  label="Text Decoding Method",
151
  choices=["Beam search", "Nucleus sampling"],
 
189
 
190
  gr.Examples(
191
  examples=examples,
192
+ inputs=[image, vqa_input],
 
 
 
193
  )
194
 
195
  caption_button.click(
196
  fn=generate_caption,
197
  inputs=[
 
198
  image,
199
  sampling_method,
200
  temperature,
 
206
  )
207
 
208
  chat_inputs = [
 
209
  image,
210
  vqa_input,
211
  sampling_method,
 
255
  queue=False,
256
  )
257
 
258
+ if __name__ == "__main__":
259
+ demo.queue(max_size=10).launch()
style.css CHANGED
@@ -1,3 +1,10 @@
1
  h1 {
2
  text-align: center;
3
  }
 
 
 
 
 
 
 
 
1
  h1 {
2
  text-align: center;
3
  }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: #fff;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }