hysts HF staff commited on
Commit
8781a5b
1 Parent(s): c9147bf

Migrate from yapf to black

Browse files
Files changed (4) hide show
  1. .pre-commit-config.yaml +26 -13
  2. .style.yapf +0 -5
  3. .vscode/settings.json +21 -0
  4. app.py +92 -89
.pre-commit-config.yaml CHANGED
@@ -1,7 +1,6 @@
1
- exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
  hooks:
6
  - id: check-executables-have-shebangs
7
  - id: check-json
@@ -9,29 +8,43 @@ repos:
9
  - id: check-shebang-scripts-are-executable
10
  - id: check-toml
11
  - id: check-yaml
12
- - id: double-quote-string-fixer
13
  - id: end-of-file-fixer
14
  - id: mixed-line-ending
15
- args: ['--fix=lf']
16
  - id: requirements-txt-fixer
17
  - id: trailing-whitespace
18
  - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
  hooks:
21
  - id: docformatter
22
- args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
  rev: 5.12.0
25
  hooks:
26
  - id: isort
 
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
  hooks:
30
  - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
  hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  repos:
2
  - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
  hooks:
5
  - id: check-executables-have-shebangs
6
  - id: check-json
 
8
  - id: check-shebang-scripts-are-executable
9
  - id: check-toml
10
  - id: check-yaml
 
11
  - id: end-of-file-fixer
12
  - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
  - id: requirements-txt-fixer
15
  - id: trailing-whitespace
16
  - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
  hooks:
19
  - id: docformatter
20
+ args: ["--in-place"]
21
  - repo: https://github.com/pycqa/isort
22
  rev: 5.12.0
23
  hooks:
24
  - id: isort
25
+ args: ["--profile", "black"]
26
  - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.5.1
28
  hooks:
29
  - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies: ["types-python-slugify", "types-requests", "types-PyYAML"]
32
+ - repo: https://github.com/psf/black
33
+ rev: 23.9.1
34
  hooks:
35
+ - id: black
36
+ language_version: python3.10
37
+ args: ["--line-length", "119"]
38
+ - repo: https://github.com/kynan/nbstripout
39
+ rev: 0.6.1
40
+ hooks:
41
+ - id: nbstripout
42
+ args: ["--extra-keys", "metadata.interpreter metadata.kernelspec cell.metadata.pycharm"]
43
+ - repo: https://github.com/nbQA-dev/nbQA
44
+ rev: 1.7.0
45
+ hooks:
46
+ - id: nbqa-black
47
+ - id: nbqa-pyupgrade
48
+ args: ["--py37-plus"]
49
+ - id: nbqa-isort
50
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
+ "editor.formatOnType": true,
5
+ "editor.codeActionsOnSave": {
6
+ "source.organizeImports": true
7
+ }
8
+ },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
+ "editor.formatOnSave": true,
20
+ "files.insertFinalNewline": true
21
+ }
app.py CHANGED
@@ -10,92 +10,99 @@ import PIL.Image
10
  import torch
11
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
12
 
13
- DESCRIPTION = '# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)'
14
 
15
- if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
16
  DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
17
  if not torch.cuda.is_available():
18
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
19
 
20
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21
 
22
- MODEL_ID_OPT_6_7B = 'Salesforce/blip2-opt-6.7b'
23
- MODEL_ID_FLAN_T5_XXL = 'Salesforce/blip2-flan-t5-xxl'
24
 
25
  if torch.cuda.is_available():
26
  model_dict = {
27
- #MODEL_ID_OPT_6_7B: {
28
  # 'processor':
29
  # AutoProcessor.from_pretrained(MODEL_ID_OPT_6_7B),
30
  # 'model':
31
  # Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_OPT_6_7B,
32
  # device_map='auto',
33
  # load_in_8bit=True),
34
- #},
35
  MODEL_ID_FLAN_T5_XXL: {
36
- 'processor':
37
- AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL),
38
- 'model':
39
- Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL,
40
- device_map='auto',
41
- load_in_8bit=True),
42
  }
43
  }
44
  else:
45
  model_dict = {}
46
 
47
 
48
- def generate_caption(model_id: str, image: PIL.Image.Image,
49
- decoding_method: str, temperature: float,
50
- length_penalty: float, repetition_penalty: float) -> str:
 
 
 
 
 
51
  model_info = model_dict[model_id]
52
- processor = model_info['processor']
53
- model = model_info['model']
54
 
