sjc / adapt_ncsn.py
amankishore's picture
Updated app.py
7a11626
from pathlib import Path
import argparse
import yaml
import numpy as np
import torch
from ncsn.ncsnv2 import NCSNv2, NCSNv2Deeper, NCSNv2Deepest, get_sigmas
from ncsn.ema import EMAHelper
from adapt import ScoreAdapter
device = torch.device("cuda")
def get_model(config):
if config.data.dataset == 'CIFAR10' or config.data.dataset == 'CELEBA':
return NCSNv2(config).to(config.device)
elif config.data.dataset == "FFHQ":
return NCSNv2Deepest(config).to(config.device)
elif config.data.dataset == 'LSUN':
return NCSNv2Deeper(config).to(config.device)
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
class NCSN(ScoreAdapter):
def __init__(self):
config_fname = Path(__file__).resolve().parent / "ncsn" / "bedroom.yml"
with config_fname.open("r") as f:
config = yaml.safe_load(f)
config = dict2namespace(config)
config.device = device
states = torch.load(
self.checkpoint_root() / "ncsn/exp/logs/bedroom/checkpoint_150000.pth"
)
model = get_model(config)
model = torch.nn.DataParallel(model)
model.load_state_dict(states[0], strict=True)
if config.model.ema:
ema_helper = EMAHelper(mu=config.model.ema_rate)
ema_helper.register(model)
ema_helper.load_state_dict(states[-1])
# HC: update the model param with history ema.
# if don't do this the colors of images become strangely saturated.
# this is reported in the paper.
ema_helper.ema(model)
model = model.module # remove DataParallel
model.eval()
self.model = model
self._data_shape = (3, config.data.image_size, config.data.image_size)
self.σs = model.sigmas.cpu().numpy()
self._device = device
def data_shape(self):
return self._data_shape
def samps_centered(self):
return False
@property
def σ_max(self):
return self.σs[0]
@property
def σ_min(self):
return self.σs[-1]
@torch.no_grad()
def denoise(self, xs, σ):
σ, j = self.snap_t_to_nearest_tick(σ)
N = xs.shape[0]
cond_t = torch.tensor([j] * N, dtype=torch.long, device=self.device)
score = self.model(xs, cond_t)
Ds = xs + score * (σ ** 2)
return Ds
def unet_is_cond(self):
return False
def use_cls_guidance(self):
return False
def snap_t_to_nearest_tick(self, t):
j = np.abs(t - self.σs).argmin()
return self.σs[j], j