bala1802 commited on
Commit
1975737
1 Parent(s): 2c73d98

Upload 6 files

Browse files
Files changed (6) hide show
  1. config.py +20 -0
  2. diffusion_loss.py +24 -0
  3. image_generator.py +73 -0
  4. model.py +16 -0
  5. prediction.py +24 -0
  6. utils.py +24 -0
config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ DEVICE = "mps"
4
+ # DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
5
+ HEIGHT = 512
6
+ WIDTH = 512
7
+ GUIDANCE_SCALE = 8
8
+ LOSS_SCALE = 200
9
+ NUM_INFERENCE_STEPS = 50
10
+ BATCH_SIZE = 1
11
+
12
+ SEEDS = [2000,2000,500,600,100]
13
+ STABLE_DIFFUSION_MODEL = "CompVis/stable-diffusion-v1-4"
14
+ STABLE_DIFUSION_CONCEPTS = ['<meeg>', '<midjourney-style>', '<moebius>', ' <Marc_Allante>', '<wlop-style>']
15
+
16
+ #LMS DSCRETE SCHEDULER
17
+ BETA_START = 0.00085
18
+ BETA_END = 0.012
19
+ BETA_SCHEDULE = "scaled_linear"
20
+ NUM_TRAIN_TIMESTEPS = 1000
diffusion_loss.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ import torch.nn.functional as F
4
+
5
+ def blue_channel(images):
6
+ error = torch.abs(images[:,2] - 0.9).mean()
7
+ return error
8
+
9
+ def elastic_transform(images):
10
+ elastic_transformer = T.ElasticTransform(alpha=550.0,sigma=5.0)
11
+ transformed_imgs = elastic_transformer(images)
12
+ error = torch.abs(transformed_imgs - images).mean()
13
+ return error
14
+
15
+ def symmetry(images):
16
+ flipped_image = torch.flip(images, [3])
17
+ error = F.mse_loss(images, flipped_image)
18
+ print("Loss Calculated for the Symmetry : ", error)
19
+ return error
20
+
21
+ def saturation(images):
22
+ transformed_imgs = T.functional.adjust_saturation(images,saturation_factor = 10)
23
+ error = torch.abs(transformed_imgs - images).mean()
24
+ return error
image_generator.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm.auto import tqdm
3
+ from diffusers import LMSDiscreteScheduler
4
+
5
+ import config
6
+
7
+ def construct_text_embeddings(pipe, prompt):
8
+ text_input = pipe.tokenizer(prompt, padding='max_length',
9
+ max_length = pipe.tokenizer.model_max_length, truncation= True,
10
+ return_tensors="pt")
11
+ uncond_input = pipe.tokenizer([""] * config.BATCH_SIZE, padding="max_length",
12
+ max_length= text_input.input_ids.shape[-1],
13
+ return_tensors="pt")
14
+ with torch.no_grad():
15
+ text_input_embeddings = pipe.text_encoder(text_input.input_ids.to(config.DEVICE))[0]
16
+ with torch.no_grad():
17
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(config.DEVICE))[0]
18
+
19
+ text_embeddings = torch.cat([uncond_embeddings, text_input_embeddings])
20
+ return text_embeddings
21
+
22
+ def initialize_latent(seed_number, pipe, scheduler):
23
+ generator = torch.manual_seed(seed_number)
24
+ latent = torch.randn((config.BATCH_SIZE, pipe.unet.config.in_channels,
25
+ config.HEIGHT//8, config.WIDTH//8),
26
+ generator = generator).to(torch.float16)
27
+ latent = latent.to(config.DEVICE)
28
+ latent = latent * scheduler.init_noise_sigma
29
+ return latent
30
+
31
+ def run_prediction(pipe, text_embeddings, scheduler, latent, loss_function=None):
32
+ for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)):
33
+ latent_model_input = torch.cat([latent] * 2)
34
+ sigma = scheduler.sigmas[i]
35
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
36
+
37
+ with torch.no_grad():
38
+ noise_pred = pipe.unet(latent_model_input.to(torch.float16), t, encoder_hidden_states=text_embeddings)["sample"]
39
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
40
+ noise_pred = noise_pred_uncond + config.GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)
41
+
42
+ if loss_function and i%5 == 0:
43
+ latent = latent.detach().requires_grad_()
44
+ latent_x0 = latent - sigma * noise_pred
45
+
46
+ denoised_images = pipe.vae.decode((1/ 0.18215) * latent_x0).sample / 2 + 0.5 # range(0,1)
47
+
48
+ loss = loss_function(denoised_images) * config.LOSS_SCALE
49
+ print(f"loss {loss}")
50
+
51
+ cond_grad = torch.autograd.grad(loss, latent)[0]
52
+ latent = latent.detach() - cond_grad * sigma**2
53
+
54
+ latent = scheduler.step(noise_pred,t, latent).prev_sample
55
+
56
+ return latent
57
+
58
+ def generate_images(pipe, seed_number, prompt, loss_function=None):
59
+
60
+ scheduler = LMSDiscreteScheduler(beta_start = 0.00085,
61
+ beta_end = 0.012,
62
+ beta_schedule = "scaled_linear",
63
+ num_train_timesteps = 1000)
64
+ scheduler.set_timesteps(config.NUM_INFERENCE_STEPS)
65
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
66
+
67
+ text_embeddings = construct_text_embeddings(pipe=pipe, prompt=prompt)
68
+ latent = initialize_latent(seed_number=seed_number, pipe=pipe, scheduler=scheduler)
69
+ latent = run_prediction(pipe=pipe, text_embeddings=text_embeddings,
70
+ scheduler=scheduler, latent=latent,
71
+ loss_function=loss_function)
72
+
73
+ return latent
model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DiffusionPipeline
3
+ import config
4
+
5
+ def initialize_diffusion_model():
6
+ pretrained_model_name_or_path = config.STABLE_DIFFUSION_MODEL
7
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
8
+ torch_dtype=torch.float16).to(config.DEVICE)
9
+
10
+ pipe.load_textual_inversion("sd-concepts-library/dreams")
11
+ pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
12
+ pipe.load_textual_inversion("sd-concepts-library/moebius")
13
+ pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
14
+ pipe.load_textual_inversion("sd-concepts-library/wlop-style")
15
+
16
+ return pipe
prediction.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+
4
+ import utils
5
+ import model
6
+ import config
7
+ import image_generator as generator
8
+
9
+ def predict(prompt, pipe, loss_function=None):
10
+ latents = []
11
+
12
+ for seed_number, sd_concept in zip(config.SEEDS, config.STABLE_DIFUSION_CONCEPTS):
13
+ torch.mps.empty_cache()
14
+ gc.collect()
15
+ torch.mps.empty_cache()
16
+
17
+ prompt = [f'{prompt} {sd_concept}']
18
+ latent = generator.generate_images(pipe=pipe, seed_number=seed_number, prompt=prompt, loss_function=loss_function)
19
+ latents.append(latent)
20
+
21
+ latents = torch.vstack(latents)
22
+ images = utils.convert_latents_to_pil_images(pipe=pipe, latents=latents)
23
+ grid = utils.populate_image_grid(images, 1, len(latents))
24
+ return grid
utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from diffusers import LMSDiscreteScheduler
4
+
5
+ import config
6
+
7
+ def convert_latents_to_pil_images(pipe, latents):
8
+ latents = (1 / 0.18215) * latents
9
+ with torch.no_grad():
10
+ image = pipe.vae.decode(latents).sample
11
+ image = (image / 2 + 0.5).clamp(0, 1)
12
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
13
+ images = (image * 255).round().astype("uint8")
14
+ pil_images = [Image.fromarray(image) for image in images]
15
+ return pil_images
16
+
17
+ def populate_image_grid(imgs, rows, cols):
18
+ assert len(imgs) == rows*cols
19
+ w, h = imgs[0].size
20
+ grid = Image.new('RGB', size=(cols*w, rows*h))
21
+
22
+ for i, img in enumerate(imgs):
23
+ grid.paste(img, box=(i%cols*w, i//cols*h))
24
+ return grid