hysts HF staff commited on
Commit
11cddad
1 Parent(s): deb9721

Migrate from yapf to black

Browse files
Files changed (6) hide show
  1. .pre-commit-config.yaml +26 -13
  2. .style.yapf +0 -5
  3. .vscode/settings.json +11 -8
  4. app.py +92 -75
  5. model.py +7 -13
  6. style.css +7 -0
.pre-commit-config.yaml CHANGED
@@ -1,7 +1,6 @@
1
- exclude: patch
2
  repos:
3
  - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
  hooks:
6
  - id: check-executables-have-shebangs
7
  - id: check-json
@@ -9,29 +8,43 @@ repos:
9
  - id: check-shebang-scripts-are-executable
10
  - id: check-toml
11
  - id: check-yaml
12
- - id: double-quote-string-fixer
13
  - id: end-of-file-fixer
14
  - id: mixed-line-ending
15
- args: ['--fix=lf']
16
  - id: requirements-txt-fixer
17
  - id: trailing-whitespace
18
  - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
  hooks:
21
  - id: docformatter
22
- args: ['--in-place']
23
  - repo: https://github.com/pycqa/isort
24
  rev: 5.12.0
25
  hooks:
26
  - id: isort
 
27
  - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
  hooks:
30
  - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
  hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  repos:
2
  - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.4.0
4
  hooks:
5
  - id: check-executables-have-shebangs
6
  - id: check-json
 
8
  - id: check-shebang-scripts-are-executable
9
  - id: check-toml
10
  - id: check-yaml
 
11
  - id: end-of-file-fixer
12
  - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
  - id: requirements-txt-fixer
15
  - id: trailing-whitespace
16
  - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
  hooks:
19
  - id: docformatter
20
+ args: ["--in-place"]
21
  - repo: https://github.com/pycqa/isort
22
  rev: 5.12.0
23
  hooks:
24
  - id: isort
25
+ args: ["--profile", "black"]
26
  - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.5.1
28
  hooks:
29
  - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies: ["types-python-slugify", "types-requests", "types-PyYAML"]
32
+ - repo: https://github.com/psf/black
33
+ rev: 23.9.1
34
  hooks:
35
+ - id: black
36
+ language_version: python3.10
37
+ args: ["--line-length", "119"]
38
+ - repo: https://github.com/kynan/nbstripout
39
+ rev: 0.6.1
40
+ hooks:
41
+ - id: nbstripout
42
+ args: ["--extra-keys", "metadata.interpreter metadata.kernelspec cell.metadata.pycharm"]
43
+ - repo: https://github.com/nbQA-dev/nbQA
44
+ rev: 1.7.0
45
+ hooks:
46
+ - id: nbqa-black
47
+ - id: nbqa-pyupgrade
48
+ args: ["--py37-plus"]
49
+ - id: nbqa-isort
50
+ args: ["--float-to-top"]
.style.yapf DELETED
@@ -1,5 +0,0 @@
1
- [style]
2
- based_on_style = pep8
3
- blank_line_before_nested_class_or_def = false
4
- spaces_before_comment = 2
5
- split_before_logical_operator = true
 
 
 
 
 
 
.vscode/settings.json CHANGED
@@ -1,18 +1,21 @@
1
  {
2
- "python.linting.enabled": true,
3
- "python.linting.flake8Enabled": true,
4
- "python.linting.pylintEnabled": false,
5
- "python.linting.lintOnSave": true,
6
- "python.formatting.provider": "yapf",
7
- "python.formatting.yapfArgs": [
8
- "--style={based_on_style: pep8, indent_width: 4, blank_line_before_nested_class_or_def: false, spaces_before_comment: 2, split_before_logical_operator: true}"
9
- ],
10
  "[python]": {
 
11
  "editor.formatOnType": true,
12
  "editor.codeActionsOnSave": {
13
  "source.organizeImports": true
14
  }
15
  },
 
 
 
 
 
 
 
 
 
 
16
  "editor.formatOnSave": true,
17
  "files.insertFinalNewline": true
18
  }
 
1
  {
 
 
 
 
 
 
 
 
2
  "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
  "editor.formatOnType": true,
5
  "editor.codeActionsOnSave": {
6
  "source.organizeImports": true
7
  }
8
  },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
  "editor.formatOnSave": true,
20
  "files.insertFinalNewline": true
21
  }
app.py CHANGED
@@ -9,11 +9,13 @@ import PIL.Image
9
 
10
  from model import Model
11
 
12
- DESCRIPTION = '''# Attend-and-Excite
 
 
13
  This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826).
14
  Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion.
15
  Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
16
- '''
17
 
