josuelmet commited on
Commit
24230a4
1 Parent(s): 6150099

Upload 6 files

Browse files
Files changed (6) hide show
  1. bedroom.jpg +0 -0
  2. boomerang.py +252 -0
  3. cat.png +0 -0
  4. einstein.jpg +0 -0
  5. oprah.jpeg +0 -0
  6. original.png +0 -0
bedroom.jpg ADDED
boomerang.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from PIL import Image
3
+ import torch
4
+ from torch import autocast
5
+ from torchvision import transforms as T
6
+ from types import MethodType
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
11
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
12
+
13
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True)
14
+ #pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
15
+
16
+ pipe = pipe.to('cuda')
17
+
18
+
19
+
20
+
21
+ # Overriding the U-Net forward pass
22
+ def forward(
23
+ self,
24
+ sample: torch.FloatTensor,
25
+ timestep: Union[torch.Tensor, float, int],
26
+ encoder_hidden_states: torch.Tensor,
27
+ return_dict: bool = True,
28
+ ) -> Union[UNet2DConditionOutput, Tuple]:
29
+ """r
30
+ Args:
31
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
32
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
33
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
34
+ return_dict (`bool`, *optional*, defaults to `True`):
35
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
36
+
37
+ Returns:
38
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
39
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
40
+ returning a tuple, the first element is the sample tensor.
41
+ """
42
+ # 0. center input if necessary
43
+ if self.config.center_input_sample:
44
+ sample = 2 * sample - 1.0
45
+
46
+ # 1. time
47
+ timesteps = timestep
48
+ if not torch.is_tensor(timesteps):
49
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
50
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
51
+ timesteps = timesteps.to(dtype=torch.float32)
52
+ timesteps = timesteps[None].to(device=sample.device)
53
+
54
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
55
+ timesteps = timesteps.expand(sample.shape[0])
56
+
57
+ t_emb = self.time_proj(timesteps)
58
+ #emb = self.time_embedding(t_emb)
59
+ emb = self.time_embedding(t_emb.to(sample.dtype))
60
+
61
+ # 2. pre-process
62
+ sample = self.conv_in(sample)
63
+
64
+ # 3. down
65
+ down_block_res_samples = (sample,)
66
+ for downsample_block in self.down_blocks:
67
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
68
+ sample, res_samples = downsample_block(
69
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
70
+ )
71
+ else:
72
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
73
+
74
+ down_block_res_samples += res_samples
75
+
76
+ # 4. mid
77
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
78
+
79
+ # 5. up
80
+ for upsample_block in self.up_blocks:
81
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
82
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
83
+
84
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
85
+ sample = upsample_block(
86
+ hidden_states=sample,
87
+ temb=emb,
88
+ res_hidden_states_tuple=res_samples,
89
+ encoder_hidden_states=encoder_hidden_states,
90
+ )
91
+ else:
92
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
93
+
94
+ # 6. post-process
95
+ # make sure hidden states is in float32
96
+ # when running in half-precision
97
+ #sample = self.conv_norm_out(sample.float()).type(sample.dtype)
98
+ sample = self.conv_norm_out(sample)
99
+ sample = self.conv_act(sample)
100
+ sample = self.conv_out(sample)
101
+
102
+ if not return_dict:
103
+ return (sample,)
104
+
105
+ return UNet2DConditionOutput(sample=sample)
106
+
107
+
108
+ def safety_forward(self, clip_input, images):
109
+ return images, False
110
+
111
+
112
+ # Overriding the Stable Diffusion call method
113
+ @torch.no_grad()
114
+ def call(
115
+ self,
116
+ prompt: Union[str, List[str]],
117
+ height: Optional[int] = 512,
118
+ width: Optional[int] = 512,
119
+ num_inference_steps: Optional[int] = 50,
120
+ guidance_scale: Optional[float] = 7.5,
121
+ eta: Optional[float] = 0.0,
122
+ generator: Optional[torch.Generator] = None,
123
+ latents: Optional[torch.FloatTensor] = None,
124
+ output_type: Optional[str] = "pil",
125
+ return_dict: bool = True,
126
+ percent_noise: float = 0.7,
127
+ **kwargs,
128
+ ):
129
+ if isinstance(prompt, str):
130
+ batch_size = 1
131
+ elif isinstance(prompt, list):
132
+ batch_size = len(prompt)
133
+ else:
134
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
135
+
136
+ if height % 8 != 0 or width % 8 != 0:
137
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
138
+
139
+ # get prompt text embeddings
140
+ text_input = self.tokenizer(
141
+ prompt,
142
+ padding="max_length",
143
+ max_length=self.tokenizer.model_max_length,
144
+ truncation=True,
145
+ return_tensors="pt",
146
+ )
147
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
148
+
149
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
150
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
151
+ # corresponds to doing no classifier free guidance.
152
+ do_classifier_free_guidance = guidance_scale > 1.0
153
+ # get unconditional embeddings for classifier free guidance
154
+ if do_classifier_free_guidance:
155
+ max_length = text_input.input_ids.shape[-1]
156
+ uncond_input = self.tokenizer(
157
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
158
+ )
159
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
160
+
161
+ # For classifier free guidance, we need to do two forward passes.
162
+ # Here we concatenate the unconditional and text embeddings into a single batch
163
+ # to avoid doing two forward passes
164
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
165
+
166
+ # get the initial random noise unless the user supplied it
167
+
168
+ # Unlike in other pipelines, latents need to be generated in the target device
169
+ # for 1-to-1 results reproducibility with the CompVis implementation.
170
+ # However this currently doesn't work in `mps`.
171
+ latents_device = "cpu" if self.device.type == "mps" else self.device
172
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
173
+ if latents is None:
174
+ latents = torch.randn(
175
+ latents_shape,
176
+ generator=generator,
177
+ device=latents_device,
178
+ )
179
+ else:
180
+ if latents.shape != latents_shape:
181
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
182
+ latents = latents.to(self.device)
183
+
184
+ # set timesteps
185
+ self.scheduler.set_timesteps(num_inference_steps)
186
+
187
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
188
+ #if isinstance(self.scheduler, LMSDiscreteScheduler):
189
+ # latents = latents * self.scheduler.sigmas[0]
190
+
191
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
192
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
193
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
194
+ # and should be between [0, 1]
195
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
196
+ extra_step_kwargs = {}
197
+ if accepts_eta:
198
+ extra_step_kwargs["eta"] = eta
199
+
200
+
201
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
202
+
203
+ if t - 1 > 1000 * percent_noise:
204
+ continue
205
+
206
+ #print(t)
207
+
208
+ # expand the latents if we are doing classifier free guidance
209
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
210
+ #if isinstance(self.scheduler, LMSDiscreteScheduler):
211
+ # sigma = self.scheduler.sigmas[i]
212
+ # # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
213
+ # latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
214
+
215
+ # predict the noise residual
216
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
217
+
218
+ # perform guidance
219
+ if do_classifier_free_guidance:
220
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
221
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
222
+
223
+ # compute the previous noisy sample x_t -> x_t-1
224
+ #if isinstance(self.scheduler, LMSDiscreteScheduler):
225
+ # latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
226
+ #else:
227
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
228
+
229
+
230
+ # scale and decode the image latents with vae
231
+ latents = 1 / 0.18215 * latents
232
+ image = self.vae.decode(latents).sample
233
+
234
+ image = (image / 2 + 0.5).clamp(0, 1)
235
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
236
+
237
+ # run safety checker
238
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
239
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
240
+
241
+ if output_type == "pil":
242
+ image = self.numpy_to_pil(image)
243
+
244
+ if not return_dict:
245
+ return (image, has_nsfw_concept)
246
+
247
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
248
+
249
+
250
+ pipe.unet.forward = MethodType(forward, pipe.unet)
251
+ pipe.safety_checker.forward = MethodType(safety_forward, pipe.safety_checker)
252
+ type(pipe).__call__ = call
cat.png ADDED
einstein.jpg ADDED
oprah.jpeg ADDED
original.png ADDED