55
- inputs = processor(images=image,
56
- return_tensors='pt').to(device, torch.float16)
57
  generated_ids = model.generate(
58
  pixel_values=inputs.pixel_values,
59
- do_sample=decoding_method == 'Nucleus sampling',
60
  temperature=temperature,
61
  length_penalty=length_penalty,
62
  repetition_penalty=repetition_penalty,
63
  max_length=50,
64
  min_length=1,
65
  num_beams=5,
66
- top_p=0.9)
67
- result = processor.batch_decode(generated_ids,
68
- skip_special_tokens=True)[0].strip()
69
  return result
70
 
71
 
72
- def answer_question(model_id: str, image: PIL.Image.Image, text: str,
73
- decoding_method: str, temperature: float,
74
- length_penalty: float, repetition_penalty: float) -> str:
 
 
 
 
 
 
75
  model_info = model_dict[model_id]
76
- processor = model_info['processor']
77
- model = model_info['model']
78
 
79
- inputs = processor(images=image, text=text,
80
- return_tensors='pt').to(device, torch.float16)
81
- generated_ids = model.generate(**inputs,
82
- do_sample=decoding_method ==
83
- 'Nucleus sampling',
84
- temperature=temperature,
85
- length_penalty=length_penalty,
86
- repetition_penalty=repetition_penalty,
87
- max_length=30,
88
- min_length=1,
89
- num_beams=5,
90
- top_p=0.9)
91
- result = processor.batch_decode(generated_ids,
92
- skip_special_tokens=True)[0].strip()
93
  return result
94
 
95
 
96
  def postprocess_output(output: str) -> str:
97
- if output and not output[-1] in string.punctuation:
98
- output += '.'
99
  return output
100
 
101
 
@@ -111,9 +118,9 @@ def chat(
111
  history_qa: list[str] = [],
112
  ) -> tuple[dict[str, list[str]], dict[str, list[str]], dict[str, list[str]]]:
113
  history_orig.append(text)
114
- text_qa = f'Question: {text} Answer:'
115
  history_qa.append(text_qa)
116
- prompt = ' '.join(history_qa)
117
 
