Spaces:
Runtime error
Runtime error
import torch | |
from copy import deepcopy | |
from einops import repeat | |
import math | |
class FrameConditioning(): | |
def __init__(self, | |
add_frame_to_input: bool = False, | |
add_frame_to_layers: bool = False, | |
fill_zero: bool = False, | |
randomize_mask: bool = False, | |
concatenate_mask: bool = False, | |
injection_probability: float = 0.9, | |
) -> None: | |
self.use = None | |
self.add_frame_to_input = add_frame_to_input | |
self.add_frame_to_layers = add_frame_to_layers | |
self.fill_zero = fill_zero | |
self.randomize_mask = randomize_mask | |
self.concatenate_mask = concatenate_mask | |
self.injection_probability = injection_probability | |
self.add_frame_to_input or self.add_frame_to_layers | |
assert not add_frame_to_layers or not add_frame_to_input | |
def set_random_mask(self, random_mask: bool): | |
frame_conditioning = deepcopy(self) | |
frame_conditioning.randomize_mask = random_mask | |
return frame_conditioning | |
def use(self): | |
return self.add_frame_to_input or self.add_frame_to_layers | |
def use(self, value): | |
if value is not None: | |
raise NotImplementedError("Direct access not allowed") | |
def attach_video_frames(self, pl_module, z_0: torch.Tensor = None, batch: torch.Tensor = None, random_mask: bool = False): | |
assert self.fill_zero, "Not filling with zero not implemented yet" | |
n_frames_inference = self.inference_params.video_length | |
with torch.no_grad(): | |
if z_0 is None: | |
assert batch is not None | |
z_0 = pl_module.encode_frame(batch) | |
assert n_frames_inference == z_0.shape[1], "For frame injection, the number of frames sampled by the dataloader must match the number of frames used for video generation" | |
shape = list(z_0.shape) | |
shape[1] = pl_module.inference_params.video_length | |
M = torch.zeros(shape, dtype=z_0.dtype, | |
device=pl_module.device) # [B F C W H] | |
bsz = z_0.shape[0] | |
if random_mask: | |
p_inject_frame = self.injection_probability | |
use_masks = torch.bernoulli( | |
torch.tensor(p_inject_frame).repeat(bsz)).long() | |
keep_frame_idx = torch.randint( | |
0, n_frames_inference, (bsz,), device=pl_module.device).long() | |
else: | |
use_masks = torch.ones((bsz,), device=pl_module.device).long() | |
# keep only first frame | |
keep_frame_idx = 0 * use_masks | |
frame_idx = [] | |
for batch_idx, (keep_frame, use_mask) in enumerate(zip(keep_frame_idx, use_masks)): | |
M[batch_idx, keep_frame] = use_mask | |
frame_idx.append(keep_frame if use_mask == 1 else -1) | |
x0 = z_0*M | |
if self.concatenate_mask: | |
# flatten mask | |
M = M[:, :, 0, None] | |
x0 = torch.cat([x0, M], dim=2) | |
if getattr(pl_module.opt_params.noise_decomposition, "use", False) and random_mask: | |
assert x0.shape[0] == 1, "randomizing frame injection with noise decomposition not implemented for batch size >1" | |
return x0, frame_idx | |
class NoiseDecomposition(): | |
def __init__(self, | |
use: bool = False, | |
random_frame: bool = False, | |
lambda_f: float = 0.5, | |
use_base_model: bool = True, | |
): | |
self.use = use | |
self.random_frame = random_frame | |
self.lambda_f = lambda_f | |
self.use_base_model = use_base_model | |
def get_loss(self, x0, unet_base, unet, noise_scheduler, frame_idx, z_t_base, timesteps, encoder_hidden_states, base_noise, z_t_residual, composed_noise): | |
if x0 is not None: | |
# x0.shape = [B,F,C,W,H], if extrapolation_params.fill_zero=true, only one frame per batch non-zero | |
assert not self.random_frame | |
# TODO add x0 injection | |
x0_base = [] | |
for batch_idx, frame in enumerate(frame_idx): | |
x0_base.append(x0[batch_idx, frame, None, None]) | |
x0_base = torch.cat(x0_base, dim=0) | |
x0_residual = repeat( | |
x0[:, 0], "B C W H -> B F C W H", F=x0.shape[1]-1) | |
else: | |
x0_residual = None | |
if self.use_base_model: | |
base_pred = unet_base(z_t_base, timesteps, | |
encoder_hidden_states, x0=x0_base).sample | |
else: | |
base_pred = base_noise | |
timesteps_alphas = [ | |
noise_scheduler.alphas_cumprod[t.cpu()] for t in timesteps] | |
timesteps_alphas = torch.stack( | |
timesteps_alphas).to(base_pred.device) | |
timesteps_alphas = repeat(timesteps_alphas, "B -> B F C W H", | |
F=base_pred.shape[1], C=base_pred.shape[2], W=base_pred.shape[3], H=base_pred.shape[4]) | |
base_correction = math.sqrt( | |
lambda_f) * torch.sqrt(1-timesteps_alphas) * base_pred | |
z_t_residual_dash = z_t_residual - base_correction | |
residual_pred = unet( | |
z_t_residual_dash, timesteps, encoder_hidden_states, x0=x0_residual).sample | |
composed_pred = math.sqrt( | |
lambda_f)*base_pred.detach() + math.sqrt(1-lambda_f) * residual_pred | |
loss_residual = torch.nn.functional.mse_loss( | |
composed_noise.float(), composed_pred.float(), reduction=reduction) | |
if self.use_base_model: | |
loss_base = torch.nn.functional.mse_loss( | |
base_noise.float(), base_pred.float(), reduction=reduction) | |
loss = loss_residual+loss_base | |
else: | |
loss = loss_residual | |
return loss | |
def add_noise(self, z_base, base_noise, z_residual, composed_noise, noise_scheduler, timesteps): | |
z_t_base = noise_scheduler.add_noise( | |
z_base, base_noise, timesteps) | |
z_t_residual = noise_scheduler.add_noise( | |
z_residual, composed_noise, timesteps) | |
return z_t_base, z_t_residual | |
def split_latent_into_base_residual(self, z_0, pl_module, noise_generator): | |
if self.random_frame: | |
raise NotImplementedError("Must be synced with x0 mask!") | |
fr_select = torch.randint( | |
0, z_0.shape[1], (bsz,), device=pl_module.device).long() | |
z_base = z_0[:, fr_Select, None] | |
fr_residual = [fr for fr in range( | |
z_0.shape[1]) if fr != fr_select] | |
z_residual = z_0[:, fr_residual, None] | |
else: | |
if not pl_module.unet_params.frame_conditioning.randomize_mask: | |
z_base = z_0[:, 0, None] | |
z_residual = z_0[:, 1:] | |
else: | |
z_base = [] | |
for batch_idx, frame_at_batch in enumerate(frame_idx): | |
z_base.append( | |
z_0[batch_idx, frame_at_batch, None, None]) | |
z_base = torch.cat(z_base, dim=0) | |
# z_residual = z_0[[:, 1:] | |
z_residual = [] | |
for batch_idx, frame_idx_batch in enumerate(frame_idx): | |
z_residual_batch = [] | |
for frame in range(z_0.shape[1]): | |
if frame_idx_batch != frame: | |
z_residual_batch.append( | |
z_0[batch_idx, frame, None, None]) | |
z_residual_batch = torch.cat( | |
z_residual_batch, dim=1) | |
z_residual.append(z_residual_batch) | |
z_residual = torch.cat(z_residual, dim=0) | |
base_noise = noise_generator.sample_noise(z_base) # b_t | |
residual_noise = noise_generator.sample_noise(z_residual) # r^f_t | |
lambda_f = self.lambda_f | |
composed_noise = math.sqrt( | |
lambda_f) * base_noise + math.sqrt(1-lambda_f) * residual_noise # dimension issue? | |
return z_base, base_noise, z_residual, composed_noise | |
class NoiseGenerator(): | |
def __init__(self, mode="vanilla") -> None: | |
self.mode = mode | |
def set_seed(self, seed: int): | |
self.seed = seed | |
def reset_seed(self, seed: int): | |
pass | |
def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None): | |
assert (z_0 is not None) != ( | |
shape is not None), f"either z_0 must be None, or shape must be None. Both provided." | |
kwargs = {} | |
if z_0 is None: | |
if device is not None: | |
kwargs["device"] = device | |
if dtype is not None: | |
kwargs["dtype"] = dtype | |
else: | |
kwargs["device"] = z_0.device | |
kwargs["dtype"] = z_0.dtype | |
shape = z_0.shape | |
if generator is not None: | |
kwargs["generator"] = generator | |
B, F, C, W, H = shape | |
if self.mode == "vanilla": | |
noise = torch.randn( | |
shape, **kwargs) | |
elif self.mode == "free_noise": | |
noise = torch.randn(shape, **kwargs) | |
if noise.shape[1] > 4: | |
# HARD CODED | |
noise = noise[:, :8] | |
noise = torch.cat( | |
[noise, noise[:, torch.randperm(noise.shape[1])]], dim=1) | |
elif noise.shape[2] > 4: | |
noise = noise[:, :, :8] | |
noise = torch.cat( | |
[noise, noise[:, :, torch.randperm(noise.shape[2])]], dim=2) | |
else: | |
raise NotImplementedError( | |
f"Shape of noise vector not as expected {noise.shape}") | |
elif self.mode == "equal": | |
shape = list(shape) | |
shape[1] = 1 | |
noise_init = torch.randn( | |
shape, **kwargs) | |
shape[1] = F | |
noise = torch.zeros( | |
shape, device=noise_init.device, dtype=noise_init.dtype) | |
for fr in range(F): | |
noise[:, fr] = noise_init[:, 0] | |
elif self.mode == "fusion": | |
shape = list(shape) | |
shape[1] = 1 | |
noise_init = torch.randn( | |
shape, **kwargs) | |
noises = [] | |
noises.append(noise_init) | |
for fr in range(F-1): | |
shift = 2*(fr+1) | |
local_copy = noise_init | |
shifted_noise = torch.cat( | |
[local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) | |
noises.append(math.sqrt(0.2)*shifted_noise + | |
math.sqrt(1-0.2)*torch.rand(shape, **kwargs)) | |
noise = torch.cat(noises, dim=1) | |
elif self.mode == "motion_dynamics" or self.mode == "equal_noise_per_sequence": | |
shape = list(shape) | |
normal_frames = 1 | |
shape[1] = normal_frames | |
init_noise = torch.randn( | |
shape, **kwargs) | |
noises = [] | |
noises.append(init_noise) | |
init_noise = init_noise[:, -1, None] | |
print(f"UPDATE with noise = {init_noise.shape}") | |
if self.mode == "motion_dynamics": | |
for fr in range(F-normal_frames): | |
shift = 2*(fr+1) | |
print(fr, shift) | |
local_copy = init_noise | |
shifted_noise = torch.cat( | |
[local_copy[:, :, :, shift:, :], local_copy[:, :, :, :shift, :]], dim=3) | |
noises.append(shifted_noise) | |
elif self.mode == "equal_noise_per_sequence": | |
for fr in range(F-1): | |
noises.append(init_noise) | |
else: | |
raise NotImplementedError() | |
# noises[0] = noises[0] * 0 | |
noise = torch.cat(noises, dim=1) | |
print(noise.shape) | |
return noise | |