File size: 2,841 Bytes
7a11626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pathlib import Path
import argparse
import yaml

import numpy as np
import torch

from ncsn.ncsnv2 import NCSNv2, NCSNv2Deeper, NCSNv2Deepest, get_sigmas
from ncsn.ema import EMAHelper

from adapt import ScoreAdapter

device = torch.device("cuda")


def get_model(config):
    if config.data.dataset == 'CIFAR10' or config.data.dataset == 'CELEBA':
        return NCSNv2(config).to(config.device)
    elif config.data.dataset == "FFHQ":
        return NCSNv2Deepest(config).to(config.device)
    elif config.data.dataset == 'LSUN':
        return NCSNv2Deeper(config).to(config.device)


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


class NCSN(ScoreAdapter):
    def __init__(self):
        config_fname = Path(__file__).resolve().parent / "ncsn" / "bedroom.yml"
        with config_fname.open("r") as f:
            config = yaml.safe_load(f)
            config = dict2namespace(config)

        config.device = device

        states = torch.load(
            self.checkpoint_root() / "ncsn/exp/logs/bedroom/checkpoint_150000.pth"
        )

        model = get_model(config)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(states[0], strict=True)

        if config.model.ema:
            ema_helper = EMAHelper(mu=config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            # HC: update the model param with history ema.
            # if don't do this the colors of images become strangely saturated.
            # this is reported in the paper.
            ema_helper.ema(model)

        model = model.module  # remove DataParallel
        model.eval()
        self.model = model
        self._data_shape = (3, config.data.image_size, config.data.image_size)

        self.σs = model.sigmas.cpu().numpy()
        self._device = device

    def data_shape(self):
        return self._data_shape

    def samps_centered(self):
        return False

    @property
    def σ_max(self):
        return self.σs[0]

    @property
    def σ_min(self):
        return self.σs[-1]

    @torch.no_grad()
    def denoise(self, xs, σ):
        σ, j = self.snap_t_to_nearest_tick(σ)
        N = xs.shape[0]
        cond_t = torch.tensor([j] * N, dtype=torch.long, device=self.device)
        score = self.model(xs, cond_t)
        Ds = xs + score * (σ ** 2)
        return Ds

    def unet_is_cond(self):
        return False

    def use_cls_guidance(self):
        return False

    def snap_t_to_nearest_tick(self, t):
        j = np.abs(t - self.σs).argmin()
        return self.σs[j], j