Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
•
d8ffb68
1
Parent(s):
209d166
allow different resolutions for w/h
Browse files- StableDiffuser.py +23 -33
- app.py +78 -14
- finetuning.py +2 -7
- train.py +1 -1
StableDiffuser.py
CHANGED
@@ -1,17 +1,13 @@
|
|
1 |
import argparse
|
2 |
-
import traceback
|
3 |
|
4 |
import torch
|
5 |
from baukit import TraceDict
|
6 |
-
from diffusers import
|
7 |
from PIL import Image
|
8 |
from tqdm.auto import tqdm
|
9 |
-
from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
|
10 |
-
from diffusers.schedulers import EulerAncestralDiscreteScheduler
|
11 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
12 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
13 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
14 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
15 |
import util
|
16 |
|
17 |
|
@@ -39,31 +35,17 @@ class StableDiffuser(torch.nn.Module):
|
|
39 |
def __init__(self,
|
40 |
scheduler='LMS',
|
41 |
repo_id_or_path="CompVis/stable-diffusion-v1-4",
|
|
|
42 |
):
|
43 |
|
44 |
super().__init__()
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
self.
|
52 |
-
repo_id_or_path, subfolder="tokenizer")
|
53 |
-
self.text_encoder = CLIPTextModel.from_pretrained(
|
54 |
-
repo_id_or_path, subfolder="text_encoder")
|
55 |
-
|
56 |
-
# The UNet model for generating the latents.
|
57 |
-
self.unet = UNet2DConditionModel.from_pretrained(
|
58 |
-
repo_id_or_path, subfolder="unet")
|
59 |
-
|
60 |
-
try:
|
61 |
-
self.feature_extractor = CLIPFeatureExtractor.from_pretrained(repo_id_or_path, subfolder="feature_extractor")
|
62 |
-
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id_or_path, subfolder="safety_checker")
|
63 |
-
except Exception as error:
|
64 |
-
print(f"caught exception {error} making feature extractor / safety checker")
|
65 |
-
self.feature_extractor = None
|
66 |
-
self.safety_checker = None
|
67 |
|
68 |
if scheduler == 'LMS':
|
69 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
@@ -74,10 +56,14 @@ class StableDiffuser(torch.nn.Module):
|
|
74 |
|
75 |
self.eval()
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
param = list(self.parameters())[0]
|
79 |
return torch.randn(
|
80 |
-
(batch_size, self.unet.in_channels,
|
81 |
generator=generator).type(param.dtype).to(param.device)
|
82 |
|
83 |
def add_noise(self, latents, noise, step):
|
@@ -109,8 +95,8 @@ class StableDiffuser(torch.nn.Module):
|
|
109 |
def set_scheduler_timesteps(self, n_steps):
|
110 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
111 |
|
112 |
-
def get_initial_latents(self, n_imgs,
|
113 |
-
noise = self.get_noise(n_imgs,
|
114 |
latents = noise * self.scheduler.init_noise_sigma
|
115 |
return latents
|
116 |
|
@@ -196,7 +182,8 @@ class StableDiffuser(torch.nn.Module):
|
|
196 |
def __call__(self,
|
197 |
prompts,
|
198 |
negative_prompts,
|
199 |
-
|
|
|
200 |
n_steps=50,
|
201 |
n_imgs=1,
|
202 |
end_iteration=None,
|
@@ -210,7 +197,7 @@ class StableDiffuser(torch.nn.Module):
|
|
210 |
prompts = [prompts]
|
211 |
|
212 |
self.set_scheduler_timesteps(n_steps)
|
213 |
-
latents = self.get_initial_latents(n_imgs,
|
214 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
215 |
end_iteration = end_iteration or n_steps
|
216 |
latents_steps, trace_steps = self.diffusion(
|
@@ -239,13 +226,16 @@ class StableDiffuser(torch.nn.Module):
|
|
239 |
|
240 |
return images_steps
|
241 |
|
|
|
|
|
|
|
242 |
|
243 |
if __name__ == '__main__':
|
244 |
|
245 |
parser = default_parser()
|
246 |
args = parser.parse_args()
|
247 |
|
248 |
-
diffuser = StableDiffuser(
|
249 |
|
250 |
images = diffuser(args.prompts,
|
251 |
n_steps=args.nsteps,
|
|
|
1 |
import argparse
|
|
|
2 |
|
3 |
import torch
|
4 |
from baukit import TraceDict
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
from PIL import Image
|
7 |
from tqdm.auto import tqdm
|
|
|
|
|
8 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
9 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
10 |
from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
|
|
|
11 |
import util
|
12 |
|
13 |
|
|
|
35 |
def __init__(self,
|
36 |
scheduler='LMS',
|
37 |
repo_id_or_path="CompVis/stable-diffusion-v1-4",
|
38 |
+
variant='fp16'
|
39 |
):
|
40 |
|
41 |
super().__init__()
|
42 |
|
43 |
+
self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path, variant=variant)
|
44 |
+
|
45 |
+
self.vae = self.pipeline.vae
|
46 |
+
self.unet = self.pipeline.unet
|
47 |
+
self.tokenizer = self.pipeline.tokenizer
|
48 |
+
self.text_encoder = self.pipeline.text_encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
if scheduler == 'LMS':
|
51 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
|
|
56 |
|
57 |
self.eval()
|
58 |
|
59 |
+
@property
|
60 |
+
def safety_checker(self):
|
61 |
+
return self.pipeline.safety_checker
|
62 |
+
|
63 |
+
def get_noise(self, batch_size, width, height, generator=None):
|
64 |
param = list(self.parameters())[0]
|
65 |
return torch.randn(
|
66 |
+
(batch_size, self.unet.in_channels, width // 8, height // 8),
|
67 |
generator=generator).type(param.dtype).to(param.device)
|
68 |
|
69 |
def add_noise(self, latents, noise, step):
|
|
|
95 |
def set_scheduler_timesteps(self, n_steps):
|
96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
97 |
|
98 |
+
def get_initial_latents(self, n_imgs, width, height, n_prompts, generator=None):
|
99 |
+
noise = self.get_noise(n_imgs, width, height, generator=generator).repeat(n_prompts, 1, 1, 1)
|
100 |
latents = noise * self.scheduler.init_noise_sigma
|
101 |
return latents
|
102 |
|
|
|
182 |
def __call__(self,
|
183 |
prompts,
|
184 |
negative_prompts,
|
185 |
+
width=512,
|
186 |
+
height=512,
|
187 |
n_steps=50,
|
188 |
n_imgs=1,
|
189 |
end_iteration=None,
|
|
|
197 |
prompts = [prompts]
|
198 |
|
199 |
self.set_scheduler_timesteps(n_steps)
|
200 |
+
latents = self.get_initial_latents(n_imgs, width, height, len(prompts), generator=generator)
|
201 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
202 |
end_iteration = end_iteration or n_steps
|
203 |
latents_steps, trace_steps = self.diffusion(
|
|
|
226 |
|
227 |
return images_steps
|
228 |
|
229 |
+
def save_pretrained(self, path, **kwargs):
|
230 |
+
self.pipeline.save_pretrained(path, **kwargs)
|
231 |
+
|
232 |
|
233 |
if __name__ == '__main__':
|
234 |
|
235 |
parser = default_parser()
|
236 |
args = parser.parse_args()
|
237 |
|
238 |
+
diffuser = StableDiffuser(scheduler='DDIM').to(torch.device(args.device)).half()
|
239 |
|
240 |
images = diffuser(args.prompts,
|
241 |
n_steps=args.nsteps,
|
app.py
CHANGED
@@ -86,8 +86,16 @@ class Demo:
|
|
86 |
label="Seed",
|
87 |
value=42
|
88 |
)
|
89 |
-
self.
|
90 |
-
label="Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
minimum=256,
|
92 |
maximum=1024,
|
93 |
value=512,
|
@@ -190,11 +198,51 @@ class Demo:
|
|
190 |
|
191 |
self.download = gr.Files()
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
self.infr_button.click(self.inference, inputs = [
|
194 |
self.prompt_input_infr,
|
195 |
self.negative_prompt_input_infr,
|
196 |
self.seed_infr,
|
197 |
-
self.
|
|
|
198 |
self.model_dropdown,
|
199 |
self.base_repo_id_or_path_input_infr
|
200 |
],
|
@@ -214,6 +262,14 @@ class Demo:
|
|
214 |
],
|
215 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
219 |
|
@@ -251,42 +307,50 @@ class Demo:
|
|
251 |
|
252 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
-
def inference(self, prompt, negative_prompt, seed,
|
256 |
|
257 |
seed = seed or 42
|
258 |
-
generator = torch.manual_seed(seed)
|
259 |
model_path = model_map[model_name]
|
260 |
checkpoint = torch.load(model_path)
|
261 |
|
262 |
self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
|
263 |
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
264 |
-
torch.cuda.empty_cache()
|
265 |
|
|
|
|
|
|
|
266 |
images = self.diffuser(
|
267 |
prompt,
|
268 |
negative_prompt,
|
269 |
-
|
|
|
270 |
n_steps=50,
|
271 |
generator=generator
|
272 |
)
|
273 |
-
|
274 |
-
|
275 |
orig_image = images[0][0]
|
276 |
|
277 |
torch.cuda.empty_cache()
|
278 |
-
|
279 |
-
generator = torch.manual_seed(seed)
|
280 |
-
|
281 |
with finetuner:
|
282 |
-
|
283 |
images = self.diffuser(
|
284 |
prompt,
|
285 |
negative_prompt,
|
|
|
|
|
286 |
n_steps=50,
|
287 |
generator=generator
|
288 |
)
|
289 |
-
|
290 |
edited_image = images[0][0]
|
291 |
|
292 |
del finetuner
|
|
|
86 |
label="Seed",
|
87 |
value=42
|
88 |
)
|
89 |
+
self.img_width_infr = gr.Slider(
|
90 |
+
label="Image width",
|
91 |
+
minimum=256,
|
92 |
+
maximum=1024,
|
93 |
+
value=512,
|
94 |
+
step=64
|
95 |
+
)
|
96 |
+
|
97 |
+
self.img_height_infr = gr.Slider(
|
98 |
+
label="Image height",
|
99 |
minimum=256,
|
100 |
maximum=1024,
|
101 |
value=512,
|
|
|
198 |
|
199 |
self.download = gr.Files()
|
200 |
|
201 |
+
with gr.Tab("Export") as export_column:
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
|
205 |
+
self.explain_train= gr.Markdown(interactive=False,
|
206 |
+
value='Export a model to Diffusers format. Please enter the base model and select the editing weights.')
|
207 |
+
|
208 |
+
with gr.Row():
|
209 |
+
|
210 |
+
with gr.Column(scale=3):
|
211 |
+
|
212 |
+
self.base_repo_id_or_path_input_export = gr.Text(
|
213 |
+
label="Base model",
|
214 |
+
value="CompVis/stable-diffusion-v1-4",
|
215 |
+
info="Path or huggingface repo id of the base model that this edit was done against"
|
216 |
+
)
|
217 |
+
|
218 |
+
self.model_dropdown_export = gr.Dropdown(
|
219 |
+
label="ESD Model",
|
220 |
+
choices=list(model_map.keys()),
|
221 |
+
value='Van Gogh',
|
222 |
+
interactive=True
|
223 |
+
)
|
224 |
+
|
225 |
+
self.save_path_input_export = gr.Text(
|
226 |
+
label="Output path",
|
227 |
+
placeholder="./exported_models/model_name",
|
228 |
+
info="Path to export the model to. A diffusers folder will be written to this location."
|
229 |
+
)
|
230 |
+
|
231 |
+
self.save_half_export = gr.Checkbox(
|
232 |
+
label="Save as fp16"
|
233 |
+
)
|
234 |
+
|
235 |
+
with gr.Column(scale=1):
|
236 |
+
self.export_button = gr.Button(
|
237 |
+
value="Export",
|
238 |
+
)
|
239 |
+
|
240 |
self.infr_button.click(self.inference, inputs = [
|
241 |
self.prompt_input_infr,
|
242 |
self.negative_prompt_input_infr,
|
243 |
self.seed_infr,
|
244 |
+
self.img_width_infr,
|
245 |
+
self.img_height_infr,
|
246 |
self.model_dropdown,
|
247 |
self.base_repo_id_or_path_input_infr
|
248 |
],
|
|
|
262 |
],
|
263 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
264 |
)
|
265 |
+
self.export_button.click(self.export, inputs = [
|
266 |
+
self.model_dropdown_export,
|
267 |
+
self.base_repo_id_or_path_input_export,
|
268 |
+
self.save_path_input_export,
|
269 |
+
self.save_half_export
|
270 |
+
],
|
271 |
+
outputs=[self.export_button]
|
272 |
+
)
|
273 |
|
274 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
|
275 |
|
|
|
307 |
|
308 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
|
309 |
|
310 |
+
def export(self, model_name, base_repo_id_or_path, save_path, save_half):
|
311 |
+
model_path = model_map[model_name]
|
312 |
+
checkpoint = torch.load(model_path)
|
313 |
+
self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval()
|
314 |
+
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval()
|
315 |
+
with finetuner:
|
316 |
+
if save_half:
|
317 |
+
self.diffuser = self.diffuser.half()
|
318 |
+
self.diffuser.pipeline.to(torch.float16, torch_device=self.diffuser.device)
|
319 |
+
self.diffuser.save_pretrained(save_path)
|
320 |
+
|
321 |
|
322 |
+
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
323 |
|
324 |
seed = seed or 42
|
|
|
325 |
model_path = model_map[model_name]
|
326 |
checkpoint = torch.load(model_path)
|
327 |
|
328 |
self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
|
329 |
finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
|
|
|
330 |
|
331 |
+
generator = torch.manual_seed(seed)
|
332 |
+
|
333 |
+
torch.cuda.empty_cache()
|
334 |
images = self.diffuser(
|
335 |
prompt,
|
336 |
negative_prompt,
|
337 |
+
width=width,
|
338 |
+
height=height,
|
339 |
n_steps=50,
|
340 |
generator=generator
|
341 |
)
|
|
|
|
|
342 |
orig_image = images[0][0]
|
343 |
|
344 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
345 |
with finetuner:
|
|
|
346 |
images = self.diffuser(
|
347 |
prompt,
|
348 |
negative_prompt,
|
349 |
+
width=width,
|
350 |
+
height=height,
|
351 |
n_steps=50,
|
352 |
generator=generator
|
353 |
)
|
|
|
354 |
edited_image = images[0][0]
|
355 |
|
356 |
del finetuner
|
finetuning.py
CHANGED
@@ -2,11 +2,12 @@ import copy
|
|
2 |
import re
|
3 |
import torch
|
4 |
import util
|
|
|
5 |
|
6 |
class FineTunedModel(torch.nn.Module):
|
7 |
|
8 |
def __init__(self,
|
9 |
-
model,
|
10 |
modules,
|
11 |
frozen_modules=[]
|
12 |
):
|
@@ -24,11 +25,8 @@ class FineTunedModel(torch.nn.Module):
|
|
24 |
|
25 |
for module_name, module in model.named_modules():
|
26 |
for ft_module_regex in modules:
|
27 |
-
|
28 |
match = re.search(ft_module_regex, module_name)
|
29 |
-
|
30 |
if match is not None:
|
31 |
-
|
32 |
ft_module = copy.deepcopy(module)
|
33 |
|
34 |
self.orig_modules[module_name] = module
|
@@ -39,13 +37,10 @@ class FineTunedModel(torch.nn.Module):
|
|
39 |
print(f"=> Finetuning {module_name}")
|
40 |
|
41 |
for ft_module_name, module in ft_module.named_modules():
|
42 |
-
|
43 |
ft_module_name = f"{module_name}.{ft_module_name}"
|
44 |
-
|
45 |
for freeze_module_name in frozen_modules:
|
46 |
|
47 |
match = re.search(freeze_module_name, ft_module_name)
|
48 |
-
|
49 |
if match:
|
50 |
print(f"=> Freezing {ft_module_name}")
|
51 |
util.freeze(module)
|
|
|
2 |
import re
|
3 |
import torch
|
4 |
import util
|
5 |
+
from StableDiffuser import StableDiffuser
|
6 |
|
7 |
class FineTunedModel(torch.nn.Module):
|
8 |
|
9 |
def __init__(self,
|
10 |
+
model: StableDiffuser,
|
11 |
modules,
|
12 |
frozen_modules=[]
|
13 |
):
|
|
|
25 |
|
26 |
for module_name, module in model.named_modules():
|
27 |
for ft_module_regex in modules:
|
|
|
28 |
match = re.search(ft_module_regex, module_name)
|
|
|
29 |
if match is not None:
|
|
|
30 |
ft_module = copy.deepcopy(module)
|
31 |
|
32 |
self.orig_modules[module_name] = module
|
|
|
37 |
print(f"=> Finetuning {module_name}")
|
38 |
|
39 |
for ft_module_name, module in ft_module.named_modules():
|
|
|
40 |
ft_module_name = f"{module_name}.{ft_module_name}"
|
|
|
41 |
for freeze_module_name in frozen_modules:
|
42 |
|
43 |
match = re.search(freeze_module_name, ft_module_name)
|
|
|
44 |
if match:
|
45 |
print(f"=> Freezing {ft_module_name}")
|
46 |
util.freeze(module)
|
train.py
CHANGED
@@ -36,7 +36,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
36 |
optimizer.zero_grad()
|
37 |
|
38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
39 |
-
latents = diffuser.get_initial_latents(1, img_size, 1)
|
40 |
|
41 |
with finetuner:
|
42 |
latents_steps, _ = diffuser.diffusion(
|
|
|
36 |
optimizer.zero_grad()
|
37 |
|
38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
39 |
+
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
|
40 |
|
41 |
with finetuner:
|
42 |
latents_steps, _ = diffuser.diffusion(
|