Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torch import nn, Tensor | |
from typing import Dict, Tuple, Optional | |
from argparse import Namespace | |
from einops import repeat | |
from einops.layers.torch import Rearrange | |
from functools import partial | |
from models.diffusion_model import DiffusionModel | |
from trainers.utils import compare_configs | |
# Hooks code inspired by https://www.lyndonduong.com/saving-activations/ | |
# Accessed on 13Feb23 | |
def save_activations( | |
activations: Dict, | |
name: str, | |
module: nn.Module, | |
inp: Tuple, | |
out: torch.Tensor | |
) -> None: | |
"""PyTorch Forward hook to save outputs at each forward | |
pass. Mutates specified dict objects with each fwd pass. | |
""" | |
#activations[name].append(out.detach().cpu()) | |
activations[name] = out.detach().cpu() | |
class DatasetDM(nn.Module): | |
def __init__(self, args: Namespace) -> None: | |
super().__init__() | |
# Load the model | |
if not os.path.isfile(args.saved_diffusion_model): | |
self.diffusion_model = DiffusionModel(args) | |
if args.verbose: | |
print(f'No model found at {args.saved_diffusion_model}. Please load model!') | |
else: | |
checkpoint = torch.load(args.saved_diffusion_model, map_location=torch.device(args.device)) | |
old_config = checkpoint['config'] | |
compare_configs(old_config, args) | |
self.diffusion_model = DiffusionModel(old_config) | |
self.diffusion_model.load_state_dict(checkpoint['model_state_dict']) | |
self.diffusion_model.eval() | |
# storage for saved activations | |
self._features = {} | |
# Note that this only works for the model in model.py | |
for i, (block1, block2, attn, upsample) in enumerate(self.diffusion_model.model.ups): | |
attn.register_forward_hook( | |
partial(save_activations, self._features, i) | |
) | |
self.steps = args.t_steps_to_save | |
self.classifier = nn.Sequential( | |
nn.Conv2d(960 * len(self.steps), 128, 1), | |
nn.ReLU(), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(128, 32, 1), | |
nn.ReLU(), | |
nn.BatchNorm2d(32), | |
nn.Conv2d(32, 1, 1)) | |
def extract_features(self, x_0: Tensor, noise: Optional[Tensor] = None) -> Dict[int, Tensor]: | |
if noise is not None: | |
assert(x_0.shape == noise.shape) | |
activations=[] | |
for t_step in self.steps: | |
# Add t_steps of noise to x_0 - forward process | |
t_step = torch.Tensor([t_step]).long().to(x_0.device) | |
t_step = repeat(t_step, '1 -> b', b=x_0.shape[0]) | |
x_t, _ = self.diffusion_model.forward_diffusion_model(x_0=x_0, t=t_step, noise=noise) | |
# Remove one step of noise from x_t - backward process | |
_ = self.diffusion_model.model(x_t, t_step) | |
# Resize features so that they all live in the image space | |
for idx in self._features: | |
activations.append(nn.functional.interpolate(self._features[idx], size=[x_0.shape[-1]] * 2)) | |
# Return activations | |
return torch.cat(activations, dim=1) | |
def forward(self, x: Tensor) -> Tensor: | |
features = self.extract_features(x).to(x.device) | |
out = self.classifier(features) | |
return out | |