kyleleey commited on
Commit
cd3b424
·
1 Parent(s): 2b1cca8

remove unused pkgs

Browse files
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
  ConfigArgParse==1.5.3
2
- core==1.0.1
3
- diffusers==0.20.0
4
  einops==0.4.1
5
  faiss==1.7.3
6
  fire==0.5.0
 
1
  ConfigArgParse==1.5.3
 
 
2
  einops==0.4.1
3
  faiss==1.7.3
4
  fire==0.5.0
video3d/diffusion/sd.py DELETED
@@ -1,252 +0,0 @@
1
- import os
2
- # os.environ['HUGGINGFACE_HUB_CACHE'] = '/work/tomj/cache/huggingface_hub'
3
- # os.environ['HF_HOME'] = '/work/tomj/cache/huggingface_hub'
4
- os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
5
- os.environ['HF_HOME'] = '/viscam/u/zzli'
6
-
7
- from transformers import CLIPTextModel, CLIPTokenizer, logging
8
- from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
9
-
10
- # Suppress partial model loading warning
11
- logging.set_verbosity_error()
12
-
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
-
17
- from torch.cuda.amp import custom_bwd, custom_fwd
18
-
19
- class SpecifyGradient(torch.autograd.Function):
20
- @staticmethod
21
- @custom_fwd
22
- def forward(ctx, input_tensor, gt_grad):
23
- ctx.save_for_backward(gt_grad)
24
- return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # Dummy loss value
25
-
26
- @staticmethod
27
- @custom_bwd
28
- def backward(ctx, grad):
29
- gt_grad, = ctx.saved_tensors
30
- batch_size = len(gt_grad)
31
- return gt_grad / batch_size, None
32
-
33
- def seed_everything(seed):
34
- torch.manual_seed(seed)
35
- torch.cuda.manual_seed(seed)
36
-
37
-
38
- class StableDiffusion(nn.Module):
39
- def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32):
40
- super().__init__()
41
-
42
- self.device = device
43
- self.sd_version = sd_version
44
- self.torch_dtype = torch_dtype
45
-
46
- print(f'[INFO] loading stable diffusion...')
47
-
48
- if hf_key is not None:
49
- print(f'[INFO] using hugging face custom model key: {hf_key}')
50
- model_key = hf_key
51
- elif self.sd_version == '2.1':
52
- model_key = "stabilityai/stable-diffusion-2-1-base"
53
- elif self.sd_version == '2.0':
54
- model_key = "stabilityai/stable-diffusion-2-base"
55
- elif self.sd_version == '1.5':
56
- model_key = "runwayml/stable-diffusion-v1-5"
57
- else:
58
- raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
59
-
60
- # Create model
61
- self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
62
- self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
63
- self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
64
- self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
65
-
66
- self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
67
- # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
68
-
69
- self.num_train_timesteps = self.scheduler.config.num_train_timesteps
70
- self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
71
-
72
- print(f'[INFO] loaded stable diffusion!')
73
-
74
- def get_text_embeds(self, prompt, negative_prompt):
75
- # prompt, negative_prompt: [str]
76
-
77
- # Tokenize text and get embeddings
78
- text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
79
-
80
- with torch.no_grad():
81
- text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
82
-
83
- # Do the same for unconditional embeddings
84
- uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
85
-
86
- with torch.no_grad():
87
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
88
-
89
- # Cat for final embeddings
90
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
91
- return text_embeddings
92
-
93
- def train_step(self, text_embeddings, pred_rgb,
94
- guidance_scale=100, loss_weight=1.0, min_step_pct=0.02, max_step_pct=0.98, return_aux=False):
95
- pred_rgb = pred_rgb.to(self.torch_dtype)
96
- text_embeddings = text_embeddings.to(self.torch_dtype)
97
- b = pred_rgb.shape[0]
98
-
99
- # interp to 512x512 to be fed into vae.
100
-
101
- # _t = time.time()
102
- pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
103
- # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
104
-
105
- # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
106
- min_step = int(self.num_train_timesteps * min_step_pct)
107
- max_step = int(self.num_train_timesteps * max_step_pct)
108
- t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
109
-
110
- # encode image into latents with vae, requires grad!
111
- # _t = time.time()
112
- latents = self.encode_imgs(pred_rgb_512)
113
- # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
114
-
115
- # predict the noise residual with unet, NO grad!
116
- # _t = time.time()
117
- with torch.no_grad():
118
- # add noise
119
- noise = torch.randn_like(latents)
120
- latents_noisy = self.scheduler.add_noise(latents, noise, t)
121
- # pred noise
122
- latent_model_input = torch.cat([latents_noisy] * 2)
123
- t_input = torch.cat([t, t])
124
- noise_pred = self.unet(latent_model_input, t_input, encoder_hidden_states=text_embeddings).sample
125
- # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
126
-
127
- # perform guidance (high scale from paper!)
128
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
129
- # noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
130
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
131
-
132
- # w(t), sigma_t^2
133
- w = (1 - self.alphas[t])
134
- # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
135
- grad = loss_weight * w[:, None, None, None] * (noise_pred - noise)
136
-
137
- # clip grad for stable training?
138
- # grad = grad.clamp(-10, 10)
139
- grad = torch.nan_to_num(grad)
140
-
141
- # since we omitted an item in grad, we need to use the custom function to specify the gradient
142
- # _t = time.time()
143
- # loss = SpecifyGradient.apply(latents, grad)
144
- # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
145
-
146
- targets = (latents - grad).detach()
147
- loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
148
-
149
- if return_aux:
150
- aux = {'grad': grad, 't': t, 'w': w}
151
- return loss, aux
152
- else:
153
- return loss
154
-
155
-
156
- def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
157
-
158
- if latents is None:
159
- latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), device=self.device)
160
-
161
- self.scheduler.set_timesteps(num_inference_steps)
162
-
163
- with torch.autocast('cuda'):
164
- for i, t in enumerate(self.scheduler.timesteps):
165
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
166
- latent_model_input = torch.cat([latents] * 2)
167
-
168
- # predict the noise residual
169
- with torch.no_grad():
170
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
171
-
172
- # perform guidance
173
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
174
- noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
175
-
176
- # compute the previous noisy sample x_t -> x_t-1
177
- latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
178
-
179
- return latents
180
-
181
- def decode_latents(self, latents):
182
-
183
- latents = 1 / self.vae.config.scaling_factor * latents
184
-
185
- with torch.no_grad():
186
- imgs = self.vae.decode(latents).sample
187
-
188
- imgs = (imgs / 2 + 0.5).clamp(0, 1)
189
-
190
- return imgs
191
-
192
- def encode_imgs(self, imgs):
193
- # imgs: [B, 3, H, W]
194
-
195
- imgs = 2 * imgs - 1
196
-
197
- posterior = self.vae.encode(imgs).latent_dist
198
- latents = posterior.sample() * self.vae.config.scaling_factor
199
-
200
- return latents
201
-
202
- def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
203
-
204
- if isinstance(prompts, str):
205
- prompts = [prompts]
206
-
207
- if isinstance(negative_prompts, str):
208
- negative_prompts = [negative_prompts]
209
-
210
- # Prompts -> text embeds
211
- text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
212
-
213
- # Text embeds -> img latents
214
- latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
215
-
216
- # Img latents -> imgs
217
- imgs = self.decode_latents(latents) # [1, 3, 512, 512]
218
-
219
- # Img to Numpy
220
- imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
221
- imgs = (imgs * 255).round().astype('uint8')
222
-
223
- return imgs
224
-
225
-
226
- if __name__ == '__main__':
227
- import argparse
228
- import matplotlib.pyplot as plt
229
-
230
- parser = argparse.ArgumentParser()
231
- parser.add_argument('prompt', type=str)
232
- parser.add_argument('--negative', default='', type=str)
233
- parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
234
- parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
235
- parser.add_argument('-H', type=int, default=512)
236
- parser.add_argument('-W', type=int, default=512)
237
- parser.add_argument('--seed', type=int, default=0)
238
- parser.add_argument('--steps', type=int, default=50)
239
- opt = parser.parse_args()
240
-
241
- seed_everything(opt.seed)
242
-
243
- device = torch.device('cuda')
244
-
245
- sd = StableDiffusion(device, opt.sd_version, opt.hf_key)
246
-
247
- imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
248
-
249
- # visualize image
250
- plt.imshow(imgs[0])
251
- plt.show()
252
- plt.savefig(f'{opt.prompt}.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video3d/diffusion/sd_utils.py DELETED
@@ -1,123 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import random
4
- import torch.nn.functional as F
5
-
6
- from ..render.light import DirectionalLight
7
-
8
- def safe_normalize(x, eps=1e-20):
9
- return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
10
-
11
- def get_view_direction(thetas, phis, overhead, front, phi_offset=0):
12
- # phis [B,]; thetas: [B,]
13
- # front = 0 [360 - front / 2, front / 2)
14
- # side (left) = 1 [front / 2, 180 - front / 2)
15
- # back = 2 [180 - front / 2, 180 + front / 2)
16
- # side (right) = 3 [180 + front / 2, 360 - front / 2)
17
- # top = 4 [0, overhead]
18
- # bottom = 5 [180-overhead, 180]
19
- res = torch.zeros(thetas.shape[0], dtype=torch.long)
20
-
21
- # first determine by phis
22
- phi_offset = np.deg2rad(phi_offset)
23
- phis = phis + phi_offset
24
- phis = phis % (2 * np.pi)
25
- half_front = front / 2
26
-
27
- res[(phis >= (2*np.pi - half_front)) | (phis < half_front)] = 0
28
- res[(phis >= half_front) & (phis < (np.pi - half_front))] = 1
29
- res[(phis >= (np.pi - half_front)) & (phis < (np.pi + half_front))] = 2
30
- res[(phis >= (np.pi + half_front)) & (phis < (2*np.pi - half_front))] = 3
31
-
32
- # override by thetas
33
- res[thetas <= overhead] = 4
34
- res[thetas >= (np.pi - overhead)] = 5
35
- return res
36
-
37
-
38
- def view_direction_id_to_text(view_direction_id):
39
- dir_texts = ['front', 'side', 'back', 'side', 'overhead', 'bottom']
40
- return [dir_texts[i] for i in view_direction_id]
41
-
42
-
43
- def append_text_direction(prompts, dir_texts):
44
- return [f'{prompt}, {dir_text} view' for prompt, dir_text in zip(prompts, dir_texts)]
45
-
46
-
47
- def rand_lights(camera_dir, fixed_ambient, fixed_diffuse):
48
- size = camera_dir.shape[0]
49
- device = camera_dir.device
50
- random_fixed_dir = F.normalize(torch.randn_like(camera_dir) + camera_dir, dim=-1) # Centered around camera_dir
51
- random_fixed_intensity = torch.tensor([fixed_ambient, fixed_diffuse], device=device)[None, :].repeat(size, 1) # ambient, diffuse
52
- return DirectionalLight(mlp_in=1, mlp_layers=1, mlp_hidden_size=1, # Dummy values
53
- intensity_min_max=[0.5, 1],fixed_dir=random_fixed_dir, fixed_intensity=random_fixed_intensity).to(device)
54
-
55
- def rand_poses(size, device, radius_range=[1, 1], theta_range=[0, 120], phi_range=[0, 360], cam_z_offset=10, return_dirs=False, angle_overhead=30, angle_front=60, phi_offset=0, jitter=False, uniform_sphere_rate=0.5):
56
- ''' generate random poses from an orbit camera
57
- Args:
58
- size: batch size of generated poses.
59
- device: where to allocate the output.
60
- radius_range: [min, max]
61
- theta_range: [min, max], should be in [0, pi]
62
- phi_range: [min, max], should be in [0, 2 * pi]
63
- Return:
64
- poses: [size, 4, 4]
65
- '''
66
-
67
- theta_range = np.deg2rad(theta_range)
68
- phi_range = np.deg2rad(phi_range)
69
- angle_overhead = np.deg2rad(angle_overhead)
70
- angle_front = np.deg2rad(angle_front)
71
-
72
- radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
73
-
74
- phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
75
- if random.random() < uniform_sphere_rate:
76
- # based on http://corysimon.github.io/articles/uniformdistn-on-sphere/
77
- # acos takes in [-1, 1], first convert theta range to fit in [-1, 1]
78
- theta_range = torch.from_numpy(np.array(theta_range)).to(device)
79
- theta_amplitude_range = torch.cos(theta_range)
80
- # sample uniformly in amplitude space range
81
- thetas_amplitude = torch.rand(size, device=device) * (theta_amplitude_range[1] - theta_amplitude_range[0]) + theta_amplitude_range[0]
82
- # convert back
83
- thetas = torch.acos(thetas_amplitude)
84
- else:
85
- thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
86
-
87
- centers = -torch.stack([
88
- radius * torch.sin(thetas) * torch.sin(phis),
89
- radius * torch.cos(thetas),
90
- radius * torch.sin(thetas) * torch.cos(phis),
91
- ], dim=-1) # [B, 3]
92
-
93
- targets = 0
94
-
95
- # jitters
96
- if jitter:
97
- centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
98
- targets = targets + torch.randn_like(centers) * 0.2
99
-
100
- # lookat
101
- forward_vector = safe_normalize(targets - centers)
102
- up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
103
- right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1))
104
-
105
- if jitter:
106
- up_noise = torch.randn_like(up_vector) * 0.02
107
- else:
108
- up_noise = 0
109
-
110
- up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1) + up_noise)
111
-
112
- poses = torch.stack([right_vector, up_vector, forward_vector], dim=-1)
113
- radius = radius[..., None] - cam_z_offset
114
- translations = torch.cat([torch.zeros_like(radius), torch.zeros_like(radius), radius], dim=-1)
115
- poses = torch.cat([poses.view(-1, 9), translations], dim=-1)
116
-
117
- if return_dirs:
118
- dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_offset=phi_offset)
119
- dirs = view_direction_id_to_text(dirs)
120
- else:
121
- dirs = None
122
-
123
- return poses, dirs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video3d/diffusion/vsd.py DELETED
@@ -1,323 +0,0 @@
1
- import os
2
- os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
3
- os.environ['HF_HOME'] = '/viscam/u/zzli'
4
-
5
- from transformers import CLIPTextModel, CLIPTokenizer, logging
6
- from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
7
-
8
- from diffusers.loaders import AttnProcsLayers
9
- from diffusers.models.attention_processor import LoRAAttnProcessor
10
- from diffusers.models.embeddings import TimestepEmbedding
11
- from diffusers.utils.import_utils import is_xformers_available
12
-
13
- # Suppress partial model loading warning
14
- logging.set_verbosity_error()
15
-
16
- import gc
17
- import random
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- import tinycudann as tcnn
22
- from video3d.diffusion.sd import StableDiffusion
23
- from torch.cuda.amp import custom_bwd, custom_fwd
24
-
25
-
26
- def seed_everything(seed):
27
- torch.manual_seed(seed)
28
- torch.cuda.manual_seed(seed)
29
-
30
- def cleanup():
31
- gc.collect()
32
- torch.cuda.empty_cache()
33
- tcnn.free_temporary_memory()
34
-
35
- class StableDiffusion_VSD(StableDiffusion):
36
- def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32, lora_n_timestamp_samples=1):
37
- super().__init__(device, sd_version=sd_version, hf_key=hf_key, torch_dtype=torch_dtype)
38
-
39
- # self.device = device
40
- # self.sd_version = sd_version
41
- # self.torch_dtype = torch_dtype
42
-
43
- if hf_key is not None:
44
- print(f'[INFO] using hugging face custom model key: {hf_key}')
45
- model_key = hf_key
46
- elif self.sd_version == '2.1':
47
- model_key = "stabilityai/stable-diffusion-2-1-base"
48
- elif self.sd_version == '2.0':
49
- model_key = "stabilityai/stable-diffusion-2-base"
50
- elif self.sd_version == '1.5':
51
- model_key = "runwayml/stable-diffusion-v1-5"
52
- else:
53
- raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
54
-
55
- # # Create model
56
- # self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
57
- # self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
58
- # self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
59
- # self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
60
-
61
- # self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
62
- # # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
63
-
64
- # self.num_train_timesteps = self.scheduler.config.num_train_timesteps
65
- # self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
66
-
67
- print(f'[INFO] loading stable diffusion VSD modules...')
68
-
69
- self.unet_lora = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
70
- cleanup()
71
-
72
- for p in self.vae.parameters():
73
- p.requires_grad_(False)
74
- for p in self.text_encoder.parameters():
75
- p.requires_grad_(False)
76
- for p in self.unet.parameters():
77
- p.requires_grad_(False)
78
- for p in self.unet_lora.parameters():
79
- p.requires_grad_(False)
80
-
81
- # set up LoRA layers
82
- lora_attn_procs = {}
83
- for name in self.unet_lora.attn_processors.keys():
84
- cross_attention_dim = (
85
- None
86
- if name.endswith("attn1.processor")
87
- else self.unet_lora.config.cross_attention_dim
88
- )
89
- if name.startswith("mid_block"):
90
- hidden_size = self.unet_lora.config.block_out_channels[-1]
91
- elif name.startswith("up_blocks"):
92
- block_id = int(name[len("up_blocks.")])
93
- hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
94
- block_id
95
- ]
96
- elif name.startswith("down_blocks"):
97
- block_id = int(name[len("down_blocks.")])
98
- hidden_size = self.unet_lora.config.block_out_channels[block_id]
99
-
100
- lora_attn_procs[name] = LoRAAttnProcessor(
101
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
102
- )
103
-
104
- self.unet_lora.set_attn_processor(lora_attn_procs)
105
-
106
- self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
107
- self.device
108
- )
109
- self.lora_layers._load_state_dict_pre_hooks.clear()
110
- self.lora_layers._state_dict_hooks.clear()
111
- self.lora_n_timestamp_samples = lora_n_timestamp_samples
112
- self.scheduler_lora = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
113
-
114
- print(f'[INFO] loaded stable diffusion VSD modules!')
115
-
116
- def train_lora(
117
- self,
118
- latents,
119
- text_embeddings,
120
- camera_condition
121
- ):
122
- B = latents.shape[0]
123
- lora_n_timestamp_samples = self.lora_n_timestamp_samples
124
- latents = latents.detach().repeat(lora_n_timestamp_samples, 1, 1, 1)
125
-
126
- t = torch.randint(
127
- int(self.num_train_timesteps * 0.0),
128
- int(self.num_train_timesteps * 1.0),
129
- [B * lora_n_timestamp_samples],
130
- dtype=torch.long,
131
- device=self.device,
132
- )
133
-
134
- noise = torch.randn_like(latents)
135
- noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
136
- if self.scheduler_lora.config.prediction_type == "epsilon":
137
- target = noise
138
- elif self.scheduler_lora.config.prediction_type == "v_prediction":
139
- target = self.scheduler_lora.get_velocity(latents, noise, t)
140
- else:
141
- raise ValueError(
142
- f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
143
- )
144
-
145
- # use view-independent text embeddings in LoRA
146
- _, text_embeddings_cond = text_embeddings.chunk(2)
147
-
148
- if random.random() < 0.1:
149
- camera_condition = torch.zeros_like(camera_condition)
150
-
151
- noise_pred = self.unet_lora(
152
- noisy_latents,
153
- t,
154
- encoder_hidden_states=text_embeddings_cond.repeat(
155
- lora_n_timestamp_samples, 1, 1
156
- ),
157
- class_labels=camera_condition.reshape(B, -1).repeat(
158
- lora_n_timestamp_samples, 1
159
- ),
160
- cross_attention_kwargs={"scale": 1.0}
161
- ).sample
162
-
163
- loss_lora = 0.5 * F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
164
- return loss_lora
165
-
166
-
167
- def train_step(
168
- self,
169
- text_embeddings,
170
- text_embeddings_vd,
171
- pred_rgb,
172
- camera_condition,
173
- im_features,
174
- guidance_scale=7.5,
175
- guidance_scale_lora=7.5,
176
- loss_weight=1.0,
177
- min_step_pct=0.02,
178
- max_step_pct=0.98,
179
- return_aux=False
180
- ):
181
- pred_rgb = pred_rgb.to(self.torch_dtype)
182
- text_embeddings = text_embeddings.to(self.torch_dtype)
183
- text_embeddings_vd = text_embeddings_vd.to(self.torch_dtype)
184
- camera_condition = camera_condition.to(self.torch_dtype)
185
- im_features = im_features.to(self.torch_dtype)
186
-
187
- # condition_label = camera_condition
188
- condition_label = im_features
189
-
190
- b = pred_rgb.shape[0]
191
-
192
- # interp to 512x512 to be fed into vae.
193
- # _t = time.time()
194
- pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
195
- # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
196
-
197
- # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
198
- min_step = int(self.num_train_timesteps * min_step_pct)
199
- max_step = int(self.num_train_timesteps * max_step_pct)
200
- t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
201
-
202
- # encode image into latents with vae, requires grad!
203
- # _t = time.time()
204
- latents = self.encode_imgs(pred_rgb_512)
205
- # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
206
-
207
- # predict the noise residual with unet, NO grad!
208
- # _t = time.time()
209
- with torch.no_grad():
210
- # add noise
211
- noise = torch.randn_like(latents)
212
- latents_noisy = self.scheduler.add_noise(latents, noise, t)
213
- # pred noise
214
- latent_model_input = torch.cat([latents_noisy] * 2)
215
-
216
- # disable unet class embedding here
217
- cls_embedding = self.unet.class_embedding
218
- self.unet.class_embedding = None
219
-
220
- cross_attention_kwargs = None
221
- noise_pred_pretrain = self.unet(
222
- latent_model_input,
223
- torch.cat([t, t]),
224
- encoder_hidden_states=text_embeddings_vd,
225
- class_labels=None,
226
- cross_attention_kwargs=cross_attention_kwargs
227
- ).sample
228
-
229
- self.unet.class_embedding = cls_embedding
230
-
231
- # use view-independent text embeddings in LoRA
232
- _, text_embeddings_cond = text_embeddings.chunk(2)
233
-
234
- noise_pred_est = self.unet_lora(
235
- latent_model_input,
236
- torch.cat([t, t]),
237
- encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
238
- class_labels=torch.cat(
239
- [
240
- condition_label.reshape(b, -1),
241
- torch.zeros_like(condition_label.reshape(b, -1)),
242
- ],
243
- dim=0,
244
- ),
245
- cross_attention_kwargs={"scale": 1.0},
246
- ).sample
247
-
248
- noise_pred_pretrain_uncond, noise_pred_pretrain_text = noise_pred_pretrain.chunk(2)
249
-
250
- noise_pred_pretrain = noise_pred_pretrain_uncond + guidance_scale * (
251
- noise_pred_pretrain_text - noise_pred_pretrain_uncond
252
- )
253
-
254
- assert self.scheduler.config.prediction_type == "epsilon"
255
- if self.scheduler_lora.config.prediction_type == "v_prediction":
256
- alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
257
- device=latents_noisy.device, dtype=latents_noisy.dtype
258
- )
259
- alpha_t = alphas_cumprod[t] ** 0.5
260
- sigma_t = (1 - alphas_cumprod[t]) ** 0.5
261
-
262
- noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).reshape(
263
- -1, 1, 1, 1
264
- ) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).reshape(-1, 1, 1, 1)
265
-
266
- noise_pred_est_uncond, noise_pred_est_camera = noise_pred_est.chunk(2)
267
-
268
- noise_pred_est = noise_pred_est_uncond + guidance_scale_lora * (
269
- noise_pred_est_camera - noise_pred_est_uncond
270
- )
271
-
272
- # w(t), sigma_t^2
273
- w = (1 - self.alphas[t])
274
- # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
275
- grad = loss_weight * w[:, None, None, None] * (noise_pred_pretrain - noise_pred_est)
276
-
277
- grad = torch.nan_to_num(grad)
278
-
279
- targets = (latents - grad).detach()
280
- loss_vsd = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
281
-
282
- loss_lora = self.train_lora(latents, text_embeddings, condition_label)
283
-
284
- loss = {
285
- 'loss_vsd': loss_vsd,
286
- 'loss_lora': loss_lora
287
- }
288
-
289
- if return_aux:
290
- aux = {'grad': grad, 't': t, 'w': w}
291
- return loss, aux
292
- else:
293
- return loss
294
-
295
-
296
-
297
- if __name__ == '__main__':
298
- import argparse
299
- import matplotlib.pyplot as plt
300
-
301
- parser = argparse.ArgumentParser()
302
- parser.add_argument('prompt', type=str)
303
- parser.add_argument('--negative', default='', type=str)
304
- parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
305
- parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
306
- parser.add_argument('-H', type=int, default=512)
307
- parser.add_argument('-W', type=int, default=512)
308
- parser.add_argument('--seed', type=int, default=0)
309
- parser.add_argument('--steps', type=int, default=50)
310
- opt = parser.parse_args()
311
-
312
- seed_everything(opt.seed)
313
-
314
- device = torch.device('cuda')
315
-
316
- sd = StableDiffusion_VSD(device, opt.sd_version, opt.hf_key)
317
-
318
- imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
319
-
320
- # visualize image
321
- plt.imshow(imgs[0])
322
- plt.show()
323
- plt.savefig(f'{opt.prompt}.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video3d/model_ddp.py CHANGED
@@ -41,10 +41,6 @@ from .render import mesh
41
  from .render import light
42
  from .render import render
43
 
44
- from .diffusion.sd import StableDiffusion
45
- from .diffusion.vsd import StableDiffusion_VSD
46
- from .diffusion.sd_utils import rand_poses, rand_lights, append_text_direction
47
-
48
  EPS = 1e-7
49
 
50
 
@@ -1269,53 +1265,8 @@ class Unsup3DDDP:
1269
 
1270
  self.enable_sds = cfgs.get('enable_sds', False)
1271
  self.enable_vsd = cfgs.get('enable_vsd', False)
1272
- if self.enable_sds:
1273
- diffusion_torch_dtype = torch.float16 if cfgs.get('diffusion_precision', 'float16') == 'float16' else torch.float32
1274
-
1275
- # decide if use SDS or VSD
1276
- if self.enable_vsd:
1277
- # self.stable_diffusion = misc.LazyClass(StableDiffusion_VSD, device=self.device, torch_dtype=diffusion_torch_dtype)
1278
- self.stable_diffusion = StableDiffusion_VSD(device=self.device, torch_dtype=diffusion_torch_dtype)
1279
- self.diffusion_guidance_scale_lora = cfgs.get('diffusion_guidance_scale_lora', 1.)
1280
- self.diffusion_guidance_scale = cfgs.get('diffusion_guidance_scale', 7.5)
1281
- else:
1282
- self.stable_diffusion = misc.LazyClass(StableDiffusion, device=self.device, torch_dtype=diffusion_torch_dtype)
1283
- self.diffusion_guidance_scale = cfgs.get('diffusion_guidance_scale', 100.)
1284
-
1285
- self.diffusion_loss_weight = cfgs.get('diffusion_loss_weight', 1.)
1286
- self.diffusion_num_random_cameras = cfgs.get('diffusion_num_random_cameras', 1)
1287
-
1288
- # For prompts
1289
- self.diffusion_prompt = cfgs.get('diffusion_prompt', '')
1290
- self.diffusion_negative_prompt = cfgs.get('diffusion_negative_prompt', '')
1291
-
1292
- # For image sampling
1293
- self.diffusion_albedo_ratio = cfgs.get('diffusion_albedo_ratio', 0.2)
1294
- self.diffusion_shading_ratio = cfgs.get('diffusion_shading_ratio', 0.4)
1295
- self.diffusion_light_ambient = cfgs.get('diffusion_light_ambient', 0.5)
1296
- self.diffusion_light_diffuse = cfgs.get('diffusion_light_diffuse', 0.8)
1297
- self.diffusion_radius_range = cfgs.get('diffusion_radius_range', [0.8, 1.4])
1298
- self.diffusion_uniform_sphere_rate = cfgs.get('diffusion_uniform_sphere_rate', 0.5)
1299
- self.diffusion_theta_range = cfgs.get('diffusion_theta_range', [0, 120])
1300
- self.diffusion_phi_offset = cfgs.get('diffusion_phi_offset', 180)
1301
- self.diffusion_resolution = cfgs.get('diffusion_resolution', 256)
1302
-
1303
- print('-----------------------------------------------')
1304
- print(f"!!!!!! the phi offset for diffusion is set as {self.diffusion_phi_offset}!!!!!!!!!!!!!")
1305
- print('-----------------------------------------------')
1306
-
1307
- # For randomizing light
1308
- self.diffusion_random_light = cfgs.get('diffusion_random_light', False)
1309
- self.diffusion_light_ambient = cfgs.get('diffusion_light_ambient', 0.5)
1310
- self.diffusion_light_diffuse = cfgs.get('diffusion_light_diffuse', 0.8)
1311
-
1312
- # For noise scheduling
1313
- self.diffusion_max_step = cfgs.get('diffusion_max_step', 0.98)
1314
-
1315
- # For view-dependent prompting
1316
- self.diffusion_append_prompt_directions = cfgs.get('diffusion_append_prompt_directions', False)
1317
- self.diffusion_angle_overhead = cfgs.get('diffusion_angle_overhead', 30)
1318
- self.diffusion_angle_front = cfgs.get('diffusion_angle_front', 60)
1319
 
1320
  @staticmethod
1321
  def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False):
