Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from tqdm import tqdm | |
import numpy as np | |
import math | |
class RectifiedFlow(): | |
def __init__(self, num_timesteps, warmup_timesteps = 10, noise_scale=1.0, init_type='gaussian', eps=1, sampling='logit', window_size=8): | |
""" | |
eps: A `float` number. The smallest time step to sample from. | |
""" | |
self.num_timesteps = num_timesteps | |
self.warmup_timesteps = warmup_timesteps*num_timesteps | |
self.T = 1000. | |
self.noise_scale = noise_scale | |
self.init_type = init_type | |
self.eps = eps | |
self.window_size = window_size | |
self.sampling = sampling | |
def logit(self, x): | |
return torch.log(x / (1 - x)) | |
def logit_normal(self, x, mu=0, sigma=1): | |
return 1 / (sigma * math.sqrt(2 * torch.pi) * x * (1 - x)) * torch.exp(-(self.logit(x) - mu) ** 2 / (2 * sigma ** 2)) | |
def training_loss(self, model, v, a, model_kwargs): | |
""" | |
v: [B, T, C, H, W] | |
a: [B, T, N, F] | |
""" | |
B,T = v.shape[:2] | |
tw = torch.rand((v.shape[0],1), device=v.device) | |
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v.device).unsqueeze(0).repeat(B,1) | |
rollout = torch.bernoulli(torch.tensor(0.8).repeat(B).to(v.device)).bool() | |
t_rollout = (window_indexes+tw)/self.window_size | |
t_pre_rollout = window_indexes/self.window_size + tw | |
t = torch.where(rollout.unsqueeze(1).repeat(1,self.window_size), t_rollout, t_pre_rollout) | |
t = 1 - t # swap 0 and 1, since 1 is full image and 0 is full noise | |
t = torch.clamp(t, 0+1e-6, 1-1e-6) | |
if self.sampling == 'logit': | |
weigths = self.logit_normal(t, mu=0, sigma=1) | |
else: | |
weigths = torch.ones_like(t) | |
B, T = t.shape | |
v_z0 = self.get_z0(v).to(v.device) | |
a_z0 = self.get_z0(a).to(a.device) | |
t_video = t.view(B,T,1,1,1).repeat(1,1,v.shape[2], v.shape[3], v.shape[4]) | |
t_audio = t.view(B,T,1,1,1).repeat(1,1,a.shape[2], a.shape[3], a.shape[4]) | |
perturbed_video = t_video*v + (1-t_video)*v_z0 | |
perturbed_audio = t_audio*a + (1-t_audio)*a_z0 | |
t_rf = t*(self.T-self.eps) + self.eps | |
score_v, score_a = model(perturbed_video, perturbed_audio, t_rf, **model_kwargs) | |
# score_v = [B, T, C, H, W] | |
# score_a = [B, T, N, F] | |
target_video = v - v_z0 # direction of the flow | |
target_audio = a - a_z0 # direction of the flow | |
loss_video = torch.square(score_v-target_video) | |
loss_audio = torch.square(score_a-target_audio) | |
loss_video = torch.mean(loss_video, dim=[2,3,4]) | |
loss_audio = torch.mean(loss_audio, dim=[2,3,4]) | |
#mask out the loss for the time steps that are greater than T | |
loss_video = loss_video * (weigths) | |
loss_video = torch.mean(loss_video) | |
loss_audio = loss_audio * (weigths) | |
loss_audio = torch.mean(loss_audio) | |
return {"loss": (loss_video + loss_audio)} | |
def sample(self, model, v_z, a_z, model_kwargs, progress=True): | |
B = v_z.shape[0] | |
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1) | |
# warm up with different number of warmup timestep to be more precise | |
for i in tqdm(range(self.warmup_timesteps), disable=not progress): | |
dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i) | |
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) | |
v_z = v_z.detach().clone() + dt*score_v | |
a_z = a_z.detach().clone() + dt*score_a | |
v_f = v_z[:,0] | |
a_f = a_z[:,0] | |
v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
def yield_frame(): | |
nonlocal v_z, a_z, window_indexes | |
yield (v_f, a_f) | |
dt = 1/(self.num_timesteps*self.window_size) | |
while True: | |
for i in range(self.num_timesteps): | |
tw = (self.num_timesteps - i)/self.num_timesteps | |
t = (window_indexes + tw)/self.window_size | |
t = 1-t | |
t_rf = t*(self.T-self.eps) + self.eps | |
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) | |
v_z = v_z.detach().clone() + dt*score_v | |
a_z = a_z.detach().clone() + dt*score_a | |
v = v_z[:,0] | |
a = a_z[:,0] | |
#remove the first element | |
v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale | |
a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale | |
v_z = torch.cat([v_z[:,1:],v_noise], dim=1) | |
a_z = torch.cat([a_z[:,1:],a_noise], dim=1) | |
yield (v, a) | |
return yield_frame | |
def sample_a2v(self, model, v_z, a, model_kwargs, scale=1, progress=True): | |
B = v_z.shape[0] | |
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1) | |
a_partial = a[:, :self.window_size] | |
a_noise = torch.randn_like(a, device=v_z.device)*self.noise_scale | |
a_noise_partial = a_noise[:, :self.window_size] | |
with torch.enable_grad(): | |
# warm up with different number of warmup timestep to be more precise | |
for i in tqdm(range(self.warmup_timesteps), disable=not progress): | |
v_z = v_z.detach().requires_grad_(True) | |
dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i) | |
a_z = a_partial*t_partial + a_noise_partial*(1-t_partial) | |
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) | |
loss = torch.square((a_partial-a_noise_partial)-score_a) | |
grad = torch.autograd.grad(loss.mean(), v_z)[0] | |
v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale | |
v_f = v_z[:,0].detach() | |
v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
def yield_frame(): | |
nonlocal v_z, a, a_noise, window_indexes | |
yield v_f | |
dt = 1/(self.num_timesteps*self.window_size) | |
while True: | |
torch.cuda.empty_cache() | |
a = a[:,1:] | |
a_noise = a_noise[:,1:] | |
if a.shape[1] < self.window_size: | |
a = torch.cat([a, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
a_noise = torch.cat([a_noise, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
a_partial = a[:, :self.window_size] | |
a_noise_partial = a_noise[:, :self.window_size] | |
with torch.enable_grad(): | |
for i in range(self.num_timesteps): | |
v_z = v_z.detach().requires_grad_(True) | |
tw = (self.num_timesteps - i)/self.num_timesteps | |
t = (window_indexes + tw)/self.window_size | |
t = 1-t | |
t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
t_rf = t*(self.T-self.eps) + self.eps | |
a_z = a_partial*t_partial + torch.randn_like(a_partial, device=v_z.device)*self.noise_scale*(1-t_partial) | |
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) | |
loss = torch.square((a_partial-a_noise_partial)-score_a) | |
grad = torch.autograd.grad(loss.mean(), v_z)[0] | |
v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale | |
v = v_z[:,0].detach() | |
v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale | |
v_z = torch.cat([v_z[:,1:],v_noise], dim=1) | |
yield v | |
return yield_frame | |
def sample_v2a(self, model, v, a_z, model_kwargs, scale=2, progress=True): | |
B = a_z.shape[0] | |
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=a_z.device).unsqueeze(0).repeat(B,1) | |
v_partial = v[:, :self.window_size] | |
v_noise = torch.randn_like(v, device=a_z.device)*self.noise_scale | |
v_noise_partial = v_noise[:, :self.window_size] | |
with torch.enable_grad(): | |
# warm up with different number of warmup timestep to be more precise | |
for i in tqdm(range(self.warmup_timesteps), disable=not progress): | |
a_z = a_z.detach().requires_grad_(True) | |
dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i) | |
v_z = v_partial*t_partial + v_noise_partial*(1-t_partial) | |
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) | |
loss = torch.square((v_partial-v_noise_partial)-score_v) | |
grad = torch.autograd.grad(loss.mean(), a_z)[0] | |
a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale | |
a_f = a_z[:,0].detach() | |
a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
def yield_frame(): | |
nonlocal v, v_noise, a_z, window_indexes | |
yield a_f | |
dt = 1/(self.num_timesteps*self.window_size) | |
while True: | |
torch.cuda.empty_cache() | |
v = v[:,1:] | |
v_noise = v_noise[:,1:] | |
if v.shape[1] < self.window_size: | |
v = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
v_noise = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1) | |
v_partial = v[:, :self.window_size] | |
v_noise_partial = v_noise[:, :self.window_size] | |
with torch.enable_grad(): | |
for i in range(self.num_timesteps): | |
a_z = a_z.detach().requires_grad_(True) | |
tw = (self.num_timesteps - i)/self.num_timesteps | |
t = (window_indexes + tw)/self.window_size | |
t = 1-t | |
t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
t_rf = t*(self.T-self.eps) + self.eps | |
v_z = v_partial*t_partial + v_noise_partial*(1-t_partial) | |
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs) | |
loss = torch.square((v_partial-v_noise_partial)-score_v) | |
grad = torch.autograd.grad(loss.mean(), a_z)[0] | |
a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale | |
a = a_z[:,0].detach() | |
a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale | |
a_z = torch.cat([a_z[:,1:],a_noise], dim=1) | |
yield a | |
return yield_frame | |
def calculate_prerolling_timestep(self, window_indexes, i): | |
tw = (self.warmup_timesteps - i)/self.warmup_timesteps | |
tw_future = (self.warmup_timesteps - (i+1))/self.warmup_timesteps | |
t = window_indexes/self.window_size + tw | |
#timestep for the next iteration, to calculate dt | |
t_future = window_indexes/self.window_size + tw_future | |
#Swap 0 with 1, 1 is full image, 0 is full noise | |
t = 1-t | |
t_future = 1 - t_future | |
t = torch.clamp(t, 0, 1) | |
t_future = torch.clamp(t_future, 0, 1) | |
dt = torch.abs(t_future-t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # [B, window_size, 1, 1, 1] | |
t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
t_rf= t*(self.T-self.eps) + self.eps | |
return dt,t_partial,t_rf | |
def get_z0(self, batch, train=True): | |
if self.init_type == 'gaussian': | |
### standard gaussian #+ 0.5 | |
return torch.randn(batch.shape)*self.noise_scale | |
else: | |
raise NotImplementedError("INITIALIZATION TYPE NOT IMPLEMENTED") |