R-FLAV / diffusion /rectified_flow.py
Alex Ergasti
Add RF
67d1f09
import torch
from tqdm import tqdm
import numpy as np
import math
class RectifiedFlow():
def __init__(self, num_timesteps, warmup_timesteps = 10, noise_scale=1.0, init_type='gaussian', eps=1, sampling='logit', window_size=8):
"""
eps: A `float` number. The smallest time step to sample from.
"""
self.num_timesteps = num_timesteps
self.warmup_timesteps = warmup_timesteps*num_timesteps
self.T = 1000.
self.noise_scale = noise_scale
self.init_type = init_type
self.eps = eps
self.window_size = window_size
self.sampling = sampling
def logit(self, x):
return torch.log(x / (1 - x))
def logit_normal(self, x, mu=0, sigma=1):
return 1 / (sigma * math.sqrt(2 * torch.pi) * x * (1 - x)) * torch.exp(-(self.logit(x) - mu) ** 2 / (2 * sigma ** 2))
def training_loss(self, model, v, a, model_kwargs):
"""
v: [B, T, C, H, W]
a: [B, T, N, F]
"""
B,T = v.shape[:2]
tw = torch.rand((v.shape[0],1), device=v.device)
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v.device).unsqueeze(0).repeat(B,1)
rollout = torch.bernoulli(torch.tensor(0.8).repeat(B).to(v.device)).bool()
t_rollout = (window_indexes+tw)/self.window_size
t_pre_rollout = window_indexes/self.window_size + tw
t = torch.where(rollout.unsqueeze(1).repeat(1,self.window_size), t_rollout, t_pre_rollout)
t = 1 - t # swap 0 and 1, since 1 is full image and 0 is full noise
t = torch.clamp(t, 0+1e-6, 1-1e-6)
if self.sampling == 'logit':
weigths = self.logit_normal(t, mu=0, sigma=1)
else:
weigths = torch.ones_like(t)
B, T = t.shape
v_z0 = self.get_z0(v).to(v.device)
a_z0 = self.get_z0(a).to(a.device)
t_video = t.view(B,T,1,1,1).repeat(1,1,v.shape[2], v.shape[3], v.shape[4])
t_audio = t.view(B,T,1,1,1).repeat(1,1,a.shape[2], a.shape[3], a.shape[4])
perturbed_video = t_video*v + (1-t_video)*v_z0
perturbed_audio = t_audio*a + (1-t_audio)*a_z0
t_rf = t*(self.T-self.eps) + self.eps
score_v, score_a = model(perturbed_video, perturbed_audio, t_rf, **model_kwargs)
# score_v = [B, T, C, H, W]
# score_a = [B, T, N, F]
target_video = v - v_z0 # direction of the flow
target_audio = a - a_z0 # direction of the flow
loss_video = torch.square(score_v-target_video)
loss_audio = torch.square(score_a-target_audio)
loss_video = torch.mean(loss_video, dim=[2,3,4])
loss_audio = torch.mean(loss_audio, dim=[2,3,4])
#mask out the loss for the time steps that are greater than T
loss_video = loss_video * (weigths)
loss_video = torch.mean(loss_video)
loss_audio = loss_audio * (weigths)
loss_audio = torch.mean(loss_audio)
return {"loss": (loss_video + loss_audio)}
def sample(self, model, v_z, a_z, model_kwargs, progress=True):
B = v_z.shape[0]
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1)
# warm up with different number of warmup timestep to be more precise
for i in tqdm(range(self.warmup_timesteps), disable=not progress):
dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i)
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
v_z = v_z.detach().clone() + dt*score_v
a_z = a_z.detach().clone() + dt*score_a
v_f = v_z[:,0]
a_f = a_z[:,0]
v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
def yield_frame():
nonlocal v_z, a_z, window_indexes
yield (v_f, a_f)
dt = 1/(self.num_timesteps*self.window_size)
while True:
for i in range(self.num_timesteps):
tw = (self.num_timesteps - i)/self.num_timesteps
t = (window_indexes + tw)/self.window_size
t = 1-t
t_rf = t*(self.T-self.eps) + self.eps
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
v_z = v_z.detach().clone() + dt*score_v
a_z = a_z.detach().clone() + dt*score_a
v = v_z[:,0]
a = a_z[:,0]
#remove the first element
v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale
a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale
v_z = torch.cat([v_z[:,1:],v_noise], dim=1)
a_z = torch.cat([a_z[:,1:],a_noise], dim=1)
yield (v, a)
return yield_frame
def sample_a2v(self, model, v_z, a, model_kwargs, scale=1, progress=True):
B = v_z.shape[0]
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=v_z.device).unsqueeze(0).repeat(B,1)
a_partial = a[:, :self.window_size]
a_noise = torch.randn_like(a, device=v_z.device)*self.noise_scale
a_noise_partial = a_noise[:, :self.window_size]
with torch.enable_grad():
# warm up with different number of warmup timestep to be more precise
for i in tqdm(range(self.warmup_timesteps), disable=not progress):
v_z = v_z.detach().requires_grad_(True)
dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i)
a_z = a_partial*t_partial + a_noise_partial*(1-t_partial)
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
loss = torch.square((a_partial-a_noise_partial)-score_a)
grad = torch.autograd.grad(loss.mean(), v_z)[0]
v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale
v_f = v_z[:,0].detach()
v_z = torch.cat([v_z[:,1:], torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
def yield_frame():
nonlocal v_z, a, a_noise, window_indexes
yield v_f
dt = 1/(self.num_timesteps*self.window_size)
while True:
torch.cuda.empty_cache()
a = a[:,1:]
a_noise = a_noise[:,1:]
if a.shape[1] < self.window_size:
a = torch.cat([a, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
a_noise = torch.cat([a_noise, torch.randn_like(a[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
a_partial = a[:, :self.window_size]
a_noise_partial = a_noise[:, :self.window_size]
with torch.enable_grad():
for i in range(self.num_timesteps):
v_z = v_z.detach().requires_grad_(True)
tw = (self.num_timesteps - i)/self.num_timesteps
t = (window_indexes + tw)/self.window_size
t = 1-t
t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
t_rf = t*(self.T-self.eps) + self.eps
a_z = a_partial*t_partial + torch.randn_like(a_partial, device=v_z.device)*self.noise_scale*(1-t_partial)
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
loss = torch.square((a_partial-a_noise_partial)-score_a)
grad = torch.autograd.grad(loss.mean(), v_z)[0]
v_z = v_z.detach() + dt*score_v - ((t_partial+dt)!=1) * dt * grad * scale
v = v_z[:,0].detach()
v_noise = torch.randn_like(v_z[:,0]).unsqueeze(1)*self.noise_scale
v_z = torch.cat([v_z[:,1:],v_noise], dim=1)
yield v
return yield_frame
def sample_v2a(self, model, v, a_z, model_kwargs, scale=2, progress=True):
B = a_z.shape[0]
window_indexes = torch.linspace(0, self.window_size-1, steps=self.window_size, device=a_z.device).unsqueeze(0).repeat(B,1)
v_partial = v[:, :self.window_size]
v_noise = torch.randn_like(v, device=a_z.device)*self.noise_scale
v_noise_partial = v_noise[:, :self.window_size]
with torch.enable_grad():
# warm up with different number of warmup timestep to be more precise
for i in tqdm(range(self.warmup_timesteps), disable=not progress):
a_z = a_z.detach().requires_grad_(True)
dt, t_partial, t_rf = self.calculate_prerolling_timestep(window_indexes, i)
v_z = v_partial*t_partial + v_noise_partial*(1-t_partial)
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
loss = torch.square((v_partial-v_noise_partial)-score_v)
grad = torch.autograd.grad(loss.mean(), a_z)[0]
a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale
a_f = a_z[:,0].detach()
a_z = torch.cat([a_z[:,1:], torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
def yield_frame():
nonlocal v, v_noise, a_z, window_indexes
yield a_f
dt = 1/(self.num_timesteps*self.window_size)
while True:
torch.cuda.empty_cache()
v = v[:,1:]
v_noise = v_noise[:,1:]
if v.shape[1] < self.window_size:
v = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
v_noise = torch.cat([v, torch.randn_like(v[:,0]).unsqueeze(1)*self.noise_scale], dim=1)
v_partial = v[:, :self.window_size]
v_noise_partial = v_noise[:, :self.window_size]
with torch.enable_grad():
for i in range(self.num_timesteps):
a_z = a_z.detach().requires_grad_(True)
tw = (self.num_timesteps - i)/self.num_timesteps
t = (window_indexes + tw)/self.window_size
t = 1-t
t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
t_rf = t*(self.T-self.eps) + self.eps
v_z = v_partial*t_partial + v_noise_partial*(1-t_partial)
score_v, score_a = model(v_z, a_z, t_rf, **model_kwargs)
loss = torch.square((v_partial-v_noise_partial)-score_v)
grad = torch.autograd.grad(loss.mean(), a_z)[0]
a_z = a_z.detach() + dt*score_a - ((t_partial + dt)!=1) * dt * grad * scale
a = a_z[:,0].detach()
a_noise = torch.randn_like(a_z[:,0]).unsqueeze(1)*self.noise_scale
a_z = torch.cat([a_z[:,1:],a_noise], dim=1)
yield a
return yield_frame
def calculate_prerolling_timestep(self, window_indexes, i):
tw = (self.warmup_timesteps - i)/self.warmup_timesteps
tw_future = (self.warmup_timesteps - (i+1))/self.warmup_timesteps
t = window_indexes/self.window_size + tw
#timestep for the next iteration, to calculate dt
t_future = window_indexes/self.window_size + tw_future
#Swap 0 with 1, 1 is full image, 0 is full noise
t = 1-t
t_future = 1 - t_future
t = torch.clamp(t, 0, 1)
t_future = torch.clamp(t_future, 0, 1)
dt = torch.abs(t_future-t).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # [B, window_size, 1, 1, 1]
t_partial = t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
t_rf= t*(self.T-self.eps) + self.eps
return dt,t_partial,t_rf
def get_z0(self, batch, train=True):
if self.init_type == 'gaussian':
### standard gaussian #+ 0.5
return torch.randn(batch.shape)*self.noise_scale
else:
raise NotImplementedError("INITIALIZATION TYPE NOT IMPLEMENTED")