File size: 3,340 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import torch
from torch import nn, Tensor
from typing import Dict, Tuple, Optional
from argparse import Namespace
from einops import repeat
from einops.layers.torch import Rearrange
from functools import partial
from models.diffusion_model import DiffusionModel
from trainers.utils import compare_configs



# Hooks code inspired by https://www.lyndonduong.com/saving-activations/
# Accessed on 13Feb23
def save_activations(
        activations: Dict,
        name: str,
        module: nn.Module,
        inp: Tuple,
        out: torch.Tensor
        ) -> None:
    """PyTorch Forward hook to save outputs at each forward
    pass. Mutates specified dict objects with each fwd pass.
    """
    #activations[name].append(out.detach().cpu())
    activations[name] = out.detach().cpu()


class DatasetDM(nn.Module):
    def __init__(self, args: Namespace) -> None:
        super().__init__()
        # Load the model
        if not os.path.isfile(args.saved_diffusion_model):
            self.diffusion_model = DiffusionModel(args)
            if args.verbose:
                print(f'No model found at {args.saved_diffusion_model}. Please load model!')
        else:
            checkpoint = torch.load(args.saved_diffusion_model, map_location=torch.device(args.device))
            old_config = checkpoint['config']
            compare_configs(old_config, args)
            self.diffusion_model = DiffusionModel(old_config)
            self.diffusion_model.load_state_dict(checkpoint['model_state_dict'])
        self.diffusion_model.eval()

        # storage for saved activations
        self._features = {}

        # Note that this only works for the model in model.py
        for i, (block1, block2, attn, upsample) in enumerate(self.diffusion_model.model.ups):
            attn.register_forward_hook(
                partial(save_activations, self._features, i)
            )

        self.steps = args.t_steps_to_save

        self.classifier = nn.Sequential(
            nn.Conv2d(960 * len(self.steps), 128, 1), 
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 32, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, 1))
        

    @torch.no_grad()
    def extract_features(self, x_0: Tensor, noise: Optional[Tensor] = None) -> Dict[int, Tensor]:
        if noise is not None:
            assert(x_0.shape == noise.shape)
        activations=[]
        for t_step in self.steps:
            # Add t_steps of noise to x_0 - forward process
            t_step = torch.Tensor([t_step]).long().to(x_0.device)
            t_step = repeat(t_step, '1 -> b', b=x_0.shape[0])
            x_t, _ = self.diffusion_model.forward_diffusion_model(x_0=x_0, t=t_step, noise=noise)
            # Remove one step of noise from x_t - backward process
            _ = self.diffusion_model.model(x_t, t_step)
            # Resize features so that they all live in the image space
            for idx in self._features:
                activations.append(nn.functional.interpolate(self._features[idx], size=[x_0.shape[-1]] * 2))
            # Return activations
        return torch.cat(activations, dim=1)

    def forward(self, x: Tensor) -> Tensor:
        features = self.extract_features(x).to(x.device)
        out = self.classifier(features)
        return out