|
from math import sqrt, log |
|
from omegaconf import OmegaConf |
|
import importlib |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
def load_model(path): |
|
with open(path, "rb") as f: |
|
return torch.load(f, map_location=torch.device("cpu")) |
|
|
|
|
|
def map_pixels(x, eps=0.1): |
|
return (1 - 2 * eps) * x + eps |
|
|
|
|
|
def unmap_pixels(x, eps=0.1): |
|
return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1) |
|
|
|
|
|
def make_contiguous(module): |
|
with torch.no_grad(): |
|
for param in module.parameters(): |
|
param.set_(param.contiguous()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_obj_from_str(string, reload=False): |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
def instantiate_from_config(config): |
|
if not "target" in config: |
|
raise KeyError("Expected key `target` to instantiate.") |
|
return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
|
class VQGanVAE(nn.Module): |
|
def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels=1): |
|
super().__init__() |
|
|
|
assert vqgan_config_path is not None |
|
|
|
model_path = vqgan_model_path |
|
config_path = vqgan_config_path |
|
|
|
config = OmegaConf.load(config_path) |
|
|
|
model = instantiate_from_config(config["model"]) |
|
|
|
if vqgan_model_path: |
|
|
|
state = torch.load(model_path, map_location="cpu")["state_dict"] |
|
model.load_state_dict(state, strict=True) |
|
|
|
print(f"Loaded VQGAN from {model_path} and {config_path}") |
|
|
|
self.model = model |
|
|
|
|
|
f = ( |
|
config.model.params.ddconfig.resolution |
|
/ config.model.params.ddconfig.attn_resolutions[0] |
|
) |
|
self.num_layers = int(log(f) / log(2)) |
|
self.image_size = config.model.params.ddconfig.resolution |
|
self.num_tokens = config.model.params.n_embed |
|
|
|
self.is_gumbel = False |
|
self.channels = config.model.params.ddconfig.in_channels |
|
|
|
def encode(self, img): |
|
return self.model.encode(img) |
|
|
|
def get_codebook_indices(self, img): |
|
b = img.shape[0] |
|
|
|
_, _, [_, _, indices] = self.encode(img) |
|
if self.is_gumbel: |
|
return rearrange(indices, "b h w -> b (h w)", b=b) |
|
return rearrange(indices, "(b n) -> b n", b=b) |
|
|
|
def decode(self, img_seq): |
|
b, n = img_seq.shape |
|
one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float() |
|
z = ( |
|
one_hot_indices @ self.model.quantize.embed.weight |
|
if self.is_gumbel |
|
else (one_hot_indices @ self.model.quantize.embedding.weight) |
|
) |
|
|
|
z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n))) |
|
img = self.model.decode(z) |
|
|
|
|
|
return img |
|
|
|
def forward(self, img, optimizer_idx=1): |
|
return self.model.training_step(img, optimizer_idx=optimizer_idx) |
|
|