jadechoghari
commited on
Commit
•
3db57e8
1
Parent(s):
80f129f
update (lots of bugs to fix)
Browse files- pipeline_spad.py +188 -53
pipeline_spad.py
CHANGED
@@ -1,68 +1,203 @@
|
|
|
|
1 |
import torch
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
from diffusers
|
|
|
|
|
|
|
|
|
6 |
|
7 |
class SPADPipeline(DiffusionPipeline):
|
8 |
-
def __init__(
|
9 |
-
self,
|
10 |
-
vae: AutoencoderKL,
|
11 |
-
unet: SPADUnetModel,
|
12 |
-
tokenizer: CLIPTokenizer,
|
13 |
-
text_encoder: CLIPTextModel,
|
14 |
-
scheduler: DPMSolverMultistepScheduler,
|
15 |
-
):
|
16 |
super().__init__()
|
17 |
-
|
18 |
-
self.
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
self.
|
27 |
-
self.
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
)
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
#
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
text_embeddings = self.
|
48 |
|
49 |
-
#
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
)
|
54 |
|
55 |
-
#
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
images = (images / 2 + 0.5).clamp(0, 1)
|
67 |
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
from itertools import chain
|
9 |
+
from tqdm import tqdm
|
10 |
+
from .geometry import get_batch_from_spherical
|
11 |
|
12 |
class SPADPipeline(DiffusionPipeline):
|
13 |
+
def __init__(self, unet, vae, text_encoder, tokenizer, scheduler):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
super().__init__()
|
15 |
+
|
16 |
+
self.register_modules(
|
17 |
+
unet=unet,
|
18 |
+
vae=vae,
|
19 |
+
text_encoder=text_encoder,
|
20 |
+
tokenizer=tokenizer,
|
21 |
+
scheduler=scheduler
|
22 |
+
)
|
23 |
+
|
24 |
+
self.cfg_conds = ["txt", "cam", "epi", "plucker"]
|
25 |
+
self.cfg_scales = [7.5, 1.0, 1.0, 1.0] # Default scales, adjust as needed
|
26 |
+
self.use_abs_extrinsics = False
|
27 |
+
self.use_intrinsic = False
|
28 |
+
|
29 |
+
self.cc_projection = nn.Sequential(
|
30 |
+
nn.Linear(4 if not self.use_intrinsic else 8, 1280),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(1280, 1280),
|
33 |
)
|
34 |
+
nn.init.zeros_(self.cc_projection[-1].weight)
|
35 |
+
nn.init.zeros_(self.cc_projection[-1].bias)
|
36 |
+
|
37 |
+
|
38 |
+
def generate_camera_batch(self, elevations, azimuths, use_abs=False):
|
39 |
+
batch = get_batch_from_spherical(elevations, azimuths)
|
40 |
+
|
41 |
+
abs_cams = [torch.tensor([theta, azimuth, 3.5]) for theta, azimuth in zip(elevations, azimuths)]
|
42 |
+
|
43 |
+
debug_cams = [[] for _ in range(len(azimuths))]
|
44 |
+
for i, icam in enumerate(abs_cams):
|
45 |
+
for j, jcam in enumerate(abs_cams):
|
46 |
+
if use_abs:
|
47 |
+
dcam = torch.tensor([icam[0], math.sin(icam[1]), math.cos(icam[1]), icam[2]])
|
48 |
+
else:
|
49 |
+
dcam = icam - jcam
|
50 |
+
dcam = torch.tensor([dcam[0].item(), math.sin(dcam[1].item()), math.cos(dcam[1].item()), dcam[2].item()])
|
51 |
+
debug_cams[i].append(dcam)
|
52 |
+
|
53 |
+
batch["cam"] = torch.stack([torch.stack(dc) for dc in debug_cams])
|
54 |
+
|
55 |
+
# Add intrinsics to the batch
|
56 |
+
focal = 1 / np.tan(0.702769935131073 / 2)
|
57 |
+
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
|
58 |
+
intrinsics = torch.from_numpy(intrinsics).unsqueeze(0).float().repeat(batch["cam"].shape[0], 1, 1)
|
59 |
+
batch["render_intrinsics_flat"] = intrinsics[:, [0,1,0,1], [0,1,-1,-1]]
|
60 |
+
|
61 |
+
return batch
|
62 |
+
|
63 |
+
def get_gaussian_image(self, blob_width=256, blob_height=256, sigma=0.5):
|
64 |
+
X = np.linspace(-1, 1, blob_width)[None, :]
|
65 |
+
Y = np.linspace(-1, 1, blob_height)[:, None]
|
66 |
+
inv_dev = 1 / sigma ** 2
|
67 |
+
gaussian_blob = np.exp(-0.5 * (X**2) * inv_dev) * np.exp(-0.5 * (Y**2) * inv_dev)
|
68 |
+
if gaussian_blob.max() > 0:
|
69 |
+
gaussian_blob = 255.0 * (gaussian_blob - gaussian_blob.min()) / gaussian_blob.max()
|
70 |
+
gaussian_blob = 255.0 - gaussian_blob
|
71 |
+
|
72 |
+
gaussian_blob = (gaussian_blob / 255.0) * 2.0 - 1.0
|
73 |
+
gaussian_blob = np.expand_dims(gaussian_blob, axis=-1).repeat(3,-1)
|
74 |
+
gaussian_blob = torch.from_numpy(gaussian_blob)
|
75 |
+
|
76 |
+
return gaussian_blob
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, num_images_per_prompt=1,
|
80 |
+
elevations=None, azimuths=None, blob_sigma=0.5, **kwargs):
|
81 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
82 |
+
device = self.device
|
83 |
|
84 |
+
# Generate camera batch
|
85 |
+
if elevations is None or azimuths is None:
|
86 |
+
elevations = [45] * 4
|
87 |
+
azimuths = [0, 90, 180, 270]
|
88 |
+
|
89 |
+
n_views = len(elevations)
|
90 |
+
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
91 |
+
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
92 |
|
93 |
+
# Prepare gaussian blob initialization
|
94 |
+
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
|
95 |
+
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
|
96 |
|
97 |
+
# Encode text
|
98 |
+
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
|
99 |
+
text_embeddings = self.text_encoder(text_input_ids)[0]
|
100 |
|
101 |
+
# Prepare unconditional embeddings for classifier-free guidance
|
102 |
+
max_length = text_input_ids.shape[-1]
|
103 |
+
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
104 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
105 |
+
|
106 |
+
# Encode camera data
|
107 |
+
camera_embeddings = self.cc_projection(camera_batch["cam"])
|
108 |
+
|
109 |
+
# Prepare latents
|
110 |
+
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
111 |
+
latents = self.prepare_latents(
|
112 |
+
batch_size,
|
113 |
+
self.unet.in_channels,
|
114 |
+
n_views,
|
115 |
+
latent_height,
|
116 |
+
latent_width,
|
117 |
+
self.unet.dtype,
|
118 |
+
device,
|
119 |
+
generator=None,
|
120 |
)
|
121 |
|
122 |
+
# Prepare epi_constraint_masks (placeholder, replace with actual implementation)
|
123 |
+
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
|
124 |
+
|
125 |
+
# Prepare plucker embeddings (placeholder, replace with actual implementation)
|
126 |
+
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
|
127 |
+
|
128 |
+
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
129 |
+
n_objects = 2;
|
130 |
+
latents = torch.randn(n_objects, n_views, 10, 32, 32, device=device, dtype=self.unet.dtype)
|
131 |
+
|
132 |
+
# Set up scheduler
|
133 |
+
# self.scheduler.set_timesteps(num_inference_steps)
|
134 |
+
self.scheduler.set_timesteps(10)
|
135 |
+
# Repeat text_embeddings to match the desired dimensions
|
136 |
+
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
|
137 |
|
138 |
+
# Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
139 |
+
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
140 |
+
# Denoising loop
|
141 |
+
for t in tqdm(self.scheduler.timesteps):
|
142 |
+
# Expand timesteps to match shape [batch_size, 1, 1]
|
143 |
+
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
144 |
+
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
145 |
|
146 |
+
# # Repeat text_embeddings to match the desired dimensions
|
147 |
+
# text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 512]
|
148 |
+
|
149 |
+
# # Reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
150 |
+
# text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
151 |
+
|
152 |
+
# print("old cam shape: ", camera_embeddings.shape)
|
153 |
+
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
154 |
+
# print("cam emb shape: ", camera_embeddings.shape)
|
155 |
+
# Prepare context
|
156 |
+
context = [
|
157 |
+
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
|
158 |
+
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
|
159 |
+
# epi_constraint_masks # Keep this as is for now
|
160 |
+
text_embeddings, # [n_objects, n_views, max_seq_len, 768]
|
161 |
+
camera_embeddings # [n_objects, n_views, 1280]
|
162 |
+
]
|
163 |
+
|
164 |
+
# Predict noise residual
|
165 |
+
noise_pred = self.unet(
|
166 |
+
latents, # Shape: [batch_size, 1, 4, 64, 64]
|
167 |
+
timesteps=timesteps, # Shape: [batch_size, 1, 1]
|
168 |
+
context=context
|
169 |
+
)
|
170 |
+
|
171 |
+
# Perform guidance
|
172 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
173 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
174 |
+
|
175 |
+
|
176 |
+
# Compute previous noisy sample
|
177 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
178 |
+
|
179 |
+
# reduce latents
|
180 |
+
#EXPERIMENTAL
|
181 |
+
# If you need to reduce the channels from 10 to 4
|
182 |
+
latents = latents[:, :, :4, :, :] # Select only the first 4 channels
|
183 |
+
latents = latents.view(-1, latents.shape[2], latents.shape[3], latents.shape[4])
|
184 |
+
# Decode latents
|
185 |
+
images = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
186 |
+
|
187 |
+
# Post-process images
|
188 |
images = (images / 2 + 0.5).clamp(0, 1)
|
189 |
|
190 |
+
if images.dim() == 5:
|
191 |
+
images = images.cpu().permute(0, 1, 3, 4, 2).float().numpy() # For 5D tensors
|
192 |
+
elif images.dim() == 4:
|
193 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy() # For 4D tensors
|
194 |
+
else:
|
195 |
+
raise ValueError(f"Unexpected image dimensions: {images.shape}")
|
196 |
+
|
197 |
+
|
198 |
+
return {"images": images, "nsfw_content_detected": [[False] * n_views for _ in range(batch_size)]}
|
199 |
+
|
200 |
+
def prepare_latents(self, batch_size, num_channels, num_views, height, width, dtype, device, generator=None):
|
201 |
+
shape = (batch_size, num_views, num_channels, height, width)
|
202 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
203 |
+
return latents
|