@@ -2017,141 +1968,6 @@ class Unsup3DDDP:
2017
 
2018
  return losses, aux
2019
 
2020
- def score_distillation_sampling(self, shape, texture, resolution, im_features, light, prior_shape, random_light=False, prompts=None, classes_vectors=None, im_features_map=None, w2c_pred=None):
2021
- num_instances = im_features.shape[0]
2022
- n_total_random_cameras = num_instances * self.diffusion_num_random_cameras
2023
-
2024
- poses, dirs = rand_poses(
2025
- n_total_random_cameras, self.device, radius_range=self.diffusion_radius_range, uniform_sphere_rate=self.diffusion_uniform_sphere_rate,
2026
- cam_z_offset=self.cam_pos_z_offset, theta_range=self.diffusion_theta_range, phi_offset=self.diffusion_phi_offset, return_dirs=True,
2027
- angle_front=self.diffusion_angle_front, angle_overhead=self.diffusion_angle_overhead,
2028
- )
2029
- mvp, w2c, campos = self.netInstance.get_camera_extrinsics_from_pose(poses, crop_fov_approx=self.crop_fov_approx)
2030
-
2031
- if random_light:
2032
- lights = rand_lights(campos, fixed_ambient=self.diffusion_light_ambient, fixed_diffuse=self.diffusion_light_diffuse)
2033
- else:
2034
- lights = light
2035
-
2036
- proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(num_instances, 1, 1).to(self.device)
2037
- original_mvp = torch.bmm(proj, w2c_pred)
2038
-
2039
- im_features = im_features.repeat(self.diffusion_num_random_cameras, 1) if im_features is not None else None
2040
- num_shapes = shape.v_pos.shape[0]
2041
- assert n_total_random_cameras % num_shapes == 0
2042
- shape = shape.extend(n_total_random_cameras // num_shapes)
2043
-
2044
- bg_color = torch.rand((n_total_random_cameras, 3), device=self.device) # channel-wise random
2045
- background = repeat(bg_color, 'b c -> b h w c', h=resolution[0], w=resolution[1])
2046
-
2047
- # only train the texture
2048
- safe_detach = lambda x: x.detach() if x is not None else None
2049
- shape = safe_detach(shape)
2050
- im_features = safe_detach(im_features)
2051
- im_features_map = safe_detach(im_features_map)
2052
-
2053
- set_requires_grad(texture, True)
2054
- set_requires_grad(light, True)
2055
-
2056
- image_pred, mask_pred, _, _, albedo, shading = self.render(
2057
- shape,
2058
- texture,
2059
- mvp,
2060
- w2c,
2061
- campos,
2062
- resolution,
2063
- im_features=im_features,
2064
- light=lights,
2065
- prior_shape=prior_shape,
2066
- dino_pred=None,
2067
- spp=self.renderer_spp,
2068
- bg_image=background,
2069
- im_features_map={"original_mvp": original_mvp, "im_features_map": im_features_map} if im_features_map is not None else None
2070
- )
2071
- if self.enable_vsd:
2072
- if prompts is None:
2073
- prompts = n_total_random_cameras * [self.diffusion_prompt]
2074
- else:
2075
- if '_' in prompts:
2076
- prompts = prompts.replace('_', ' ')
2077
- prompts = n_total_random_cameras * [prompts]
2078
-
2079
- prompts = ['a high-resolution DSLR image of ' + x for x in prompts]
2080
- assert self.diffusion_append_prompt_directions
2081
- # TODO: check if this implementation is aligned with stable-diffusion-prompt-processor
2082
- prompts_vd = append_text_direction(prompts, dirs)
2083
- negative_prompts = n_total_random_cameras * [self.diffusion_negative_prompt]
2084
-
2085
- text_embeddings = self.stable_diffusion.get_text_embeds(prompts, negative_prompts) # [BB, 77, 768]
2086
- text_embeddings_vd = self.stable_diffusion.get_text_embeds(prompts_vd, negative_prompts)
2087
-
2088
- camera_condition_type = 'c2w'
2089
- if camera_condition_type == 'c2w':
2090
- camera_condition = torch.linalg.inv(w2c).detach()
2091
- elif camera_condition_type == 'mvp':
2092
- camera_condition = mvp.detach()
2093
- else:
2094
- raise NotImplementedError
2095
-
2096
- # Alternate among albedo, shading, and image
2097
- rand = torch.rand(n_total_random_cameras, device=self.device)
2098
- rendered_component = torch.zeros_like(image_pred)
2099
- mask_pred = mask_pred[:, None]
2100
- background = rearrange(background, 'b h w c -> b c h w')
2101
- albedo_flag = rand > (1 - self.diffusion_albedo_ratio)
2102
- rendered_component[albedo_flag] = albedo[albedo_flag] * mask_pred[albedo_flag] + (1 - mask_pred[albedo_flag]) * background[albedo_flag]
2103
- shading_flag = (rand > (1 - self.diffusion_albedo_ratio - self.diffusion_shading_ratio)) & (rand <= (1 - self.diffusion_albedo_ratio))
2104
- rendered_component[shading_flag] = shading.repeat(1, 3, 1, 1)[shading_flag] / 2 * mask_pred[shading_flag] + (1 - mask_pred[shading_flag]) * background[shading_flag]
2105
- rendered_component[~(albedo_flag | shading_flag)] = image_pred[~(albedo_flag | shading_flag)]
2106
-
2107
- condition_label = classes_vectors
2108
- # condition_label = im_features
2109
-
2110
- sd_loss, sd_aux = self.stable_diffusion.train_step(
2111
- text_embeddings,
2112
- text_embeddings_vd,
2113
- rendered_component,
2114
- camera_condition, # TODO: can we input category condition in lora?
2115
- condition_label,
2116
- guidance_scale=self.diffusion_guidance_scale,
2117
- guidance_scale_lora=self.diffusion_guidance_scale_lora,
2118
- loss_weight=self.diffusion_loss_weight,
2119
- max_step_pct=self.diffusion_max_step,
2120
- return_aux=True
2121
- )
2122
-
2123
- aux = {'loss': sd_loss['loss_vsd'], 'loss_lora': sd_loss['loss_lora'], 'dirs': dirs, 'sd_aux': sd_aux, 'rendered_shape': shape}
2124
-
2125
- else:
2126
- # Prompt to text embeds
2127
- if prompts is None:
2128
- prompts = n_total_random_cameras * [self.diffusion_prompt]
2129
- else:
2130
- if '_' in prompts:
2131
- prompts = prompts.replace('_', ' ')
2132
- prompts = n_total_random_cameras * [prompts]
2133
- prompts = ['a high-resolution DSLR image of ' + x for x in prompts]
2134
- if self.diffusion_append_prompt_directions:
2135
- prompts = append_text_direction(prompts, dirs)
2136
- negative_prompts = n_total_random_cameras * [self.diffusion_negative_prompt]
2137
- text_embeddings = self.stable_diffusion.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
2138
-
2139
- # Alternate among albedo, shading, and image
2140
- rand = torch.rand(n_total_random_cameras, device=self.device)
2141
- rendered_component = torch.zeros_like(image_pred)
2142
- mask_pred = mask_pred[:, None]
2143
- background = rearrange(background, 'b h w c -> b c h w')
2144
- albedo_flag = rand > (1 - self.diffusion_albedo_ratio)
2145
- rendered_component[albedo_flag] = albedo[albedo_flag] * mask_pred[albedo_flag] + (1 - mask_pred[albedo_flag]) * background[albedo_flag]
2146
- shading_flag = (rand > (1 - self.diffusion_albedo_ratio - self.diffusion_shading_ratio)) & (rand <= (1 - self.diffusion_albedo_ratio))
2147
- rendered_component[shading_flag] = shading.repeat(1, 3, 1, 1)[shading_flag] / 2 * mask_pred[shading_flag] + (1 - mask_pred[shading_flag]) * background[shading_flag]
2148
- rendered_component[~(albedo_flag | shading_flag)] = image_pred[~(albedo_flag | shading_flag)]
2149
- sd_loss, sd_aux = self.stable_diffusion.train_step(
2150
- text_embeddings, rendered_component, guidance_scale=self.diffusion_guidance_scale, loss_weight=self.diffusion_loss_weight, max_step_pct=self.diffusion_max_step, return_aux=True)
2151
- aux = {'loss':sd_loss, 'dirs': dirs, 'sd_aux': sd_aux, 'rendered_shape': shape}
2152
-
2153
- return rendered_component, aux
2154
-
2155
  def parse_dict_definition(self, dict_config, total_iter):
2156
  '''
2157
  The dict_config is a diction-based configuration with ascending order
@@ -2987,19 +2803,6 @@ class Unsup3DDDP:
2987
  final_losses[name] = loss.mean()
2988
  final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean()
2989
 
2990
- ## score distillation sampling
2991
- sds_random_images = None
2992
- if self.enable_sds:
2993
- prompts = None
2994
- if classes_vectors is not None:
2995
- prompts = category_name[0]
2996
- sds_random_images, sds_aux = self.score_distillation_sampling(shape, texture, [self.diffusion_resolution, self.diffusion_resolution], im_features, light, prior_shape, prompts=prompts, classes_vectors=class_vector[None, :].expand(batch_size * num_frames, -1), im_features_map=im_features_map, w2c_pred=w2c)
2997
- if self.enable_vsd:
2998
- final_losses.update({'vsd_loss': sds_aux['loss']})
2999
- final_losses.update({'vsd_lora_loss': sds_aux['loss_lora']})
3000
- else:
3001
- final_losses.update({'sds_loss': sds_aux['loss']})
3002
-
3003
  ## mask distribution loss
3004
  mask_distribution_aux = None
3005
  if self.enable_mask_distribution:
 
41
  from .render import light
42
  from .render import render
43
 
 
 
 
 
44
  EPS = 1e-7
45
 
46
 
 
1265
 
1266
  self.enable_sds = cfgs.get('enable_sds', False)
1267
  self.enable_vsd = cfgs.get('enable_vsd', False)
1268
+ self.enable_sds = False
1269
+ self.enable_vsd = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1270
 
1271
  @staticmethod
1272
  def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False):
 
1968
 
1969
  return losses, aux
1970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1971
  def parse_dict_definition(self, dict_config, total_iter):
1972
  '''
1973
  The dict_config is a diction-based configuration with ascending order
 
2803
  final_losses[name] = loss.mean()
2804
  final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean()
2805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2806
  ## mask distribution loss
2807
  mask_distribution_aux = None
2808
  if self.enable_mask_distribution: