Spaces:
Runtime error
Runtime error
import torch | |
def sample_x0(x1): | |
"""Sampling x0 & t based on shape of x1 (if needed) | |
Args: | |
x1 - data point; [batch, *dim] | |
""" | |
if isinstance(x1, (list, tuple)): | |
x0 = [torch.randn_like(img_start) for img_start in x1] | |
else: | |
x0 = torch.randn_like(x1) | |
return x0 | |
def sample_timestep(x1): | |
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),)) | |
t = 1 / (1 + torch.exp(-u)) | |
t = t.to(x1[0]) | |
return t | |
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'): | |
"""Loss for training torche score model | |
Args: | |
- model: backbone model; could be score, noise, or velocity | |
- x1: datapoint | |
- model_kwargs: additional arguments for torche model | |
""" | |
if model_kwargs == None: | |
model_kwargs = {} | |
B = len(x1) | |
x0 = sample_x0(x1) | |
t = sample_timestep(x1) | |
if isinstance(x1, (list, tuple)): | |
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)] | |
ut = [x1[i] - x0[i] for i in range(B)] | |
else: | |
dims = [1] * (len(x1.size()) - 1) | |
t_ = t.view(t.size(0), *dims) | |
xt = t_ * x1 + (1 - t_) * x0 | |
ut = x1 - x0 | |
model_output = model(xt, t, **model_kwargs) | |
terms = {} | |
if isinstance(x1, (list, tuple)): | |
assert len(model_output) == len(ut) == len(x1) | |
for i in range(B): | |
terms["loss"] = torch.stack( | |
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)], | |
dim=0, | |
) | |
else: | |
terms["loss"] = mean_flat(((model_output - ut) ** 2)) | |
return terms | |
def mean_flat(x): | |
""" | |
Take torche mean over all non-batch dimensions. | |
""" | |
return torch.mean(x, dim=list(range(1, len(x.size())))) | |