hysts HF staff commited on
Commit
779acf3
1 Parent(s): 145f872
Files changed (10) hide show
  1. .gitignore +163 -0
  2. .gitmodules +3 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. Attend-and-Excite +1 -0
  6. README.md +1 -0
  7. app.py +177 -0
  8. model.py +85 -0
  9. requirements.txt +7 -0
  10. style.css +3 -0
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ [submodule "Attend-and-Excite"]
2
+ path = Attend-and-Excite
3
+ url = https://github.com/AttendAndExcite/Attend-and-Excite
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: patch
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.12.0
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
Attend-and-Excite ADDED
@@ -0,0 +1 @@
 
1
+ Subproject commit 1b67cfc19cd3952e390dbb8047ccd126471567f2
README.md CHANGED
@@ -5,6 +5,7 @@ colorFrom: gray
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.17.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.17.0
8
+ python_version: 3.10.9
9
  app_file: app.py
10
  pinned: false
11
  license: mit
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import gradio as gr
6
+ import PIL.Image
7
+
8
+ from model import Model
9
+
10
+ DESCRIPTION = '''# Attend-and-Excite
11
+
12
+ This is an unofficial demo for [https://github.com/AttendAndExcite/Attend-and-Excite](https://github.com/AttendAndExcite/Attend-and-Excite).
13
+ '''
14
+
15
+ model = Model()
16
+
17
+
18
+ def process_example(
19
+ prompt: str,
20
+ indices_to_alter_str: str,
21
+ seed: int,
22
+ apply_attend_and_excite: bool,
23
+ ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
24
+ model_id = 'CompVis/stable-diffusion-v1-4'
25
+ num_steps = 50
26
+ guidance_scale = 7.5
27
+ return model.run(model_id, prompt, indices_to_alter_str, seed,
28
+ apply_attend_and_excite, num_steps, guidance_scale)
29
+
30
+
31
+ with gr.Blocks(css='style.css') as demo:
32
+ gr.Markdown(DESCRIPTION)
33
+
34
+ with gr.Row():
35
+ with gr.Column():
36
+ model_id = gr.Text(label='Model ID',
37
+ value='CompVis/stable-diffusion-v1-4',
38
+ visible=False)
39
+ prompt = gr.Text(
40
+ label='Prompt',
41
+ max_lines=1,
42
+ placeholder=
43
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background'
44
+ )
45
+ with gr.Accordion(label='Check token indices', open=False):
46
+ show_token_indices_button = gr.Button('Show token indices')
47
+ token_indices_table = gr.Dataframe(label='Token indices',
48
+ headers=['Index', 'Token'],
49
+ col_count=2)
50
+ token_indices_str = gr.Text(
51
+ label=
52
+ 'Token indices (a comma-separated list indices of the tokens you wish to alter)',
53
+ max_lines=1,
54
+ placeholder='4,16')
55
+ seed = gr.Slider(label='Seed',
56
+ minimum=0,
57
+ maximum=100000,
58
+ value=0,
59
+ step=1)
60
+ apply_attend_and_excite = gr.Checkbox(
61
+ label='Apply Attend-and-Excite', value=True)
62
+ num_steps = gr.Slider(label='Number of steps',
63
+ minimum=0,
64
+ maximum=100,
65
+ step=1,
66
+ value=50)
67
+ guidance_scale = gr.Slider(label='CFG scale',
68
+ minimum=0,
69
+ maximum=50,
70
+ step=0.1,
71
+ value=7.5)
72
+ run_button = gr.Button('Generate')
73
+ with gr.Column():
74
+ result = gr.Image(label='Result')
75
+
76
+ with gr.Row():
77
+ examples = [
78
+ [
79
+ 'A horse and a dog',
80
+ '2,5',
81
+ 123,
82
+ True,
83
+ ],
84
+ [
85
+ 'A horse and a dog',
86
+ '2,5',
87
+ 123,
88
+ False,
89
+ ],
90
+ [
91
+ 'A painting of an elephant with glasses',
92
+ '5,7',
93
+ 123,
94
+ True,
95
+ ],
96
+ [
97
+ 'A painting of an elephant with glasses',
98
+ '5,7',
99
+ 123,
100
+ False,
101
+ ],
102
+ [
103
+ 'A playful kitten chasing a butterfly in a wildflower meadow',
104
+ '3,6,10',
105
+ 123,
106
+ True,
107
+ ],
108
+ [
109
+ 'A playful kitten chasing a butterfly in a wildflower meadow',
110
+ '3,6,10',
111
+ 123,
112
+ False,
113
+ ],
114
+ [
115
+ 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
116
+ '2,6,15',
117
+ 123,
118
+ True,
119
+ ],
120
+ [
121
+ 'A grizzly bear catching a salmon in a crystal clear river surrounded by a forest',
122
+ '2,6,15',
123
+ 123,
124
+ False,
125
+ ],
126
+ [
127
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
128
+ '4,16',
129
+ 123,
130
+ True,
131
+ ],
132
+ [
133
+ 'A pod of dolphins leaping out of the water in an ocean with a ship on the background',
134
+ '4,16',
135
+ 123,
136
+ False,
137
+ ],
138
+ ]
139
+ gr.Examples(examples=examples,
140
+ inputs=[
141
+ prompt,
142
+ token_indices_str,
143
+ seed,
144
+ apply_attend_and_excite,
145
+ ],
146
+ outputs=[
147
+ token_indices_table,
148
+ result,
149
+ ],
150
+ fn=process_example,
151
+ cache_examples=True)
152
+
153
+ show_token_indices_button.click(fn=model.get_token_table,
154
+ inputs=[
155
+ model_id,
156
+ prompt,
157
+ ],
158
+ outputs=token_indices_table)
159
+
160
+ inputs = [
161
+ model_id,
162
+ prompt,
163
+ token_indices_str,
164
+ seed,
165
+ apply_attend_and_excite,
166
+ num_steps,
167
+ guidance_scale,
168
+ ]
169
+ outputs = [
170
+ token_indices_table,
171
+ result,
172
+ ]
173
+ prompt.submit(fn=model.run, inputs=inputs, outputs=outputs)
174
+ token_indices_str.submit(fn=model.run, inputs=inputs, outputs=outputs)
175
+ run_button.click(fn=model.run, inputs=inputs, outputs=outputs)
176
+
177
+ demo.queue(max_size=1).launch(share=False)
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+ import gradio as gr
6
+ import PIL.Image
7
+ import torch
8
+
9
+ sys.path.append('Attend-and-Excite')
10
+
11
+ from config import RunConfig
12
+ from pipeline_attend_and_excite import AttendAndExcitePipeline
13
+ from run import run_on_prompt
14
+ from utils.ptp_utils import AttentionStore
15
+
16
+
17
+ class Model:
18
+ def __init__(self):
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.model_id = ''
22
+ self.model = None
23
+ self.tokenizer = None
24
+
25
+ self.load_model('CompVis/stable-diffusion-v1-4')
26
+
27
+ def load_model(self, model_id: str) -> None:
28
+ if model_id == self.model_id:
29
+ return
30
+ self.model = AttendAndExcitePipeline.from_pretrained(model_id).to(
31
+ self.device)
32
+ self.tokenizer = self.model.tokenizer
33
+ self.model_id = model_id
34
+
35
+ def get_token_table(self, model_id: str, prompt: str):
36
+ self.load_model(model_id)
37
+ tokens = [
38
+ self.tokenizer.decode(t)
39
+ for t in self.tokenizer(prompt)['input_ids']
40
+ ]
41
+ tokens = tokens[1:-1]
42
+ return list(enumerate(tokens, start=1))
43
+
44
+ def run(
45
+ self,
46
+ model_id: str,
47
+ prompt: str,
48
+ indices_to_alter_str: str,
49
+ seed: int,
50
+ apply_attend_and_excite: bool,
51
+ num_steps: int,
52
+ guidance_scale: float,
53
+ scale_factor: int = 20,
54
+ thresholds: dict[int, float] = {
55
+ 10: 0.5,
56
+ 20: 0.8
57
+ },
58
+ max_iter_to_alter: int = 25,
59
+ ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
60
+ generator = torch.Generator(device=self.device).manual_seed(seed)
61
+ try:
62
+ indices_to_alter = list(map(int, indices_to_alter_str.split(',')))
63
+ except:
64
+ raise gr.Error('Invalid token indices.')
65
+
66
+ self.load_model(model_id)
67
+
68
+ token_table = self.get_token_table(model_id, prompt)
69
+
70
+ controller = AttentionStore()
71
+ config = RunConfig(prompt=prompt,
72
+ n_inference_steps=num_steps,
73
+ guidance_scale=guidance_scale,
74
+ run_standard_sd=not apply_attend_and_excite,
75
+ scale_factor=scale_factor,
76
+ thresholds=thresholds,
77
+ max_iter_to_alter=max_iter_to_alter)
78
+ image = run_on_prompt(model=self.model,
79
+ prompt=[prompt],
80
+ controller=controller,
81
+ token_indices=indices_to_alter,
82
+ seed=generator,
83
+ config=config)
84
+
85
+ return token_table, image
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ diffusers==0.3.0
2
+ ftfy==6.1.1
3
+ jupyter
4
+ opencv-python-headless==4.7.0.68
5
+ pyrallis==0.3.1
6
+ torch==1.13.1
7
+ transformers==4.23.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }