MeetMeAt92 commited on
Commit
42eef31
1 Parent(s): 715de45

Create diffusion.py

Browse files
Files changed (1) hide show
  1. diffusion.py +586 -0
diffusion.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """stabledefusion_using_dataset.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1mORMC1aTJ8LzN06Z5zUGbdzlDiPb63EC
8
+
9
+ <h1>A stable dufissiun using <b>huggingface diffuser</b>(concepts-library) </h1>
10
+
11
+ ## Initial setup
12
+
13
+ <h2>Install the required libs<h2>
14
+ """
15
+
16
+ # Install the required libs
17
+ !pip install -U -qq git+https://github.com/huggingface/diffusers.git
18
+ !pip install -qq accelerate transformers ftfy
19
+ !pip install -qq "ipywidgets>=7,<8"
20
+
21
+ """<h2>Install xformers for faster and memory efficient training(for low end GPU)<h2>"""
22
+
23
+ # Commented out IPython magic to ensure Python compatibility.
24
+
25
+ !pip install -U --pre triton
26
+
27
+ from subprocess import getoutput
28
+ from IPython.display import HTML
29
+ from IPython.display import clear_output
30
+ import time
31
+
32
+ s = getoutput('nvidia-smi')
33
+ if 'T4' in s:
34
+ gpu = 'T4'
35
+ elif 'P100' in s:
36
+ gpu = 'P100'
37
+ elif 'V100' in s:
38
+ gpu = 'V100'
39
+ elif 'A100' in s:
40
+ gpu = 'A100'
41
+
42
+ while True:
43
+ try:
44
+ gpu=='T4'or gpu=='P100'or gpu=='V100'or gpu=='A100'
45
+ break
46
+ except:
47
+ pass
48
+ print('[1;31mit seems that your GPU is not supported at the moment')
49
+ time.sleep(5)
50
+
51
+ if (gpu=='T4'):
52
+ # %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl
53
+
54
+ elif (gpu=='P100'):
55
+ # %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/P100/xformers-0.0.13.dev0-py3-none-any.whl
56
+
57
+ elif (gpu=='V100'):
58
+ # %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/V100/xformers-0.0.13.dev0-py3-none-any.whl
59
+
60
+ elif (gpu=='A100'):
61
+ # %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/A100/xformers-0.0.13.dev0-py3-none-any.whl
62
+
63
+ """<h2>Enviroment Setup</h2>"""
64
+
65
+ import argparse
66
+ import itertools
67
+ import math
68
+ import os
69
+ import random
70
+
71
+ import numpy as np
72
+ import torch
73
+ import torch.nn.functional as F
74
+ import torch.utils.checkpoint
75
+ from torch.utils.data import Dataset
76
+
77
+ import PIL
78
+ from accelerate import Accelerator
79
+ from accelerate.logging import get_logger
80
+ from accelerate.utils import set_seed
81
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
82
+ from diffusers.optimization import get_scheduler
83
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
84
+ from PIL import Image
85
+ from torchvision import transforms
86
+ from tqdm.auto import tqdm
87
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
88
+
89
+ def image_grid(imgs, rows, cols):
90
+ assert len(imgs) == rows*cols
91
+
92
+ w, h = imgs[0].size
93
+ grid = Image.new('RGB', size=(cols*w, rows*h))
94
+ grid_w, grid_h = grid.size
95
+
96
+ for i, img in enumerate(imgs):
97
+ grid.paste(img, box=(i%cols*w, i//cols*h))
98
+ return grid
99
+
100
+ """<h2>getting model </h2>"""
101
+
102
+ pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
103
+ # #@param ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] {allow-input: true}
104
+
105
+ """<h2>adding data set using urls</h2>"""
106
+
107
+ urls = [
108
+
109
+ "https://p4.wallpaperbetter.com/wallpaper/990/374/475/jinx-league-of-legends-arcane-hd-wallpaper-preview.jpg",
110
+ "https://p4.wallpaperbetter.com/wallpaper/597/595/773/arcane-jinx-league-of-legends-hd-wallpaper-preview.jpg",
111
+ "https://p4.wallpaperbetter.com/wallpaper/455/986/460/jinx-league-of-legends-arcane-hd-wallpaper-preview.jpg",
112
+ "https://p4.wallpaperbetter.com/wallpaper/769/667/606/jinx-league-of-legends-arcane-hd-wallpaper-preview.jpg",
113
+ "https://p4.wallpaperbetter.com/wallpaper/836/342/438/jinx-league-of-legends-vi-league-of-legends-arcane-hd-wallpaper-preview.jpg",
114
+ "https://p4.wallpaperbetter.com/wallpaper/211/1017/269/cyberpunk-edgerunners-cyberpunk-2077-hd-wallpaper-preview.jpg",
115
+ "https://p4.wallpaperbetter.com/wallpaper/8/868/36/cyberpunk-edgerunners-cyberpunk-2077-lucy-edgerunners-rebecca-edgerunners-hd-wallpaper-preview.jpg",
116
+ "https://p4.wallpaperbetter.com/wallpaper/288/722/467/cyberpunk-edgerunners-lucy-edgerunners-anime-girls-cyberpunk-2077-cyberpunk-hd-wallpaper-preview.jpg",
117
+
118
+ ]
119
+
120
+ """<h2>Checking if immages are loaded</h2>"""
121
+
122
+ import requests
123
+ import glob
124
+ from io import BytesIO
125
+
126
+ def download_image(url):
127
+ try:
128
+ response = requests.get(url)
129
+ except:
130
+ return None
131
+ return Image.open(BytesIO(response.content)).convert("RGB")
132
+
133
+ images = list(filter(None,[download_image(url) for url in urls]))
134
+ save_path = "./my_concept"
135
+ if not os.path.exists(save_path):
136
+ os.mkdir(save_path)
137
+ [image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
138
+ image_grid(images, 1, len(images))
139
+
140
+ """<h2>innitilizing placeholder and initial token for newly created concept</h2>"""
141
+
142
+ what_to_teach = "object"
143
+
144
+ placeholder_token = "\u003Canime-style>"
145
+
146
+ #tokeniser spellings always check(token dana sa phla confrm ke lana)
147
+ initializer_token = "character" #character mean k art ma chracter banana ha
148
+
149
+ """<h2>setting peompts for traning"""
150
+
151
+ imagenet_templates_small = [
152
+ "a photo of a {}",
153
+ "a rendering of a {}",
154
+ "a cropped photo of the {}",
155
+ "the photo of a {}",
156
+ "a photo of a clean {}",
157
+ "a photo of a dirty {}",
158
+ "a dark photo of the {}",
159
+ "a photo of my {}",
160
+ "a photo of the cool {}",
161
+ "a close-up photo of a {}",
162
+ "a bright photo of the {}",
163
+ "a cropped photo of a {}",
164
+ "a photo of the {}",
165
+ "a good photo of the {}",
166
+ "a photo of one {}",
167
+ "a close-up photo of the {}",
168
+ "a rendition of the {}",
169
+ "a photo of the clean {}",
170
+ "a rendition of a {}",
171
+ "a photo of a nice {}",
172
+ "a good photo of a {}",
173
+ "a photo of the nice {}",
174
+ "a photo of the small {}",
175
+ "a photo of the weird {}",
176
+ "a photo of the large {}",
177
+ "a photo of a cool {}",
178
+ "a photo of a small {}",
179
+ "4k",
180
+ "hyeper realistic",
181
+ ]
182
+
183
+ imagenet_style_templates_small = [
184
+ "a painting in the style of {}",
185
+ "a rendering in the style of {}",
186
+ "a cropped painting in the style of {}",
187
+ "the painting in the style of {}",
188
+ "a clean painting in the style of {}",
189
+ "a dirty painting in the style of {}",
190
+ "a dark painting in the style of {}",
191
+ "a picture in the style of {}",
192
+ "a cool painting in the style of {}",
193
+ "a close-up painting in the style of {}",
194
+ "a bright painting in the style of {}",
195
+ "a cropped painting in the style of {}",
196
+ "a good painting in the style of {}",
197
+ "a close-up painting in the style of {}",
198
+ "a rendition in the style of {}",
199
+ "a nice painting in the style of {}",
200
+ "a small painting in the style of {}",
201
+ "a weird painting in the style of {}",
202
+ "a large painting in the style of {}",
203
+ "lying on rose bed of {}",
204
+ ]
205
+
206
+ """<h2>setting dataset"""
207
+
208
+ class TextualInversionDataset(Dataset):
209
+ def __init__(
210
+ self,
211
+ data_root,
212
+ tokenizer,
213
+ learnable_property="object", # [object, style]
214
+ size=512,
215
+ repeats=100,
216
+ interpolation="bicubic",
217
+ flip_p=0.5,
218
+ set="train",
219
+ placeholder_token="*",
220
+ center_crop=False,
221
+ ):
222
+
223
+ self.data_root = data_root
224
+ self.tokenizer = tokenizer
225
+ self.learnable_property = learnable_property
226
+ self.size = size
227
+ self.placeholder_token = placeholder_token
228
+ self.center_crop = center_crop
229
+ self.flip_p = flip_p
230
+
231
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
232
+
233
+ self.num_images = len(self.image_paths)
234
+ self._length = self.num_images
235
+
236
+ if set == "train":
237
+ self._length = self.num_images * repeats
238
+
239
+ self.interpolation = {
240
+ "linear": PIL.Image.LINEAR,
241
+ "bilinear": PIL.Image.BILINEAR,
242
+ "bicubic": PIL.Image.BICUBIC,
243
+ "lanczos": PIL.Image.LANCZOS,
244
+ }[interpolation]
245
+
246
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
247
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
248
+
249
+ def __len__(self):
250
+ return self._length
251
+
252
+ def __getitem__(self, i):
253
+ example = {}
254
+ image = Image.open(self.image_paths[i % self.num_images])
255
+
256
+ if not image.mode == "RGB":
257
+ image = image.convert("RGB")
258
+
259
+ placeholder_string = self.placeholder_token
260
+ text = random.choice(self.templates).format(placeholder_string)
261
+
262
+ example["input_ids"] = self.tokenizer(
263
+ text,
264
+ padding="max_length",
265
+ truncation=True,
266
+ max_length=self.tokenizer.model_max_length,
267
+ return_tensors="pt",
268
+ ).input_ids[0]
269
+
270
+ # default to score-sde preprocessing
271
+ img = np.array(image).astype(np.uint8)
272
+
273
+ if self.center_crop:
274
+ crop = min(img.shape[0], img.shape[1])
275
+ h, w, = (
276
+ img.shape[0],
277
+ img.shape[1],
278
+ )
279
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
280
+
281
+ image = Image.fromarray(img)
282
+ image = image.resize((self.size, self.size), resample=self.interpolation)
283
+
284
+ image = self.flip_transform(image)
285
+ image = np.array(image).astype(np.uint8)
286
+ image = (image / 127.5 - 1.0).astype(np.float32)
287
+
288
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
289
+ return example
290
+
291
+ """<h2>Load the tokenizer and add the placeholder token as a additional special token."""
292
+
293
+ tokenizer = CLIPTokenizer.from_pretrained(
294
+ pretrained_model_name_or_path,
295
+ subfolder="tokenizer",
296
+ )
297
+
298
+ # Add the placeholder token in tokenizer
299
+ num_added_tokens = tokenizer.add_tokens(placeholder_token)
300
+ if num_added_tokens == 0:
301
+ raise ValueError(
302
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
303
+ " `placeholder_token` that is not already in the tokenizer."
304
+ )
305
+
306
+ """<h2> Get token ids for our placeholder and initializer token. This code """
307
+
308
+ # Convert the initializer_token, placeholder_token to ids
309
+ token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
310
+ # Check if initializer_token is a single token or a sequence of tokens
311
+ if len(token_ids) > 1:
312
+ raise ValueError("The initializer token must be a single token.")
313
+
314
+ initializer_token_id = token_ids[0]
315
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
316
+
317
+ """<h2>Load the Stable Diffusion model"""
318
+
319
+ # Load models and create wrapper for stable diffusion
320
+ # pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
321
+ # del pipeline
322
+ text_encoder = CLIPTextModel.from_pretrained(
323
+ pretrained_model_name_or_path, subfolder="text_encoder"
324
+ )
325
+ vae = AutoencoderKL.from_pretrained(
326
+ pretrained_model_name_or_path, subfolder="vae"
327
+ )
328
+ unet = UNet2DConditionModel.from_pretrained(
329
+ pretrained_model_name_or_path, subfolder="unet"
330
+ )
331
+
332
+ """<h2>added the "placeholder_token" in the "tokenizer" so we resize the token embeddings<h2>
333
+ <h2>create a new embedding vector in the token embeddings
334
+ """
335
+
336
+ text_encoder.resize_token_embeddings(len(tokenizer))
337
+
338
+ """<h2>Initialise the newly added placeholder token"""
339
+
340
+ token_embeds = text_encoder.get_input_embeddings().weight.data
341
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
342
+
343
+ """<h2>train the newly added embedding vecto"""
344
+
345
+ def freeze_params(params):
346
+ for param in params:
347
+ param.requires_grad = False
348
+
349
+ # Freeze vae and unet
350
+ freeze_params(vae.parameters())
351
+ freeze_params(unet.parameters())
352
+ # Freeze all parameters except for the token embeddings in text encoder
353
+ params_to_freeze = itertools.chain(
354
+ text_encoder.text_model.encoder.parameters(),
355
+ text_encoder.text_model.final_layer_norm.parameters(),
356
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
357
+ )
358
+ freeze_params(params_to_freeze)
359
+
360
+ """<h2>Creating training data"""
361
+
362
+ train_dataset = TextualInversionDataset(
363
+ data_root=save_path,
364
+ tokenizer=tokenizer,
365
+ size=vae.sample_size,
366
+ placeholder_token=placeholder_token,
367
+ repeats=100,
368
+ learnable_property=what_to_teach, #Option selected above between object and style
369
+ center_crop=False,
370
+ set="train",
371
+ )
372
+
373
+ def create_dataloader(train_batch_size=1):
374
+ return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
375
+
376
+ #creating noise secdular
377
+ noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler")
378
+
379
+ """<h2>Setting traning arguments"""
380
+
381
+ hyperparameters = {
382
+ "learning_rate": 5e-04,
383
+ "scale_lr": True,
384
+ "max_train_steps": 2000,
385
+ "save_steps": 250,
386
+ "train_batch_size": 4,
387
+ "gradient_accumulation_steps": 1,
388
+ "gradient_checkpointing": True,
389
+ "mixed_precision": "fp16",
390
+ "seed": 42,
391
+ "output_dir": "sd-concept-output"
392
+ }
393
+ !mkdir -p sd-concept-output
394
+
395
+ """<h2>traninfg functions"""
396
+
397
+ logger = get_logger(__name__)
398
+
399
+ def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
400
+ logger.info("Saving embeddings")
401
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
402
+ learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
403
+ torch.save(learned_embeds_dict, save_path)
404
+
405
+ def training_function(text_encoder, vae, unet):
406
+ train_batch_size = hyperparameters["train_batch_size"]
407
+ gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
408
+ learning_rate = hyperparameters["learning_rate"]
409
+ max_train_steps = hyperparameters["max_train_steps"]
410
+ output_dir = hyperparameters["output_dir"]
411
+ gradient_checkpointing = hyperparameters["gradient_checkpointing"]
412
+
413
+ accelerator = Accelerator(
414
+ gradient_accumulation_steps=gradient_accumulation_steps,
415
+ mixed_precision=hyperparameters["mixed_precision"]
416
+ )
417
+
418
+ if gradient_checkpointing:
419
+ text_encoder.gradient_checkpointing_enable()
420
+ unet.enable_gradient_checkpointing()
421
+
422
+ train_dataloader = create_dataloader(train_batch_size)
423
+
424
+ if hyperparameters["scale_lr"]:
425
+ learning_rate = (
426
+ learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
427
+ )
428
+
429
+ # Initialize the optimizer
430
+ optimizer = torch.optim.AdamW(
431
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
432
+ lr=learning_rate,
433
+ )
434
+
435
+ text_encoder, optimizer, train_dataloader = accelerator.prepare(
436
+ text_encoder, optimizer, train_dataloader
437
+ )
438
+
439
+ weight_dtype = torch.float32
440
+ if accelerator.mixed_precision == "fp16":
441
+ weight_dtype = torch.float16
442
+ elif accelerator.mixed_precision == "bf16":
443
+ weight_dtype = torch.bfloat16
444
+
445
+ # Move vae and unet to device
446
+ vae.to(accelerator.device, dtype=weight_dtype)
447
+ unet.to(accelerator.device, dtype=weight_dtype)
448
+
449
+ # Keep vae in eval mode as we don't train it
450
+ vae.eval()
451
+ # Keep unet in train mode to enable gradient checkpointing
452
+ unet.train()
453
+
454
+
455
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
456
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
457
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
458
+
459
+ # Train!
460
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
461
+
462
+ logger.info("***** Running training *****")
463
+ logger.info(f" Num examples = {len(train_dataset)}")
464
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
465
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
466
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
467
+ logger.info(f" Total optimization steps = {max_train_steps}")
468
+ # Only show the progress bar once on each machine.
469
+ progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
470
+ progress_bar.set_description("Steps")
471
+ global_step = 0
472
+
473
+ for epoch in range(num_train_epochs):
474
+ text_encoder.train()
475
+ for step, batch in enumerate(train_dataloader):
476
+ with accelerator.accumulate(text_encoder):
477
+ # Convert images to latent space
478
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
479
+ latents = latents * 0.18215
480
+
481
+ # Sample noise that we'll add to the latents
482
+ noise = torch.randn_like(latents)
483
+ bsz = latents.shape[0]
484
+ # Sample a random timestep for each image
485
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
486
+
487
+ # Add noise to the latents according to the noise magnitude at each timestep
488
+ # (this is the forward diffusion process)
489
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
490
+
491
+ # Get the text embedding for conditioning
492
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
493
+
494
+ # Predict the noise residual
495
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states.to(weight_dtype)).sample
496
+
497
+ # Get the target for loss depending on the prediction type
498
+ if noise_scheduler.config.prediction_type == "epsilon":
499
+ target = noise
500
+ elif noise_scheduler.config.prediction_type == "v_prediction":
501
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
502
+ else:
503
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
504
+
505
+ loss = F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
506
+ accelerator.backward(loss)
507
+
508
+ # Zero out the gradients for all token embeddings except the newly added
509
+ # embeddings for the concept, as we only want to optimize the concept embeddings
510
+ if accelerator.num_processes > 1:
511
+ grads = text_encoder.module.get_input_embeddings().weight.grad
512
+ else:
513
+ grads = text_encoder.get_input_embeddings().weight.grad
514
+ # Get the index for tokens that we want to zero the grads for
515
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
516
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
517
+
518
+ optimizer.step()
519
+ optimizer.zero_grad()
520
+
521
+ # Checks if the accelerator has performed an optimization step behind the scenes
522
+ if accelerator.sync_gradients:
523
+ progress_bar.update(1)
524
+ global_step += 1
525
+ if global_step % hyperparameters["save_steps"] == 0:
526
+ save_path = os.path.join(output_dir, f"learned_embeds-step-{global_step}.bin")
527
+ save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
528
+
529
+ logs = {"loss": loss.detach().item()}
530
+ progress_bar.set_postfix(**logs)
531
+
532
+ if global_step >= max_train_steps:
533
+ break
534
+
535
+ accelerator.wait_for_everyone()
536
+
537
+
538
+ # Create the pipeline using using the trained modules and save it.
539
+ if accelerator.is_main_process:
540
+ pipeline = StableDiffusionPipeline.from_pretrained(
541
+ pretrained_model_name_or_path,
542
+ text_encoder=accelerator.unwrap_model(text_encoder),
543
+ tokenizer=tokenizer,
544
+ vae=vae,
545
+ unet=unet,
546
+ )
547
+ pipeline.save_pretrained(output_dir)
548
+ # Also save the newly trained embeddings
549
+ save_path = os.path.join(output_dir, f"learned_embeds.bin")
550
+ save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
551
+
552
+ """<h2>launching traning on gpu(will not work without gpu)"""
553
+
554
+ import accelerate
555
+ accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
556
+
557
+ for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
558
+ if param.grad is not None:
559
+ del param.grad # free some memory
560
+ torch.cuda.empty_cache()
561
+
562
+ """<h2>set up pipeline"""
563
+
564
+ from diffusers import DPMSolverMultistepScheduler
565
+ pipe = StableDiffusionPipeline.from_pretrained(
566
+ hyperparameters["output_dir"],
567
+ scheduler=DPMSolverMultistepScheduler.from_pretrained(hyperparameters["output_dir"], subfolder="scheduler"),
568
+ torch_dtype=torch.float16,
569
+ ).to("cuda")
570
+
571
+ #@title Run the Stable Diffusion pipeline
572
+
573
+
574
+ prompt = "Planet scale halo of water in space digital art, Trending on ArtStation" #@param {type:"string"}
575
+
576
+ num_samples = 4
577
+ num_rows = 1
578
+
579
+ all_images = []
580
+ for _ in range(num_rows):
581
+ images = pipe([prompt] * num_samples, num_inference_steps=30, guidance_scale=7.5).images
582
+ all_images.extend(images)
583
+
584
+ grid = image_grid(all_images, num_rows, num_samples)
585
+
586
+ grid