jadechoghari commited on
Commit
3db57e8
1 Parent(s): 80f129f

update (lots of bugs to fix)

Browse files
Files changed (1) hide show
  1. pipeline_spad.py +188 -53
pipeline_spad.py CHANGED
@@ -1,68 +1,203 @@
 
1
  import torch
2
- from diffusers import AutoencoderKL, DiffusionPipeline
3
- from transformers import CLIPTextModel, CLIPTokenizer
4
- from mv_unet import SPADUnetModel
5
- from diffusers.schedulers import DPMSolverMultistepScheduler
 
 
 
 
6
 
7
  class SPADPipeline(DiffusionPipeline):
8
- def __init__(
9
- self,
10
- vae: AutoencoderKL,
11
- unet: SPADUnetModel,
12
- tokenizer: CLIPTokenizer,
13
- text_encoder: CLIPTextModel,
14
- scheduler: DPMSolverMultistepScheduler,
15
- ):
16
  super().__init__()
17
-
18
- self.vae = vae
19
- self.unet = unet
20
- self.tokenizer = tokenizer
21
- self.text_encoder = text_encoder
22
- self.scheduler = scheduler
23
-
24
- # make sure all our models are on the same device
25
- self.vae.to(self.device)
26
- self.unet.to(self.device)
27
- self.text_encoder.to(self.device)
28
-
29
- def encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None):
30
- text_input = self.tokenizer(
31
- prompt,
32
- padding="max_length",
33
- max_length=self.tokenizer.model_max_length,
34
- return_tensors="pt"
35
  )
36
- text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # we duplicate the text embeddings for each generation, just to save time :)
39
- bs_embed, seq_len, _ = text_embeddings.shape
40
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
41
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
 
 
42
 
43
- return text_embeddings
 
 
44
 
45
- def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5):
46
- # encide the prompt into the text embeddings
47
- text_embeddings = self.encode_prompt(prompt, self.device, 1, do_classifier_free_guidance=False)
48
 
49
- # this is the initial noise sample
50
- latents = torch.randn(
51
- (text_embeddings.shape[0], self.unet.in_channels, self.unet.image_size, self.unet.image_size),
52
- device=self.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
- # setting up the scheduler
56
- self.scheduler.set_timesteps(num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # iterate and generate
59
- for t in self.scheduler.timesteps:
60
- latents = self.scheduler.scale_model_input(latents, t)
61
- latents = self.unet(latents, t, text_embeddings)["sample"]
62
- latents = self.scheduler.step(latents, t, latents, guidance_scale=guidance_scale)["prev_sample"]
 
 
63
 
64
- # decode latents into images
65
- images = self.vae.decode(latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  images = (images / 2 + 0.5).clamp(0, 1)
67
 
68
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import math
6
+ from diffusers import DiffusionPipeline
7
+ from einops import rearrange, repeat
8
+ from itertools import chain
9
+ from tqdm import tqdm
10
+ from .geometry import get_batch_from_spherical
11
 
12
  class SPADPipeline(DiffusionPipeline):
13
+ def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
 
 
 
 
 
 
 
14
  super().__init__()
15
+
16
+ self.register_modules(
17
+ unet=unet,
18
+ vae=vae,
19
+ text_encoder=text_encoder,
20
+ tokenizer=tokenizer,
21
+ scheduler=scheduler
22
+ )
23
+
24
+ self.cfg_conds = ["txt", "cam", "epi", "plucker"]
25
+ self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
26
+ self.use_abs_extrinsics = False
27
+ self.use_intrinsic = False
28
+
29
+ self.cc_projection = nn.Sequential(
30
+ nn.Linear(4 if not self.use_intrinsic else 8, 1280),
31
+ nn.SiLU(),
32
+ nn.Linear(1280, 1280),
33
  )
34
+ nn.init.zeros_(self.cc_projection[-1].weight)
35
+ nn.init.zeros_(self.cc_projection[-1].bias)
36
+
37
+
38
+ def generate_camera_batch(self, elevations, azimuths, use_abs=False):
39
+ batch = get_batch_from_spherical(elevations, azimuths)
40
+
41
+ abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
42
+
43
+ debug_cams = [[] for _ in range(len(azimuths))]
44
+ for i, icam in enumerate(abs_cams):
45
+ for j, jcam in enumerate(abs_cams):
46
+ if use_abs:
47
+ dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]])
48
+ else:
49
+ dcam = icam - jcam
50
+ dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
51
+ debug_cams[i].append(dcam)
52
+
53
+ batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
54
+
55
+ # Add intrinsics to the batch
56
+ focal = 1 / np.tan(0.702769935131073 / 2)
57
+ intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
58
+ intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
59
+ batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
60
+
61
+ return batch
62
+
63
+ def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
64
+ X = np.linspace(-1, 1, blob_width)[None, :]
65
+ Y = np.linspace(-1, 1, blob_height)[:, None]
66
+ inv_dev = 1 / sigma ** 2
67
+ gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev)
68
+ if gaussian_blob.max() > 0:
69
+ gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
70
+ gaussian_blob = 255.0 - gaussian_blob
71
+
72
+ gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
73
+ gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
74
+ gaussian_blob = torch.from_numpy(gaussian_blob)
75
+
76
+ return gaussian_blob
77
+
78
+ @torch.no_grad()
79
+ def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
80
+ elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
81
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
82
+ device = self.device
83
 