118
  output = answer_question(
119
  model_id,
@@ -129,73 +136,73 @@ def chat(
129
  history_qa.append(output)
130
 
131
  chat_val = list(zip(history_orig[0::2], history_orig[1::2]))
132
- return gr.update(value=chat_val), gr.update(value=history_orig), gr.update(
133
- value=history_qa)
134
 
135
 
136
  examples = [
137
  [
138
- 'house.png',
139
- 'How could someone get out of the house?',
140
  ],
141
  [
142
- 'flower.jpg',
143
- 'What is this flower and where is it\'s origin?',
144
  ],
145
  [
146
- 'pizza.jpg',
147
- 'What are steps to cook it?',
148
  ],
149
  [
150
- 'sunset.jpg',
151
- 'Here is a romantic message going along the photo:',
152
  ],
153
  [
154
- 'forbidden_city.webp',
155
- 'In what dynasties was this place built?',
156
  ],
157
  ]
158
 
159
- with gr.Blocks(css='style.css') as demo:
160
  gr.Markdown(DESCRIPTION)
161
 
162
- image = gr.Image(type='pil')
163
- with gr.Accordion(label='Advanced settings', open=False):
164
  with gr.Row():
165
  model_id_caption = gr.Dropdown(
166
- label='Model ID for image captioning',
167
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
168
  value=MODEL_ID_FLAN_T5_XXL,
169
  interactive=False,
170
- visible=False)
 
171
  model_id_chat = gr.Dropdown(
172
- label='Model ID for VQA',
173
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
174
  value=MODEL_ID_FLAN_T5_XXL,
175
  interactive=False,
176
- visible=False)
 
177
  sampling_method = gr.Radio(
178
- label='Text Decoding Method',
179
- choices=['Beam search', 'Nucleus sampling'],
180
- value='Beam search',
181
  )
182
  temperature = gr.Slider(
183
- label='Temperature (used with nucleus sampling)',
184
  minimum=0.5,
185
  maximum=1.0,
186
  value=1.0,
187
  step=0.1,
188
  )
189
  length_penalty = gr.Slider(
190
- label=
191
- 'Length Penalty (set to larger for longer sequence, used with beam search)',
192
  minimum=-1.0,
193
  maximum=2.0,
194
  value=1.0,
195
  step=0.2,
196
  )
197
  rep_penalty = gr.Slider(
198
- label='Repeat Penalty (larger value prevents repetition)',
199
  minimum=1.0,
200
  maximum=5.0,
201
  value=1.5,
@@ -204,21 +211,17 @@ with gr.Blocks(css='style.css') as demo:
204
  with gr.Row():
205
  with gr.Column():
206
  with gr.Box():
207
- caption_button = gr.Button(value='Caption it!')
208
- caption_output = gr.Textbox(
209
- label='Caption Output',
210
- show_label=False).style(container=False)
211
  with gr.Column():
212
  with gr.Box():
213
- chatbot = gr.Chatbot(label='VQA Chat')
214
  history_orig = gr.State(value=[])
215
  history_qa = gr.State(value=[])
216
- vqa_input = gr.Text(label='Chat Input',
217
- show_label=False,
218
- max_lines=1).style(container=False)
219
  with gr.Row():
220
- clear_chat_button = gr.Button(value='Clear')
221
- chat_button = gr.Button(value='Submit')
222
 
223
  gr.Examples(
224
  examples=examples,
@@ -239,7 +242,7 @@ with gr.Blocks(css='style.css') as demo:
239
  rep_penalty,
240
  ],
241
  outputs=caption_output,
242
- api_name='caption',
243
  )
244
 
245
  chat_inputs = [
@@ -267,10 +270,10 @@ with gr.Blocks(css='style.css') as demo:
267
  fn=chat,
268
  inputs=chat_inputs,
269
  outputs=chat_outputs,
270
- api_name='chat',
271
  )
272
  clear_chat_button.click(
273
- fn=lambda: ('', [], [], []),
274
  inputs=None,
275
  outputs=[
276
  vqa_input,
@@ -279,10 +282,10 @@ with gr.Blocks(css='style.css') as demo:
279
  history_qa,
280
  ],
281
  queue=False,
282
- api_name='clear',
283
  )
284
  image.change(
285
- fn=lambda: ('', [], [], []),
286
  inputs=None,
287
  outputs=[
288
  caption_output,
 
10
  import torch
11
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
12
 
13
+ DESCRIPTION = "# [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
14
 
15
+ if (SPACE_ID := os.getenv("SPACE_ID")) is not None:
16
  DESCRIPTION += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
17
  if not torch.cuda.is_available():
18
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
19
 
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
22
+ MODEL_ID_OPT_6_7B = "Salesforce/blip2-opt-6.7b"
23
+ MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl"
24
 
25
  if torch.cuda.is_available():
26
  model_dict = {
27
+ # MODEL_ID_OPT_6_7B: {
28
  # 'processor':
29
  # AutoProcessor.from_pretrained(MODEL_ID_OPT_6_7B),
30
  # 'model':
31
  # Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_OPT_6_7B,
32
  # device_map='auto',
33
  # load_in_8bit=True),
34
+ # },
35
  MODEL_ID_FLAN_T5_XXL: {
36
+ "processor": AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL),
37
+ "model": Blip2ForConditionalGeneration.from_pretrained(
38
+ MODEL_ID_FLAN_T5_XXL, device_map="auto", load_in_8bit=True
39
+ ),
 
 
40
  }
41
  }
42
  else:
43
  model_dict = {}
44
 
45
 
46
+ def generate_caption(
47
+ model_id: str,
48
+ image: PIL.Image.Image,
49
+ decoding_method: str,
50
+ temperature: float,
51
+ length_penalty: float,
52
+ repetition_penalty: float,
53
+ ) -> str:
54
  model_info = model_dict[model_id]
55
+ processor = model_info["processor"]
56
+ model = model_info["model"]
57
 
58
+ inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
 
59
  generated_ids = model.generate(
60
  pixel_values=inputs.pixel_values,
61
+ do_sample=decoding_method == "Nucleus sampling",
62
  temperature=temperature,
63
  length_penalty=length_penalty,
64
  repetition_penalty=repetition_penalty,
65
  max_length=50,
66
  min_length=1,
67
  num_beams=5,
68
+ top_p=0.9,
69
+ )
70
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
71
  return result
72
 
73
 
74
+ def answer_question(
75
+ model_id: str,
76
+ image: PIL.Image.Image,
77
+ text: str,
78
+ decoding_method: str,
79
+ temperature: float,
80
+ length_penalty: float,
81
+ repetition_penalty: float,
82
+ ) -> str:
83
  model_info = model_dict[model_id]
84
+ processor = model_info["processor"]
85
+ model = model_info["model"]
86
 
87
+ inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16)
88
+ generated_ids = model.generate(
89
+ **inputs,
90
+ do_sample=decoding_method == "Nucleus sampling",
91
+ temperature=temperature,
92
+ length_penalty=length_penalty,
93
+ repetition_penalty=repetition_penalty,
94
+ max_length=30,
95
+ min_length=1,
96
+ num_beams=5,
97
+ top_p=0.9,
98
+ )
99
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
100
  return result
101
 
102
 
103
  def postprocess_output(output: str) -> str:
104
+ if output and output[-1] not in string.punctuation:
105
+ output += "."
106
  return output
107
 
108
 
 
118
  history_qa: list[str] = [],
119
  ) -> tuple[dict[str, list[str]], dict[str, list[str]], dict[str, list[str]]]:
120
  history_orig.append(text)
121
+ text_qa = f"Question: {text} Answer:"
122
  history_qa.append(text_qa)
123
+ prompt = " ".join(history_qa)
124
 
125
  output = answer_question(
126
  model_id,
 
136
  history_qa.append(output)
137
 
138
  chat_val = list(zip(history_orig[0::2], history_orig[1::2]))
139
+ return gr.update(value=chat_val), gr.update(value=history_orig), gr.update(value=history_qa)
 
140
 
141
 
142
  examples = [
143
  [
144
+ "house.png",
145
+ "How could someone get out of the house?",
146
  ],
147
  [
148
+ "flower.jpg",
149
+ "What is this flower and where is it's origin?",
150
  ],
151
  [
152
+ "pizza.jpg",
153
+ "What are steps to cook it?",
154
  ],
155
  [
156
+ "sunset.jpg",
157
+ "Here is a romantic message going along the photo:",
158
  ],
159
  [
160
+ "forbidden_city.webp",
161
+ "In what dynasties was this place built?",
162
  ],
163
  ]
164
 
165
+ with gr.Blocks(css="style.css") as demo:
166
  gr.Markdown(DESCRIPTION)
167
 
168
+ image = gr.Image(type="pil")
169
+ with gr.Accordion(label="Advanced settings", open=False):
170
  with gr.Row():
171
  model_id_caption = gr.Dropdown(
172
+ label="Model ID for image captioning",
173
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
174
  value=MODEL_ID_FLAN_T5_XXL,
175
  interactive=False,
176
+ visible=False,
177
+ )
178
  model_id_chat = gr.Dropdown(
179
+ label="Model ID for VQA",
180
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
181
  value=MODEL_ID_FLAN_T5_XXL,
182
  interactive=False,
183
+ visible=False,
184
+ )
185
  sampling_method = gr.Radio(
186
+ label="Text Decoding Method",
187
+ choices=["Beam search", "Nucleus sampling"],
188
+ value="Beam search",
189
  )
190
  temperature = gr.Slider(
191
+ label="Temperature (used with nucleus sampling)",
192
  minimum=0.5,
193
  maximum=1.0,
194
  value=1.0,
195
  step=0.1,
196
  )
197
  length_penalty = gr.Slider(
198
+ label="Length Penalty (set to larger for longer sequence, used with beam search)",
 
199
  minimum=-1.0,
200
  maximum=2.0,
201
  value=1.0,
202
  step=0.2,
203
  )
204
  rep_penalty = gr.Slider(
205
+ label="Repeat Penalty (larger value prevents repetition)",
206
  minimum=1.0,
207
  maximum=5.0,
208
  value=1.5,
 
211
  with gr.Row():
212
  with gr.Column():
213
  with gr.Box():
214
+ caption_button = gr.Button(value="Caption it!")
215
+ caption_output = gr.Textbox(label="Caption Output", show_label=False).style(container=False)
 
 
216
  with gr.Column():
217
  with gr.Box():
218
+ chatbot = gr.Chatbot(label="VQA Chat")
219
  history_orig = gr.State(value=[])
220
  history_qa = gr.State(value=[])
221
+ vqa_input = gr.Text(label="Chat Input", show_label=False, max_lines=1).style(container=False)
 
 
222
  with gr.Row():
223
+ clear_chat_button = gr.Button(value="Clear")
224
+ chat_button = gr.Button(value="Submit")
225
 
226
  gr.Examples(
227
  examples=examples,
 
242
  rep_penalty,
243
  ],
244
  outputs=caption_output,
245
+ api_name="caption",
246
  )
247
 
248
  chat_inputs = [
 
270
  fn=chat,
271
  inputs=chat_inputs,
272
  outputs=chat_outputs,
273
+ api_name="chat",
274
  )
275
  clear_chat_button.click(
276
+ fn=lambda: ("", [], [], []),
277
  inputs=None,
278
  outputs=[
279
  vqa_input,
 
282
  history_qa,
283
  ],
284
  queue=False,
285
+ api_name="clear",
286
  )
287
  image.change(
288
+ fn=lambda: ("", [], [], []),
289
  inputs=None,
290
  outputs=[
291
  caption_output,