AlanB commited on
Commit
c8e1bf7
1 Parent(s): a47751a

Added callback & callback_steps so I can use in my GUI

Browse files
Files changed (1) hide show
  1. pipeline.py +158 -0
pipeline.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Callable
2
+
3
+ import torch
4
+
5
+ from diffusers import (
6
+ AutoencoderKL,
7
+ DDIMScheduler,
8
+ DiffusionPipeline,
9
+ LMSDiscreteScheduler,
10
+ PNDMScheduler,
11
+ UNet2DConditionModel,
12
+ )
13
+ from PIL import Image
14
+ from torchvision import transforms as tfms
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPTextModel, CLIPTokenizer
17
+
18
+
19
+ class MagicMixPipeline(DiffusionPipeline):
20
+ def __init__(
21
+ self,
22
+ vae: AutoencoderKL,
23
+ text_encoder: CLIPTextModel,
24
+ tokenizer: CLIPTokenizer,
25
+ unet: UNet2DConditionModel,
26
+ scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
27
+ ):
28
+ super().__init__()
29
+
30
+ self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
31
+
32
+ # convert PIL image to latents
33
+ def encode(self, img):
34
+ with torch.no_grad():
35
+ latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
36
+ latent = 0.18215 * latent.latent_dist.sample()
37
+ return latent
38
+
39
+ # convert latents to PIL image
40
+ def decode(self, latent):
41
+ latent = (1 / 0.18215) * latent
42
+ with torch.no_grad():
43
+ img = self.vae.decode(latent).sample
44
+ img = (img / 2 + 0.5).clamp(0, 1)
45
+ img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
46
+ img = (img * 255).round().astype("uint8")
47
+ return Image.fromarray(img[0])
48
+
49
+ # convert prompt into text embeddings, also unconditional embeddings
50
+ def prep_text(self, prompt):
51
+ text_input = self.tokenizer(
52
+ prompt,
53
+ padding="max_length",
54
+ max_length=self.tokenizer.model_max_length,
55
+ truncation=True,
56
+ return_tensors="pt",
57
+ )
58
+
59
+ text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]
60
+
61
+ uncond_input = self.tokenizer(
62
+ "",
63
+ padding="max_length",
64
+ max_length=self.tokenizer.model_max_length,
65
+ truncation=True,
66
+ return_tensors="pt",
67
+ )
68
+
69
+ uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
70
+
71
+ return torch.cat([uncond_embedding, text_embedding])
72
+
73
+ def __call__(
74
+ self,
75
+ img: Image.Image,
76
+ prompt: str,
77
+ kmin: float = 0.3,
78
+ kmax: float = 0.6,
79
+ mix_factor: float = 0.5,
80
+ seed: int = 42,
81
+ steps: int = 50,
82
+ guidance_scale: float = 7.5,
83
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
84
+ callback_steps: Optional[int] = 1,
85
+ ) -> Image.Image:
86
+ tmin = steps - int(kmin * steps)
87
+ tmax = steps - int(kmax * steps)
88
+
89
+ text_embeddings = self.prep_text(prompt)
90
+
91
+ self.scheduler.set_timesteps(steps)
92
+
93
+ width, height = img.size
94
+ encoded = self.encode(img)
95
+
96
+ torch.manual_seed(seed)
97
+ noise = torch.randn(
98
+ (1, self.unet.in_channels, height // 8, width // 8),
99
+ ).to(self.device)
100
+
101
+ latents = self.scheduler.add_noise(
102
+ encoded,
103
+ noise,
104
+ timesteps=self.scheduler.timesteps[tmax],
105
+ )
106
+
107
+ input = torch.cat([latents] * 2)
108
+
109
+ input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])
110
+
111
+ with torch.no_grad():
112
+ pred = self.unet(
113
+ input,
114
+ self.scheduler.timesteps[tmax],
115
+ encoder_hidden_states=text_embeddings,
116
+ ).sample
117
+
118
+ pred_uncond, pred_text = pred.chunk(2)
119
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
120
+
121
+ latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample
122
+
123
+ for i, t in enumerate(tqdm(self.scheduler.timesteps)):
124
+ if i > tmax:
125
+ if i < tmin: # layout generation phase
126
+ orig_latents = self.scheduler.add_noise(
127
+ encoded,
128
+ noise,
129
+ timesteps=t,
130
+ )
131
+
132
+ input = (mix_factor * latents) + (
133
+ 1 - mix_factor
134
+ ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
135
+ input = torch.cat([input] * 2)
136
+
137
+ else: # content generation phase
138
+ input = torch.cat([latents] * 2)
139
+
140
+ input = self.scheduler.scale_model_input(input, t)
141
+
142
+ with torch.no_grad():
143
+ pred = self.unet(
144
+ input,
145
+ t,
146
+ encoder_hidden_states=text_embeddings,
147
+ ).sample
148
+
149
+ pred_uncond, pred_text = pred.chunk(2)
150
+ pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
151
+
152
+ latents = self.scheduler.step(pred, t, latents).prev_sample
153
+
154
+ # call the callback, if provided
155
+ if callback is not None and i % callback_steps == 0:
156
+ callback(i, t, latents)
157
+
158
+ return self.decode(latents)