84
+ # Generate camera batch
85
+ if elevations is None or azimuths is None:
86
+ elevations = [45] * 4
87
+ azimuths = [0, 90, 180, 270]
88
+
89
+ n_views = len(elevations)
90
+ camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
91
+ camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
92
 
93
+ # Prepare gaussian blob initialization
94
+ blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
95
+ camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
96
 
97
+ # Encode text
98
+ text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
99
+ text_embeddings = self.text_encoder(text_input_ids)[0]
100
 
101
+ # Prepare unconditional embeddings for classifier-free guidance
102
+ max_length = text_input_ids.shape[-1]
103
+ uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
104
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
105
+
106
+ # Encode camera data
107
+ camera_embeddings = self.cc_projection(camera_batch["cam"])
108
+
109
+ # Prepare latents
110
+ latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
111
+ latents = self.prepare_latents(
112
+ batch_size,
113
+ self.unet.in_channels,
114
+ n_views,
115
+ latent_height,
116
+ latent_width,
117
+ self.unet.dtype,
118
+ device,
119
+ generator=None,
120
  )
121
 
122
+ # Prepare epi_constraint_masks (placeholder, replace with actual implementation)
123
+ epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
124
+
125
+ # Prepare plucker embeddings (placeholder, replace with actual implementation)
126
+ plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
127
+
128
+ latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
129
+ n_objects = 2;
130
+ latents = torch.randn(n_objects, n_views, 10, 32, 32, device=device, dtype=self.unet.dtype)
131
+
132
+ # Set up scheduler
133
+ # self.scheduler.set_timesteps(num_inference_steps)
134
+ self.scheduler.set_timesteps(10)
135
+ # Repeat text_embeddings to match the desired dimensions
136
+ text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
137
 
138
+ # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
139
+ text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
140
+ # Denoising loop
141
+ for t in tqdm(self.scheduler.timesteps):
142
+ # Expand timesteps to match shape [batch_size, 1, 1]
143
+ # timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
144
+ timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
145
 
146
+ # # Repeat text_embeddings to match the desired dimensions
147
+ # text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
148
+
149
+ # # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
150
+ # text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
151
+
152
+ # print("old cam shape: ", camera_embeddings.shape)
153
+ camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
154
+ # print("cam emb shape: ", camera_embeddings.shape)
155
+ # Prepare context
156
+ context = [
157
+ # text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
158
+ # camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
159
+ # epi_constraint_masks # Keep this as is for now
160
+ text_embeddings, # [n_objects, n_views, max_seq_len, 768]
161
+ camera_embeddings # [n_objects, n_views, 1280]
162
+ ]
163
+
164
+ # Predict noise residual
165
+ noise_pred = self.unet(
166
+ latents, # Shape: [batch_size, 1, 4, 64, 64]
167
+ timesteps=timesteps, # Shape: [batch_size, 1, 1]
168
+ context=context
169
+ )
170
+
171
+ # Perform guidance
172
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
173
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
174
+
175
+
176
+ # Compute previous noisy sample
177
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
178
+
179
+ # reduce latents
180
+ #EXPERIMENTAL
181
+ # If you need to reduce the channels from 10 to 4
182
+ latents = latents[:, :, :4, :, :] # Select only the first 4 channels
183
+ latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4])
184
+ # Decode latents
185
+ images = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
186
+
187
+ # Post-process images
188
  images = (images / 2 + 0.5).clamp(0, 1)
189
 
190
+ if images.dim() == 5:
191
+ images = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
192
+ elif images.dim() == 4:
193
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
194
+ else:
195
+ raise ValueError(f"Unexpected image dimensions: {images.shape}")
196
+
197
+
198
+ return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]}
199
+
200
+ def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None):
201
+ shape = (batch_size, num_views, num_channels, height, width)
202
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
203
+ return latents