GV05 commited on
Commit
5e41772
1 Parent(s): 61f2e1b

Upload Utils.py

Browse files
Files changed (1) hide show
  1. Utils.py +88 -0
Utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
2
+ from transformers import CLIPTextModel, CLIPTokenizer
3
+ from tqdm.auto import tqdm
4
+ from PIL import Image
5
+ import torch
6
+
7
+ class MingleModel:
8
+
9
+ def __init__(self):
10
+ # Set device
11
+ self.torch_device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ # Load the autoencoder model which will be used to decode the latents into image space.
13
+ use_auth_token = "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"
14
+ self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae",
15
+ use_auth_token=use_auth_token).to(self.torch_device)
16
+
17
+ # Load the tokenizer and text encoder to tokenize and encode the text.
18
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_auth_token=use_auth_token)
19
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", use_auth_token=use_auth_token).to(self.torch_device)
20
+
21
+ # # The UNet model for generating the latents.
22
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet",use_auth_token=use_auth_token).to(self.torch_device)
23
+
24
+ # The noise scheduler
25
+ self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
26
+ num_train_timesteps=1000)
27
+
28
+ def tokenizer(self, prompt):
29
+ return self.tokenizer([prompt], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True,
30
+ return_tensors="pt")
31
+
32
+ def text_encoder(self, text_input):
33
+ return self.text_encoder(text_input.input_ids.to(self.torch_device))[0]
34
+
35
+ def latents_to_pil(self, latents):
36
+ # bath of latents -> list of images
37
+ latents = (1 / 0.18215) * latents
38
+ with torch.no_grad():
39
+ image = self.vae.decode(latents).sample
40
+ image = (image / 2 + 0.5).clamp(0, 1)
41
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
42
+ images = (image * 255).round().astype("uint8")
43
+ pil_images = [Image.fromarray(image) for image in images]
44
+ return pil_images
45
+
46
+ def generate_with_embs(self, text_embeddings, generator_int=32, num_inference_steps=30, guidance_scale=7.5):
47
+ height = 512 # default height of Stable Diffusion
48
+ width = 512 # default width of Stable Diffusion
49
+ num_inference_steps = num_inference_steps # Number of denoising steps
50
+ guidance_scale = guidance_scale # Scale for classifier-free guidance
51
+ generator = torch.manual_seed(generator_int) # Seed generator to create the inital latent noise
52
+ batch_size = 1
53
+
54
+ max_length = 77
55
+ uncond_input = self.tokenizer(
56
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
57
+ )
58
+ with torch.no_grad():
59
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.torch_device))[0]
60
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
61
+
62
+ # Prep Scheduler
63
+ self.scheduler.set_timesteps(num_inference_steps)
64
+
65
+ # Prep latents
66
+ latents = torch.randn((batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator)
67
+ latents = latents.to(self.torch_device)
68
+ latents = latents * self.scheduler.init_noise_sigma
69
+
70
+ # Loop
71
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
72
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
73
+ latent_model_input = torch.cat([latents] * 2)
74
+ sigma = self.scheduler.sigmas[i]
75
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
76
+
77
+ # predict the noise residual
78
+ with torch.no_grad():
79
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
80
+
81
+ # perform guidance
82
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
83
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
84
+
85
+ # compute the previous noisy sample x_t -> x_t-1
86
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
87
+
88
+ return self.latents_to_pil(latents)[0]