ravi.naik commited on
Commit
b87b512
·
1 Parent(s): 19147b3

Added source, experiments, gradio app for stable diffusion

Browse files
.gitattributes CHANGED
@@ -1,4 +1,5 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import torch
4
+ import pathlib
5
+
6
+ from src.utils import concept_styles, loss_fn
7
+ from src.stable_diffusion import StableDiffusion
8
+
9
+ PROJECT_PATH = "."
10
+ CONCEPT_LIBS_PATH = f"{PROJECT_PATH}/concept_libs"
11
+
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+
15
+ def generate(prompt, styles, gen_steps, loss_scale):
16
+ lossless_images, lossy_images = [], []
17
+ for style in styles:
18
+ concept_lib_path = f"{CONCEPT_LIBS_PATH}/{concept_styles[style]}"
19
+ concept_lib = pathlib.Path(concept_lib_path)
20
+ concept_embed = torch.load(concept_lib)
21
+
22
+ manual_seed = random.randint(0, 100)
23
+ diffusion = StableDiffusion(
24
+ device=DEVICE,
25
+ num_inference_steps=gen_steps,
26
+ manual_seed=manual_seed,
27
+ )
28
+ generated_image_lossless = diffusion.generate_image(
29
+ prompt=prompt,
30
+ loss_fn=loss_fn,
31
+ loss_scale=0,
32
+ concept_embed=concept_embed,
33
+ )
34
+ generated_image_lossy = diffusion.generate_image(
35
+ prompt=prompt,
36
+ loss_fn=loss_fn,
37
+ loss_scale=loss_scale,
38
+ concept_embed=concept_embed,
39
+ )
40
+ lossless_images.append((generated_image_lossless, style))
41
+ lossy_images.append((generated_image_lossy, style))
42
+ return {lossless_gallery: lossless_images, lossy_gallery: lossy_images}
43
+
44
+
45
+ with gr.Blocks() as app:
46
+ gr.Markdown("## ERA Session20 - Stable Diffusion: Generative Art with Guidance")
47
+ with gr.Row():
48
+ with gr.Column():
49
+ prompt_box = gr.Textbox(label="Prompt", interactive=True)
50
+ style_selector = gr.Dropdown(
51
+ choices=list(concept_styles.keys()),
52
+ value=list(concept_styles.keys())[0],
53
+ multiselect=True,
54
+ label="Select a Concept Style",
55
+ interactive=True,
56
+ )
57
+ gen_steps = gr.Slider(
58
+ minimum=10,
59
+ maximum=50,
60
+ value=30,
61
+ step=10,
62
+ label="Select Number of Steps",
63
+ interactive=True,
64
+ )
65
+
66
+ loss_scale = gr.Slider(
67
+ minimum=0,
68
+ maximum=32,
69
+ value=8,
70
+ step=8,
71
+ label="Select Guidance Scale",
72
+ interactive=True,
73
+ )
74
+
75
+ submit_btn = gr.Button(value="Generate")
76
+
77
+ with gr.Column():
78
+ lossless_gallery = gr.Gallery(
79
+ label="Generated Images without Guidance", show_label=True
80
+ )
81
+ lossy_gallery = gr.Gallery(
82
+ label="Generated Images with Guidance", show_label=True
83
+ )
84
+
85
+ submit_btn.click(
86
+ generate,
87
+ inputs=[prompt_box, style_selector, gen_steps, loss_scale],
88
+ outputs=[lossless_gallery, lossy_gallery],
89
+ )
90
+
91
+ app.launch()
experiments/Stable Diffusion Deep Dive.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ff21d4579bafcd26c5ec593bd9020c65b85e552a1a8645dc60cf3eeddec3126
3
+ size 8313731
experiments/exp.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e55d79ab1ba786bbcce564b743caf8064c69aa24dddd46e851a974329348e312
3
+ size 2470336
experiments/exp1.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f53f8647069e798a316493db4f0e09ef0b798e2c74636f4465cc35236d9e5130
3
+ size 3992987
experiments/exp2.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8f946cb2f730f609ad3bb2b38b55ca142fc5c98916e6415586ab23f71aedbd8
3
+ size 4713617
experiments/exp3.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cfe21389c210759611b9c14e02cdee663bee21276f0f7dbdc326d35899a9dd3
3
+ size 1108233
experiments/exp4.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71245f3f38295fec8dd0170face4539b6c176f0385a4c26868e84249a12e1ffd
3
+ size 18169187
experiments/exp5.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d42fb890c72a04192dcedbd6b583e8c1666f5217f0b8bd69f4a1834eeab5a45c
3
+ size 49514010
src/stable_diffusion.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class StableDiffusion:
9
+ def __init__(
10
+ self,
11
+ vae_arch="CompVis/stable-diffusion-v1-4",
12
+ tokenizer_arch="openai/clip-vit-large-patch14",
13
+ encoder_arch="openai/clip-vit-large-patch14",
14
+ unet_arch="CompVis/stable-diffusion-v1-4",
15
+ device="cpu",
16
+ height=512,
17
+ width=512,
18
+ num_inference_steps=30,
19
+ guidance_scale=7.5,
20
+ manual_seed=1,
21
+ ) -> None:
22
+ self.height = height # default height of Stable Diffusion
23
+ self.width = width # default width of Stable Diffusion
24
+ self.num_inference_steps = num_inference_steps # Number of denoising steps
25
+ self.guidance_scale = guidance_scale # Scale for classifier-free guidance
26
+ self.device = device
27
+ self.manual_seed = manual_seed
28
+
29
+ vae = AutoencoderKL.from_pretrained(vae_arch, subfolder="vae")
30
+ # Load the tokenizer and text encoder to tokenize and encode the text.
31
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_arch)
32
+ text_encoder = CLIPTextModel.from_pretrained(encoder_arch)
33
+
34
+ # The UNet model for generating the latents.
35
+ unet = UNet2DConditionModel.from_pretrained(unet_arch, subfolder="unet")
36
+
37
+ # The noise scheduler
38
+ self.scheduler = LMSDiscreteScheduler(
39
+ beta_start=0.00085,
40
+ beta_end=0.012,
41
+ beta_schedule="scaled_linear",
42
+ num_train_timesteps=1000,
43
+ )
44
+
45
+ # To the GPU we go!
46
+ self.vae = vae.to(self.device)
47
+ self.text_encoder = text_encoder.to(self.device)
48
+ self.unet = unet.to(self.device)
49
+
50
+ self.token_emb_layer = text_encoder.text_model.embeddings.token_embedding
51
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
52
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
53
+ self.position_embeddings = pos_emb_layer(position_ids)
54
+
55
+ def get_output_embeds(self, input_embeddings):
56
+ # CLIP's text model uses causal mask, so we prepare it here:
57
+ bsz, seq_len = input_embeddings.shape[:2]
58
+ causal_attention_mask = (
59
+ self.text_encoder.text_model._build_causal_attention_mask(
60
+ bsz, seq_len, dtype=input_embeddings.dtype
61
+ )
62
+ )
63
+
64
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
65
+ # so that it doesn't just return the pooled final predictions:
66
+ encoder_outputs = self.text_encoder.text_model.encoder(
67
+ inputs_embeds=input_embeddings,
68
+ attention_mask=None, # We aren't using an attention mask so that can be None
69
+ causal_attention_mask=causal_attention_mask.to(self.device),
70
+ output_attentions=None,
71
+ output_hidden_states=True, # We want the output embs not the final output
72
+ return_dict=None,
73
+ )
74
+
75
+ # We're interested in the output hidden state only
76
+ output = encoder_outputs[0]
77
+
78
+ # There is a final layer norm we need to pass these through
79
+ output = self.text_encoder.text_model.final_layer_norm(output)
80
+
81
+ # And now they're ready!
82
+ return output
83
+
84
+ def set_timesteps(self, scheduler, num_inference_steps):
85
+ scheduler.set_timesteps(num_inference_steps)
86
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
87
+
88
+ def latents_to_pil(self, latents):
89
+ # bath of latents -> list of images
90
+ latents = (1 / 0.18215) * latents
91
+ with torch.no_grad():
92
+ image = self.vae.decode(latents).sample
93
+ image = (image / 2 + 0.5).clamp(0, 1)
94
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
95
+ images = (image * 255).round().astype("uint8")
96
+ pil_images = [Image.fromarray(image) for image in images]
97
+ return pil_images
98
+
99
+ def generate_with_embs(self, text_embeddings, text_input, loss_fn, loss_scale):
100
+ generator = torch.manual_seed(
101
+ self.manual_seed
102
+ ) # Seed generator to create the inital latent noise
103
+ batch_size = 1
104
+
105
+ max_length = text_input.input_ids.shape[-1]
106
+ uncond_input = self.tokenizer(
107
+ [""] * batch_size,
108
+ padding="max_length",
109
+ max_length=max_length,
110
+ return_tensors="pt",
111
+ )
112
+ with torch.no_grad():
113
+ uncond_embeddings = self.text_encoder(
114
+ uncond_input.input_ids.to(self.device)
115
+ )[0]
116
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
117
+
118
+ # Prep Scheduler
119
+ self.set_timesteps(self.scheduler, self.num_inference_steps)
120
+
121
+ # Prep latents
122
+ latents = torch.randn(
123
+ (batch_size, self.unet.in_channels, self.height // 8, self.width // 8),
124
+ generator=generator,
125
+ )
126
+ latents = latents.to(self.device)
127
+ latents = latents * self.scheduler.init_noise_sigma
128
+
129
+ # Loop
130
+ for i, t in tqdm(
131
+ enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)
132
+ ):
133
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
134
+ latent_model_input = torch.cat([latents] * 2)
135
+ sigma = self.scheduler.sigmas[i]
136
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
137
+
138
+ # predict the noise residual
139
+ with torch.no_grad():
140
+ noise_pred = self.unet(
141
+ latent_model_input, t, encoder_hidden_states=text_embeddings
142
+ )["sample"]
143
+
144
+ # perform guidance
145
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
146
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
147
+ noise_pred_text - noise_pred_uncond
148
+ )
149
+ if i % 5 == 0:
150
+ # Requires grad on the latents
151
+ latents = latents.detach().requires_grad_()
152
+
153
+ # Get the predicted x0:
154
+ # latents_x0 = latents - sigma * noise_pred
155
+ latents_x0 = self.scheduler.step(
156
+ noise_pred, t, latents
157
+ ).pred_original_sample
158
+
159
+ # Decode to image space
160
+ denoised_images = (
161
+ self.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
162
+ ) # range (0, 1)
163
+
164
+ # Calculate loss
165
+ loss = loss_fn(denoised_images) * loss_scale
166
+
167
+ # Occasionally print it out
168
+ # if i % 10 == 0:
169
+ # print(i, "loss:", loss.item())
170
+
171
+ # Get gradient
172
+ cond_grad = torch.autograd.grad(loss, latents)[0]
173
+
174
+ # Modify the latents based on this gradient
175
+ latents = latents.detach() - cond_grad * sigma**2
176
+ self.scheduler._step_index = self.scheduler._step_index - 1
177
+
178
+ # compute the previous noisy sample x_t -> x_t-1
179
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
180
+
181
+ return self.latents_to_pil(latents)[0]
182
+
183
+ def generate_image(
184
+ self,
185
+ prompt="A campfire (oil on canvas)",
186
+ loss_fn=None,
187
+ loss_scale=200,
188
+ concept_embed=None, # birb_embed["<birb-style>"]
189
+ ):
190
+ prompt += " in the style of cs"
191
+ text_input = self.tokenizer(
192
+ prompt,
193
+ padding="max_length",
194
+ max_length=self.tokenizer.model_max_length,
195
+ truncation=True,
196
+ return_tensors="pt",
197
+ )
198
+ input_ids = text_input.input_ids.to(self.device)
199
+ custom_style_token = self.tokenizer.encode("cs", add_special_tokens=False)[0]
200
+ # Get token embeddings
201
+ token_embeddings = self.token_emb_layer(input_ids)
202
+
203
+ # The new embedding - our special birb word
204
+ embed_key = list(concept_embed.keys())[0]
205
+ replacement_token_embedding = concept_embed[embed_key]
206
+
207
+ # Insert this into the token embeddings
208
+ token_embeddings[
209
+ 0, torch.where(input_ids[0] == custom_style_token)
210
+ ] = replacement_token_embedding.to(self.device)
211
+ # token_embeddings = token_embeddings + (replacement_token_embedding * 0.9)
212
+ # Combine with pos embs
213
+ input_embeddings = token_embeddings + self.position_embeddings
214
+
215
+ # Feed through to get final output embs
216
+ modified_output_embeddings = self.get_output_embeds(input_embeddings)
217
+
218
+ # And generate an image with this:
219
+ generated_image = self.generate_with_embs(
220
+ modified_output_embeddings, text_input, loss_fn, loss_scale
221
+ )
222
+ return generated_image
src/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def loss_fn(images):
2
+ return -images.median() / 3
3
+
4
+
5
+ concept_styles = {
6
+ "Allante": "allante.bin",
7
+ "XYZ": "xyz.bin",
8
+ "Moebius": "moebius.bin",
9
+ "Oil Style": "oil_style",
10
+ "Polygons": "poly.bin",
11
+ }