18
  model = Model()
19
 
@@ -28,148 +30,157 @@ def process_example(
28
  guidance_scale = 7.5
29
 
30
  token_table = model.get_token_table(prompt)
31
- result = model.run(prompt, indices_to_alter_str, seed,
32
- apply_attend_and_excite, num_steps, guidance_scale)
33
  return token_table, result
34
 
35
 
36
- with gr.Blocks(css='style.css') as demo:
37
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
38
 
39
  with gr.Row():
40
  with gr.Column():
41
  prompt = gr.Text(
42
- label='Prompt',
43
  max_lines=1,
44
- placeholder=
45
- 'A pod of dolphins leaping out of the water in an ocean with a ship on the background'
46
  )
47
- with gr.Accordion(label='Check token indices', open=False):
48
- show_token_indices_button = gr.Button('Show token indices')
49
- token_indices_table = gr.Dataframe(label='Token indices',
50
- headers=['Index', 'Token'],
51
- col_count=2)
52
  token_indices_str = gr.Text(
53
- label=
54
- 'Token indices (a comma-separated list indices of the tokens you wish to alter)',
55
  max_lines=1,
56
- placeholder='4,16')
57
- seed = gr.Slider(label='Seed',
58
- minimum=0,
59
- maximum=100000,
60
- value=0,
61
- step=1)
62
- apply_attend_and_excite = gr.Checkbox(
63
- label='Apply Attend-and-Excite', value=True)
64
- num_steps = gr.Slider(label='Number of steps',
65
- minimum=0,
66
- maximum=100,
67
- step=1,
68
- value=50)
69
- guidance_scale = gr.Slider(label='CFG scale',
70
- minimum=0,
71
- maximum=50,
72
- step=0.1,
73
- value=7.5)
74
- run_button = gr.Button('Generate')
 
 
 
 
 
 
75
  with gr.Column():
76
- result = gr.Image(label='Result')
77
 
78
  with gr.Row():
79
  examples = [
80
  [
81
- 'A mouse and a red car',
82
- '2,6',
83
  2098,
84
  True,
85
  ],
86
  [
87
- 'A mouse and a red car',
88
- '2,6',
89
  2098,
90
  False,
91
  ],
92
  [
93
- 'A horse and a dog',
94
- '2,5',
95
  123,
96
  True,
97
  ],
98
  [
99
- 'A horse and a dog',
100
- '2,5',
101
  123,
102
  False,
103
  ],
104
  [
105
- 'A painting of an elephant with glasses',
106
- '5,7',
107
  123,
108
  True,
109
  ],
110
  [
111
- 'A painting of an elephant with glasses',
112
- '5,7',
113
  123,
114
  False,
115
  ],
116
  [
117
- 'A playful kitten chasing a butterfly in a wildflower meadow',
118
- '3,6,10',
119
  123,
120
  True,
121
  ],
122
  [
123
- 'A playful kitten chasing a butterfly in a wildflower meadow',
124
- '3,6,10',
125
  123,
126
  False,
127
  ],
128
  [
129
- 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
130
- '2,6,15',
131
  123,
132
  True,
133
  ],
134
  [
135
- 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
136
- '2,6,15',
137
  123,
138
  False,
139
  ],
140
  [
141
- 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
142
- '4,16',
143
  123,
144
  True,
145
  ],
146
  [
147
- 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
148
- '4,16',
149
  123,
150
  False,
151
  ],
152
  ]
153
- gr.Examples(examples=examples,
154
- inputs=[
155
- prompt,
156
- token_indices_str,
157
- seed,
158
- apply_attend_and_excite,
159
- ],
160
- outputs=[
161
- token_indices_table,
162
- result,
163
- ],
164
- fn=process_example,
165
- cache_examples=os.getenv('CACHE_EXAMPLES') == '1',
166
- examples_per_page=20)
 
 
167
 
168
  show_token_indices_button.click(
169
  fn=model.get_token_table,
170
  inputs=prompt,
171
  outputs=token_indices_table,
172
  queue=False,
 
173
  )
174
 
175
  inputs = [
@@ -185,31 +196,37 @@ with gr.Blocks(css='style.css') as demo:
185
  inputs=prompt,
186
  outputs=token_indices_table,
187
  queue=False,
 
188
  ).then(
189
  fn=model.run,
190
  inputs=inputs,
191
  outputs=result,
 
192
  )
193
  token_indices_str.submit(
194
  fn=model.get_token_table,
195
  inputs=prompt,
196
  outputs=token_indices_table,
197
  queue=False,
 
198
  ).then(
199
  fn=model.run,
200
  inputs=inputs,
201
  outputs=result,
 
202
  )
203
  run_button.click(
204
  fn=model.get_token_table,
205
  inputs=prompt,
206
  outputs=token_indices_table,
207
  queue=False,
 
208
  ).then(
209
  fn=model.run,
210
  inputs=inputs,
211
  outputs=result,
212
- api_name='run',
213
  )
214
 
215
- demo.queue(max_size=10).launch()
 
 
9
 
10
  from model import Model
11
 
12
+ DESCRIPTION = """\
13
+ # Attend-and-Excite
14
+
15
  This is a demo for [Attend-and-Excite](https://arxiv.org/abs/2301.13826).
16
  Attend-and-Excite performs attention-based generative semantic guidance to mitigate subject neglect in Stable Diffusion.
17
  Select a prompt and a set of indices matching the subjects you wish to strengthen (the `Check token indices` cell can help map between a word and its index).
18
+ """
19
 
20
  model = Model()
21
 
 
30
  guidance_scale = 7.5
31
 
32
  token_table = model.get_token_table(prompt)
33
+ result = model.run(prompt, indices_to_alter_str, seed, apply_attend_and_excite, num_steps, guidance_scale)
 
34
  return token_table, result
35
 
36
 
37
+ with gr.Blocks(css="style.css") as demo:
38
  gr.Markdown(DESCRIPTION)
39
+ gr.DuplicateButton(
40
+ value="Duplicate Space for private use",
41
+ elem_id="duplicate-button",
42
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
43
+ )
44
 
45
  with gr.Row():
46
  with gr.Column():
47
  prompt = gr.Text(
48
+ label="Prompt",
49
  max_lines=1,
50
+ placeholder="A pod of dolphins leaping out of the water in an ocean with a ship on the background",
 
51
  )
52
+ with gr.Accordion(label="Check token indices", open=False):
53
+ show_token_indices_button = gr.Button("Show token indices")
54
+ token_indices_table = gr.Dataframe(label="Token indices", headers=["Index", "Token"], col_count=2)
 
 
55
  token_indices_str = gr.Text(
56
+ label="Token indices (a comma-separated list indices of the tokens you wish to alter)",
 
57
  max_lines=1,
58
+ placeholder="4,16",
59
+ )
60
+ seed = gr.Slider(
61
+ label="Seed",
62
+ minimum=0,
63
+ maximum=100000,
64
+ step=1,
65
+ value=0,
66
+ )
67
+ apply_attend_and_excite = gr.Checkbox(label="Apply Attend-and-Excite", value=True)
68
+ num_steps = gr.Slider(
69
+ label="Number of steps",
70
+ minimum=0,
71
+ maximum=100,
72
+ step=1,
73
+ value=50,
74
+ )
75
+ guidance_scale = gr.Slider(
76
+ label="CFG scale",
77
+ minimum=0,
78
+ maximum=50,
79
+ step=0.1,
80
+ value=7.5,
81
+ )
82
+ run_button = gr.Button("Generate")
83
  with gr.Column():
84
+ result = gr.Image(label="Result")
85
 
86
  with gr.Row():
87
  examples = [
88
  [
89
+ "A mouse and a red car",
90
+ "2,6",
91
  2098,
92
  True,
93
  ],
94
  [
95
+ "A mouse and a red car",
96
+ "2,6",
97
  2098,
98
  False,
99
  ],
100
  [
101
+ "A horse and a dog",
102
+ "2,5",
103
  123,
104
  True,
105
  ],
106
  [
107
+ "A horse and a dog",
108
+ "2,5",
109
  123,
110
  False,
111
  ],
112
  [
113
+ "A painting of an elephant with glasses",
114
+ "5,7",
115
  123,
116
  True,
117
  ],
118
  [
119
+ "A painting of an elephant with glasses",
120
+ "5,7",
121
  123,
122
  False,
123
  ],
124
  [
125
+ "A playful kitten chasing a butterfly in a wildflower meadow",
126
+ "3,6,10",
127
  123,
128
  True,
129
  ],
130
  [
131
+ "A playful kitten chasing a butterfly in a wildflower meadow",
132
+ "3,6,10",
133
  123,
134
  False,
135
  ],
136
  [
137
+ "A grizzly bear catching a salmon in a crystal clear river surrounded by a forest",
138
+ "2,6,15",
139
  123,
140
  True,
141
  ],
142
  [
143
+ "A grizzly bear catching a salmon in a crystal clear river surrounded by a forest",
144
+ "2,6,15",
145
  123,
146
  False,
147
  ],
148
  [
149
+ "A pod of dolphins leaping out of the water in an ocean with a ship on the background",
150
+ "4,16",
151
  123,
152
  True,
153
  ],
154
  [
155
+ "A pod of dolphins leaping out of the water in an ocean with a ship on the background",
156
+ "4,16",
157
  123,
158
  False,
159
  ],
160
  ]
161
+ gr.Examples(
162
+ examples=examples,
163
+ inputs=[
164
+ prompt,
165
+ token_indices_str,
166
+ seed,
167
+ apply_attend_and_excite,
168
+ ],
169
+ outputs=[
170
+ token_indices_table,
171
+ result,
172
+ ],
173
+ fn=process_example,
174
+ cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
175
+ examples_per_page=20,
176
+ )
177
 
178
  show_token_indices_button.click(
179
  fn=model.get_token_table,
180
  inputs=prompt,
181
  outputs=token_indices_table,
182
  queue=False,
183
+ api_name=False,
184
  )
185
 
186
  inputs = [
 
196
  inputs=prompt,
197
  outputs=token_indices_table,
198
  queue=False,
199
+ api_name=False,
200
  ).then(
201
  fn=model.run,
202
  inputs=inputs,
203
  outputs=result,
204
+ api_name=False,
205
  )
206
  token_indices_str.submit(
207
  fn=model.get_token_table,
208
  inputs=prompt,
209
  outputs=token_indices_table,
210
  queue=False,
211
+ api_name=False,
212
  ).then(
213
  fn=model.run,
214
  inputs=inputs,
215
  outputs=result,
216
+ api_name=False,
217
  )
218
  run_button.click(
219
  fn=model.get_token_table,
220
  inputs=prompt,
221
  outputs=token_indices_table,
222
  queue=False,
223
+ api_name=False,
224
  ).then(
225
  fn=model.run,
226
  inputs=inputs,
227
  outputs=result,
228
+ api_name="run",
229
  )
230
 
231
+ if __name__ == "__main__":
232
+ demo.queue(max_size=10).launch()
model.py CHANGED
@@ -2,26 +2,20 @@ from __future__ import annotations
2
 
3
  import PIL.Image
4
  import torch
5
- from diffusers import (StableDiffusionAttendAndExcitePipeline,
6
- StableDiffusionPipeline)
7
 
8
 
9
  class Model:
10
  def __init__(self):
11
- self.device = torch.device(
12
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
13
- model_id = 'CompVis/stable-diffusion-v1-4'
14
- self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
15
- model_id)
16
  self.ax_pipe.to(self.device)
17
  self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
18
  self.sd_pipe.to(self.device)
19
 
20
  def get_token_table(self, prompt: str):
21
- tokens = [
22
- self.ax_pipe.tokenizer.decode(t)
23
- for t in self.ax_pipe.tokenizer(prompt)['input_ids']
24
- ]
25
  tokens = tokens[1:-1]
26
  return list(enumerate(tokens, start=1))
27
 
@@ -44,9 +38,9 @@ class Model:
44
 
45
  if apply_attend_and_excite:
46
  try:
47
- token_indices = list(map(int, indices_to_alter_str.split(',')))
48
  except Exception:
49
- raise ValueError('Invalid token indices.')
50
  out = self.ax_pipe(
51
  prompt=prompt,
52
  token_indices=token_indices,
 
2
 
3
  import PIL.Image
4
  import torch
5
+ from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline
 
6
 
7
 
8
  class Model:
9
  def __init__(self):
10
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
11
+ model_id = "CompVis/stable-diffusion-v1-4"
12
+ self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id)
 
 
13
  self.ax_pipe.to(self.device)
14
  self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
15
  self.sd_pipe.to(self.device)
16
 
17
  def get_token_table(self, prompt: str):
18
+ tokens = [self.ax_pipe.tokenizer.decode(t) for t in self.ax_pipe.tokenizer(prompt)["input_ids"]]
 
 
 
19
  tokens = tokens[1:-1]
20
  return list(enumerate(tokens, start=1))
21
 
 
38
 
39
  if apply_attend_and_excite:
40
  try:
41
+ token_indices = list(map(int, indices_to_alter_str.split(",")))
42
  except Exception:
43
+ raise ValueError("Invalid token indices.")
44
  out = self.ax_pipe(
45
  prompt=prompt,
46
  token_indices=token_indices,
style.css CHANGED
@@ -1,3 +1,10 @@
1
  h1 {
2
  text-align: center;
3
  }
 
 
 
 
 
 
 
 
1
  h1 {
2
  text-align: center;
3
  }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: #fff;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }