Spaces:
Runtime error
Runtime error
| import math | |
| from contextlib import nullcontext | |
| from functools import partial | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import kornia | |
| import numpy as np | |
| import open_clip | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| from omegaconf import ListConfig | |
| # from torch.utils.checkpoint import checkpoint | |
| checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) | |
| from transformers import ( | |
| ByT5Tokenizer, | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| T5EncoderModel, | |
| T5Tokenizer, | |
| ) | |
| from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer | |
| from ...modules.diffusionmodules.model import Encoder | |
| from ...modules.diffusionmodules.openaimodel import Timestep | |
| from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule | |
| from ...modules.distributions.distributions import DiagonalGaussianDistribution | |
| from ...util import ( | |
| append_dims, | |
| autocast, | |
| count_params, | |
| default, | |
| disabled_train, | |
| expand_dims_like, | |
| instantiate_from_config, | |
| ) | |
| class AbstractEmbModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self._is_trainable = None | |
| self._ucg_rate = None | |
| self._input_key = None | |
| def is_trainable(self) -> bool: | |
| return self._is_trainable | |
| def ucg_rate(self) -> Union[float, torch.Tensor]: | |
| return self._ucg_rate | |
| def input_key(self) -> str: | |
| return self._input_key | |
| def is_trainable(self, value: bool): | |
| self._is_trainable = value | |
| def ucg_rate(self, value: Union[float, torch.Tensor]): | |
| self._ucg_rate = value | |
| def input_key(self, value: str): | |
| self._input_key = value | |
| def is_trainable(self): | |
| del self._is_trainable | |
| def ucg_rate(self): | |
| del self._ucg_rate | |
| def input_key(self): | |
| del self._input_key | |
| class GeneralConditioner(nn.Module): | |
| OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} | |
| KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} | |
| def __init__(self, emb_models: Union[List, ListConfig]): | |
| super().__init__() | |
| embedders = [] | |
| for n, embconfig in enumerate(emb_models): | |
| embedder = instantiate_from_config(embconfig) | |
| assert isinstance( | |
| embedder, AbstractEmbModel | |
| ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" | |
| embedder.is_trainable = embconfig.get("is_trainable", False) | |
| embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) | |
| if not embedder.is_trainable: | |
| embedder.train = disabled_train | |
| for param in embedder.parameters(): | |
| param.requires_grad = False | |
| embedder.eval() | |
| print( | |
| f"Initialized embedder #{n}: {embedder.__class__.__name__} " | |
| f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" | |
| ) | |
| if "input_key" in embconfig: | |
| embedder.input_key = embconfig["input_key"] | |
| elif "input_keys" in embconfig: | |
| embedder.input_keys = embconfig["input_keys"] | |
| else: | |
| raise KeyError( | |
| f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" | |
| ) | |
| embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) | |
| if embedder.legacy_ucg_val is not None: | |
| embedder.ucg_prng = np.random.RandomState() | |
| embedders.append(embedder) | |
| self.embedders = nn.ModuleList(embedders) | |
| def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: | |
| assert embedder.legacy_ucg_val is not None | |
| p = embedder.ucg_rate | |
| val = embedder.legacy_ucg_val | |
| for i in range(len(batch[embedder.input_key])): | |
| if embedder.ucg_prng.choice(2, p=[1 - p, p]): | |
| batch[embedder.input_key][i] = val | |
| return batch | |
| def forward( | |
| self, batch: Dict, force_zero_embeddings: Optional[List] = None | |
| ) -> Dict: | |
| output = dict() | |
| if force_zero_embeddings is None: | |
| force_zero_embeddings = [] | |
| for embedder in self.embedders: | |
| embedding_context = nullcontext if embedder.is_trainable else torch.no_grad | |
| with embedding_context(): | |
| if hasattr(embedder, "input_key") and (embedder.input_key is not None): | |
| if embedder.legacy_ucg_val is not None: | |
| batch = self.possibly_get_ucg_val(embedder, batch) | |
| emb_out = embedder(batch[embedder.input_key]) | |
| elif hasattr(embedder, "input_keys"): | |
| emb_out = embedder(*[batch[k] for k in embedder.input_keys]) | |
| assert isinstance( | |
| emb_out, (torch.Tensor, list, tuple) | |
| ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" | |
| if not isinstance(emb_out, (list, tuple)): | |
| emb_out = [emb_out] | |
| for emb in emb_out: | |
| out_key = self.OUTPUT_DIM2KEYS[emb.dim()] | |
| if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: | |
| emb = ( | |
| expand_dims_like( | |
| torch.bernoulli( | |
| (1.0 - embedder.ucg_rate) | |
| * torch.ones(emb.shape[0], device=emb.device) | |
| ), | |
| emb, | |
| ) | |
| * emb | |
| ) | |
| if ( | |
| hasattr(embedder, "input_key") | |
| and embedder.input_key in force_zero_embeddings | |
| ): | |
| emb = torch.zeros_like(emb) | |
| if out_key in output: | |
| output[out_key] = torch.cat( | |
| (output[out_key], emb), self.KEY2CATDIM[out_key] | |
| ) | |
| else: | |
| output[out_key] = emb | |
| # if "num_video_frames" in batch: | |
| # num_frames = batch["num_video_frames"] | |
| # for k in ["crossattn", "concat"]: | |
| # output[k] = repeat(output[k], "b ... -> b t ...", t=num_frames) | |
| # output[k] = rearrange(output[k], "b t ... -> (b t) ...", t=num_frames) | |
| return output | |
| def get_unconditional_conditioning( | |
| self, | |
| batch_c: Dict, | |
| batch_uc: Optional[Dict] = None, | |
| force_uc_zero_embeddings: Optional[List[str]] = None, | |
| force_cond_zero_embeddings: Optional[List[str]] = None, | |
| ): | |
| if force_uc_zero_embeddings is None: | |
| force_uc_zero_embeddings = [] | |
| ucg_rates = list() | |
| for embedder in self.embedders: | |
| ucg_rates.append(embedder.ucg_rate) | |
| embedder.ucg_rate = 0.0 | |
| c = self(batch_c, force_cond_zero_embeddings) | |
| uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) | |
| for embedder, rate in zip(self.embedders, ucg_rates): | |
| embedder.ucg_rate = rate | |
| return c, uc | |
| class InceptionV3(nn.Module): | |
| """Wrapper around the https://github.com/mseitzer/pytorch-fid inception | |
| port with an additional squeeze at the end""" | |
| def __init__(self, normalize_input=False, **kwargs): | |
| super().__init__() | |
| from pytorch_fid import inception | |
| kwargs["resize_input"] = True | |
| self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) | |
| def forward(self, inp): | |
| outp = self.model(inp) | |
| if len(outp) == 1: | |
| return outp[0].squeeze() | |
| return outp | |
| class IdentityEncoder(AbstractEmbModel): | |
| def encode(self, x): | |
| return x | |
| def forward(self, x): | |
| return x | |
| class ClassEmbedder(AbstractEmbModel): | |
| def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): | |
| super().__init__() | |
| self.embedding = nn.Embedding(n_classes, embed_dim) | |
| self.n_classes = n_classes | |
| self.add_sequence_dim = add_sequence_dim | |
| def forward(self, c): | |
| c = self.embedding(c) | |
| if self.add_sequence_dim: | |
| c = c[:, None, :] | |
| return c | |
| def get_unconditional_conditioning(self, bs, device="cuda"): | |
| uc_class = ( | |
| self.n_classes - 1 | |
| ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) | |
| uc = torch.ones((bs,), device=device) * uc_class | |
| uc = {self.key: uc.long()} | |
| return uc | |
| class ClassEmbedderForMultiCond(ClassEmbedder): | |
| def forward(self, batch, key=None, disable_dropout=False): | |
| out = batch | |
| key = default(key, self.key) | |
| islist = isinstance(batch[key], list) | |
| if islist: | |
| batch[key] = batch[key][0] | |
| c_out = super().forward(batch, key, disable_dropout) | |
| out[key] = [c_out] if islist else c_out | |
| return out | |
| class FrozenT5Embedder(AbstractEmbModel): | |
| """Uses the T5 transformer encoder for text""" | |
| def __init__( | |
| self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True | |
| ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
| super().__init__() | |
| self.tokenizer = T5Tokenizer.from_pretrained(version) | |
| self.transformer = T5EncoderModel.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| with torch.autocast("cuda", enabled=False): | |
| outputs = self.transformer(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| return z | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenByT5Embedder(AbstractEmbModel): | |
| """ | |
| Uses the ByT5 transformer encoder for text. Is character-aware. | |
| """ | |
| def __init__( | |
| self, version="google/byt5-base", device="cuda", max_length=77, freeze=True | |
| ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
| super().__init__() | |
| self.tokenizer = ByT5Tokenizer.from_pretrained(version) | |
| self.transformer = T5EncoderModel.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| with torch.autocast("cuda", enabled=False): | |
| outputs = self.transformer(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| return z | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenCLIPEmbedder(AbstractEmbModel): | |
| """Uses the CLIP transformer encoder for text (from huggingface)""" | |
| LAYERS = ["last", "pooled", "hidden"] | |
| def __init__( | |
| self, | |
| version="openai/clip-vit-large-patch14", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| layer="last", | |
| layer_idx=None, | |
| always_return_pooled=False, | |
| ): # clip-vit-base-patch32 | |
| super().__init__() | |
| assert layer in self.LAYERS | |
| self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
| self.transformer = CLIPTextModel.from_pretrained(version) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| self.layer = layer | |
| self.layer_idx = layer_idx | |
| self.return_pooled = always_return_pooled | |
| if layer == "hidden": | |
| assert layer_idx is not None | |
| assert 0 <= abs(layer_idx) <= 12 | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| outputs = self.transformer( | |
| input_ids=tokens, output_hidden_states=self.layer == "hidden" | |
| ) | |
| if self.layer == "last": | |
| z = outputs.last_hidden_state | |
| elif self.layer == "pooled": | |
| z = outputs.pooler_output[:, None, :] | |
| else: | |
| z = outputs.hidden_states[self.layer_idx] | |
| if self.return_pooled: | |
| return z, outputs.pooler_output | |
| return z | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenOpenCLIPEmbedder2(AbstractEmbModel): | |
| """ | |
| Uses the OpenCLIP transformer encoder for text | |
| """ | |
| LAYERS = ["pooled", "last", "penultimate"] | |
| def __init__( | |
| self, | |
| arch="ViT-H-14", | |
| version="laion2b_s32b_b79k", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| layer="last", | |
| always_return_pooled=False, | |
| legacy=True, | |
| ): | |
| super().__init__() | |
| assert layer in self.LAYERS | |
| model, _, _ = open_clip.create_model_and_transforms( | |
| arch, | |
| device=torch.device("cpu"), | |
| pretrained=version, | |
| ) | |
| del model.visual | |
| self.model = model | |
| self.device = device | |
| self.max_length = max_length | |
| self.return_pooled = always_return_pooled | |
| if freeze: | |
| self.freeze() | |
| self.layer = layer | |
| if self.layer == "last": | |
| self.layer_idx = 0 | |
| elif self.layer == "penultimate": | |
| self.layer_idx = 1 | |
| else: | |
| raise NotImplementedError() | |
| self.legacy = legacy | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| tokens = open_clip.tokenize(text) | |
| z = self.encode_with_transformer(tokens.to(self.device)) | |
| if not self.return_pooled and self.legacy: | |
| return z | |
| if self.return_pooled: | |
| assert not self.legacy | |
| return z[self.layer], z["pooled"] | |
| return z[self.layer] | |
| def encode_with_transformer(self, text): | |
| x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
| x = x + self.model.positional_embedding | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
| if self.legacy: | |
| x = x[self.layer] | |
| x = self.model.ln_final(x) | |
| return x | |
| else: | |
| # x is a dict and will stay a dict | |
| o = x["last"] | |
| o = self.model.ln_final(o) | |
| pooled = self.pool(o, text) | |
| x["pooled"] = pooled | |
| return x | |
| def pool(self, x, text): | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| x = ( | |
| x[torch.arange(x.shape[0]), text.argmax(dim=-1)] | |
| ) | |
| return x | |
| def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): | |
| outputs = {} | |
| for i, r in enumerate(self.model.transformer.resblocks): | |
| if i == len(self.model.transformer.resblocks) - 1: | |
| outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD | |
| if ( | |
| self.model.transformer.grad_checkpointing | |
| and not torch.jit.is_scripting() | |
| ): | |
| x = checkpoint(r, x, attn_mask) | |
| else: | |
| x = r(x, attn_mask=attn_mask) | |
| outputs["last"] = x.permute(1, 0, 2) # LND -> NLD | |
| return outputs | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenOpenCLIPEmbedder(AbstractEmbModel): | |
| LAYERS = [ | |
| # "pooled", | |
| "last", | |
| "penultimate", | |
| ] | |
| def __init__( | |
| self, | |
| arch="ViT-H-14", | |
| version="laion2b_s32b_b79k", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| layer="last", | |
| ): | |
| super().__init__() | |
| assert layer in self.LAYERS | |
| model, _, _ = open_clip.create_model_and_transforms( | |
| arch, | |
| device=torch.device("cpu"), | |
| pretrained=version, | |
| ) | |
| del model.visual | |
| self.model = model | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| self.layer = layer | |
| if self.layer == "last": | |
| self.layer_idx = 0 | |
| elif self.layer == "penultimate": | |
| self.layer_idx = 1 | |
| else: | |
| raise NotImplementedError() | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| tokens = open_clip.tokenize(text) | |
| z = self.encode_with_transformer(tokens.to(self.device)) | |
| return z | |
| def encode_with_transformer(self, text): | |
| x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] | |
| x = x + self.model.positional_embedding | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.model.ln_final(x) | |
| return x | |
| def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): | |
| for i, r in enumerate(self.model.transformer.resblocks): | |
| if i == len(self.model.transformer.resblocks) - self.layer_idx: | |
| break | |
| if ( | |
| self.model.transformer.grad_checkpointing | |
| and not torch.jit.is_scripting() | |
| ): | |
| x = checkpoint(r, x, attn_mask) | |
| else: | |
| x = r(x, attn_mask=attn_mask) | |
| return x | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): | |
| """ | |
| Uses the OpenCLIP vision transformer encoder for images | |
| """ | |
| def __init__( | |
| self, | |
| arch="ViT-H-14", | |
| version="laion2b_s32b_b79k", | |
| device="cuda", | |
| max_length=77, | |
| freeze=True, | |
| antialias=True, | |
| ucg_rate=0.0, | |
| unsqueeze_dim=False, | |
| repeat_to_max_len=False, | |
| num_image_crops=0, | |
| output_tokens=False, | |
| init_device=None, | |
| ): | |
| super().__init__() | |
| model, _, _ = open_clip.create_model_and_transforms( | |
| arch, | |
| device=torch.device(default(init_device, "cpu")), | |
| pretrained=version, | |
| ) | |
| del model.transformer | |
| self.model = model | |
| self.max_crops = num_image_crops | |
| self.pad_to_max_len = self.max_crops > 0 | |
| self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| self.antialias = antialias | |
| self.register_buffer( | |
| "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False | |
| ) | |
| self.register_buffer( | |
| "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False | |
| ) | |
| self.ucg_rate = ucg_rate | |
| self.unsqueeze_dim = unsqueeze_dim | |
| self.stored_batch = None | |
| self.model.visual.output_tokens = output_tokens | |
| self.output_tokens = output_tokens | |
| def preprocess(self, x): | |
| # normalize to [0,1] | |
| x = kornia.geometry.resize( | |
| x, | |
| (224, 224), | |
| interpolation="bicubic", | |
| align_corners=True, | |
| antialias=self.antialias, | |
| ) | |
| x = (x + 1.0) / 2.0 | |
| # renormalize according to clip | |
| x = kornia.enhance.normalize(x, self.mean, self.std) | |
| return x | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, image, no_dropout=False): | |
| z = self.encode_with_vision_transformer(image) | |
| tokens = None | |
| if self.output_tokens: | |
| z, tokens = z[0], z[1] | |
| z = z.to(image.dtype) | |
| if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): | |
| z = ( | |
| torch.bernoulli( | |
| (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) | |
| )[:, None] | |
| * z | |
| ) | |
| if tokens is not None: | |
| tokens = ( | |
| expand_dims_like( | |
| torch.bernoulli( | |
| (1.0 - self.ucg_rate) | |
| * torch.ones(tokens.shape[0], device=tokens.device) | |
| ), | |
| tokens, | |
| ) | |
| * tokens | |
| ) | |
| if self.unsqueeze_dim: | |
| z = z[:, None, :] | |
| if self.output_tokens: | |
| assert not self.repeat_to_max_len | |
| assert not self.pad_to_max_len | |
| return tokens, z | |
| if self.repeat_to_max_len: | |
| if z.dim() == 2: | |
| z_ = z[:, None, :] | |
| else: | |
| z_ = z | |
| return repeat(z_, "b 1 d -> b n d", n=self.max_length), z | |
| elif self.pad_to_max_len: | |
| assert z.dim() == 3 | |
| z_pad = torch.cat( | |
| ( | |
| z, | |
| torch.zeros( | |
| z.shape[0], | |
| self.max_length - z.shape[1], | |
| z.shape[2], | |
| device=z.device, | |
| ), | |
| ), | |
| 1, | |
| ) | |
| return z_pad, z_pad[:, 0, ...] | |
| return z | |
| def encode_with_vision_transformer(self, img): | |
| # if self.max_crops > 0: | |
| # img = self.preprocess_by_cropping(img) | |
| if img.dim() == 5: | |
| assert self.max_crops == img.shape[1] | |
| img = rearrange(img, "b n c h w -> (b n) c h w") | |
| img = self.preprocess(img) | |
| if not self.output_tokens: | |
| assert not self.model.visual.output_tokens | |
| x = self.model.visual(img) | |
| tokens = None | |
| else: | |
| assert self.model.visual.output_tokens | |
| x, tokens = self.model.visual(img) | |
| if self.max_crops > 0: | |
| x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) | |
| # drop out between 0 and all along the sequence axis | |
| x = ( | |
| torch.bernoulli( | |
| (1.0 - self.ucg_rate) | |
| * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) | |
| ) | |
| * x | |
| ) | |
| if tokens is not None: | |
| tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) | |
| print( | |
| f"You are running very experimental token-concat in {self.__class__.__name__}. " | |
| f"Check what you are doing, and then remove this message." | |
| ) | |
| if self.output_tokens: | |
| return x, tokens | |
| return x | |
| def encode(self, text): | |
| return self(text) | |
| class FrozenCLIPT5Encoder(AbstractEmbModel): | |
| def __init__( | |
| self, | |
| clip_version="openai/clip-vit-large-patch14", | |
| t5_version="google/t5-v1_1-xl", | |
| device="cuda", | |
| clip_max_length=77, | |
| t5_max_length=77, | |
| ): | |
| super().__init__() | |
| self.clip_encoder = FrozenCLIPEmbedder( | |
| clip_version, device, max_length=clip_max_length | |
| ) | |
| self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) | |
| print( | |
| f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " | |
| f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." | |
| ) | |
| def encode(self, text): | |
| return self(text) | |
| def forward(self, text): | |
| clip_z = self.clip_encoder.encode(text) | |
| t5_z = self.t5_encoder.encode(text) | |
| return [clip_z, t5_z] | |
| class SpatialRescaler(nn.Module): | |
| def __init__( | |
| self, | |
| n_stages=1, | |
| method="bilinear", | |
| multiplier=0.5, | |
| in_channels=3, | |
| out_channels=None, | |
| bias=False, | |
| wrap_video=False, | |
| kernel_size=1, | |
| remap_output=False, | |
| ): | |
| super().__init__() | |
| self.n_stages = n_stages | |
| assert self.n_stages >= 0 | |
| assert method in [ | |
| "nearest", | |
| "linear", | |
| "bilinear", | |
| "trilinear", | |
| "bicubic", | |
| "area", | |
| ] | |
| self.multiplier = multiplier | |
| self.interpolator = partial(torch.nn.functional.interpolate, mode=method) | |
| self.remap_output = out_channels is not None or remap_output | |
| if self.remap_output: | |
| print( | |
| f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." | |
| ) | |
| self.channel_mapper = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| bias=bias, | |
| padding=kernel_size // 2, | |
| ) | |
| self.wrap_video = wrap_video | |
| def forward(self, x): | |
| if self.wrap_video and x.ndim == 5: | |
| B, C, T, H, W = x.shape | |
| x = rearrange(x, "b c t h w -> b t c h w") | |
| x = rearrange(x, "b t c h w -> (b t) c h w") | |
| for stage in range(self.n_stages): | |
| x = self.interpolator(x, scale_factor=self.multiplier) | |
| if self.wrap_video: | |
| x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) | |
| x = rearrange(x, "b t c h w -> b c t h w") | |
| if self.remap_output: | |
| x = self.channel_mapper(x) | |
| return x | |
| def encode(self, x): | |
| return self(x) | |
| class LowScaleEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| model_config, | |
| linear_start, | |
| linear_end, | |
| timesteps=1000, | |
| max_noise_level=250, | |
| output_size=64, | |
| scale_factor=1.0, | |
| ): | |
| super().__init__() | |
| self.max_noise_level = max_noise_level | |
| self.model = instantiate_from_config(model_config) | |
| self.augmentation_schedule = self.register_schedule( | |
| timesteps=timesteps, linear_start=linear_start, linear_end=linear_end | |
| ) | |
| self.out_size = output_size | |
| self.scale_factor = scale_factor | |
| def register_schedule( | |
| self, | |
| beta_schedule="linear", | |
| timesteps=1000, | |
| linear_start=1e-4, | |
| linear_end=2e-2, | |
| cosine_s=8e-3, | |
| ): | |
| betas = make_beta_schedule( | |
| beta_schedule, | |
| timesteps, | |
| linear_start=linear_start, | |
| linear_end=linear_end, | |
| cosine_s=cosine_s, | |
| ) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) | |
| (timesteps,) = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.linear_start = linear_start | |
| self.linear_end = linear_end | |
| assert ( | |
| alphas_cumprod.shape[0] == self.num_timesteps | |
| ), "alphas have to be defined for each timestep" | |
| to_torch = partial(torch.tensor, dtype=torch.float32) | |
| self.register_buffer("betas", to_torch(betas)) | |
| self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) | |
| self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) | |
| self.register_buffer( | |
| "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) | |
| ) | |
| self.register_buffer( | |
| "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) | |
| ) | |
| def q_sample(self, x_start, t, noise=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| * noise | |
| ) | |
| def forward(self, x): | |
| z = self.model.encode(x) | |
| if isinstance(z, DiagonalGaussianDistribution): | |
| z = z.sample() | |
| z = z * self.scale_factor | |
| noise_level = torch.randint( | |
| 0, self.max_noise_level, (x.shape[0],), device=x.device | |
| ).long() | |
| z = self.q_sample(z, noise_level) | |
| if self.out_size is not None: | |
| z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") | |
| return z, noise_level | |
| def decode(self, z): | |
| z = z / self.scale_factor | |
| return self.model.decode(z) | |
| class ConcatTimestepEmbedderND(AbstractEmbModel): | |
| """embeds each dimension independently and concatenates them""" | |
| def __init__(self, outdim): | |
| super().__init__() | |
| self.timestep = Timestep(outdim) | |
| self.outdim = outdim | |
| def forward(self, x): | |
| if x.ndim == 1: | |
| x = x[:, None] | |
| assert len(x.shape) == 2 | |
| b, dims = x.shape[0], x.shape[1] | |
| x = rearrange(x, "b d -> (b d)") | |
| emb = self.timestep(x) | |
| emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) | |
| return emb | |
| class GaussianEncoder(Encoder, AbstractEmbModel): | |
| def __init__( | |
| self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.posterior = DiagonalGaussianRegularizer() | |
| self.weight = weight | |
| self.flatten_output = flatten_output | |
| def forward(self, x) -> Tuple[Dict, torch.Tensor]: | |
| z = super().forward(x) | |
| z, log = self.posterior(z) | |
| log["loss"] = log["kl_loss"] | |
| log["weight"] = self.weight | |
| if self.flatten_output: | |
| z = rearrange(z, "b c h w -> b (h w ) c") | |
| return log, z | |
| class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): | |
| def __init__( | |
| self, | |
| n_cond_frames: int, | |
| n_copies: int, | |
| encoder_config: dict, | |
| sigma_sampler_config: Optional[dict] = None, | |
| sigma_cond_config: Optional[dict] = None, | |
| is_ae: bool = False, | |
| scale_factor: float = 1.0, | |
| disable_encoder_autocast: bool = False, | |
| en_and_decode_n_samples_a_time: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.n_cond_frames = n_cond_frames | |
| self.n_copies = n_copies | |
| self.encoder = instantiate_from_config(encoder_config) | |
| self.sigma_sampler = ( | |
| instantiate_from_config(sigma_sampler_config) | |
| if sigma_sampler_config is not None | |
| else None | |
| ) | |
| self.sigma_cond = ( | |
| instantiate_from_config(sigma_cond_config) | |
| if sigma_cond_config is not None | |
| else None | |
| ) | |
| self.is_ae = is_ae | |
| self.scale_factor = scale_factor | |
| self.disable_encoder_autocast = disable_encoder_autocast | |
| self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time | |
| def forward( | |
| self, vid: torch.Tensor | |
| ) -> Union[ | |
| torch.Tensor, | |
| Tuple[torch.Tensor, torch.Tensor], | |
| Tuple[torch.Tensor, dict], | |
| Tuple[Tuple[torch.Tensor, torch.Tensor], dict], | |
| ]: | |
| if self.sigma_sampler is not None: | |
| b = vid.shape[0] // self.n_cond_frames | |
| sigmas = self.sigma_sampler(b).to(vid.device) | |
| if self.sigma_cond is not None: | |
| sigma_cond = self.sigma_cond(sigmas) | |
| sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) | |
| sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) | |
| noise = torch.randn_like(vid) | |
| vid = vid + noise * append_dims(sigmas, vid.ndim) | |
| with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): | |
| n_samples = ( | |
| self.en_and_decode_n_samples_a_time | |
| if self.en_and_decode_n_samples_a_time is not None | |
| else vid.shape[0] | |
| ) | |
| n_rounds = math.ceil(vid.shape[0] / n_samples) | |
| all_out = [] | |
| for n in range(n_rounds): | |
| if self.is_ae: | |
| out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) | |
| else: | |
| out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) | |
| all_out.append(out) | |
| vid = torch.cat(all_out, dim=0) | |
| vid *= self.scale_factor | |
| vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) | |
| vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) | |
| # modified for svd | |
| # vid = repeat(vid, "b 1 c h w -> b t c h w", t=self.n_copies) | |
| return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid | |
| return return_val | |
| class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): | |
| def __init__( | |
| self, | |
| open_clip_embedding_config: Dict, | |
| n_cond_frames: int, | |
| n_copies: int, | |
| ): | |
| super().__init__() | |
| self.n_cond_frames = n_cond_frames | |
| self.n_copies = n_copies | |
| self.open_clip = instantiate_from_config(open_clip_embedding_config) | |
| def forward(self, vid): | |
| vid = self.open_clip(vid) | |
| vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) | |
| vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) | |
| return vid | |
| class PixelNeRFEmbedder(AbstractEmbModel): | |
| def __init__( | |
| self, | |
| image_encoder_config: dict, | |
| pixelnerf_encoder_config: dict, | |
| render_size: int, | |
| num_video_frames: int, | |
| ): | |
| super().__init__() | |
| self.render_size = render_size | |
| self.num_video_frames = num_video_frames | |
| self.image_encoder = instantiate_from_config(image_encoder_config) | |
| self.pixelnerf_encoder = instantiate_from_config(pixelnerf_encoder_config) | |
| def forward(self, pixelnerf_input): | |
| if "source_index" not in pixelnerf_input: | |
| source_images = pixelnerf_input["frames"][:, 0] | |
| image_feats = self.image_encoder(source_images) | |
| image_feats = image_feats[:, None] | |
| source_cameras = pixelnerf_input["cameras"][:, :1] | |
| else: | |
| # source_images = pixelnerf_input["frames"][ | |
| # :, pixelnerf_input["source_index"] | |
| # ] | |
| source_images = pixelnerf_input["source_images"] | |
| n_source_images = source_images.shape[1] | |
| source_images = rearrange(source_images, "b t c h w -> (b t) c h w") | |
| image_feats = self.image_encoder(source_images) | |
| image_feats = rearrange( | |
| image_feats, "(b t) c h w -> b t c h w", t=n_source_images | |
| ) | |
| source_cameras = pixelnerf_input["source_cameras"] | |
| cameras = pixelnerf_input["cameras"] | |
| target_cameras = cameras[:, :] | |
| # source_images = source_images[:, None, ...] | |
| source_c2ws = source_cameras[..., :16].reshape(*source_cameras.shape[:-1], 4, 4) | |
| source_intrinsics = source_cameras[..., 16:].reshape( | |
| *source_cameras.shape[:-1], 3, 3 | |
| ) | |
| target_c2ws = target_cameras[..., :16].reshape(*target_cameras.shape[:-1], 4, 4) | |
| target_intrinsics = target_cameras[..., 16:].reshape( | |
| *target_cameras.shape[:-1], 3, 3 | |
| ) | |
| rgb, feats = self.pixelnerf_encoder( | |
| image_feats, | |
| source_c2ws, | |
| source_intrinsics, | |
| target_c2ws, | |
| target_intrinsics, | |
| self.render_size, | |
| ) | |
| rgb = rearrange(rgb, "b t c h w -> (b t) c h w") | |
| feats = rearrange(feats, "b t c h w -> (b t) c h w") | |
| return rgb, feats | |
| class ExtraConditioner(GeneralConditioner): | |
| def forward(self, batch: Dict, force_zero_embeddings: List | None = None) -> Dict: | |
| bs = batch["frames"].shape[0] | |
| num_frames = batch["num_video_frames"] | |
| output = dict() | |
| if force_zero_embeddings is None: | |
| force_zero_embeddings = [] | |
| for embedder in self.embedders: | |
| embedding_context = nullcontext if embedder.is_trainable else torch.no_grad | |
| with embedding_context(): | |
| if hasattr(embedder, "input_key") and (embedder.input_key is not None): | |
| if embedder.legacy_ucg_val is not None: | |
| batch = self.possibly_get_ucg_val(embedder, batch) | |
| emb_out = embedder(batch[embedder.input_key]) | |
| elif hasattr(embedder, "input_keys"): | |
| emb_out = embedder(*[batch[k] for k in embedder.input_keys]) | |
| assert isinstance( | |
| emb_out, (torch.Tensor, list, tuple) | |
| ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" | |
| if not isinstance(emb_out, (list, tuple)): | |
| emb_out = [emb_out] | |
| if isinstance(embedder, PixelNeRFEmbedder): | |
| # a hack for pixelnerf input | |
| output["rgb"] = emb_out[0] | |
| emb_out = emb_out[1:] | |
| for emb in emb_out: | |
| out_key = self.OUTPUT_DIM2KEYS[emb.dim()] | |
| if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: | |
| emb = ( | |
| expand_dims_like( | |
| torch.bernoulli( | |
| (1.0 - embedder.ucg_rate) | |
| * torch.ones(emb.shape[0], device=emb.device) | |
| ), | |
| emb, | |
| ) | |
| * emb | |
| ) | |
| if ( | |
| hasattr(embedder, "input_key") | |
| and embedder.input_key in force_zero_embeddings | |
| ): | |
| emb = torch.zeros_like(emb) | |
| if out_key in output: | |
| output[out_key] = torch.cat( | |
| (output[out_key], emb), self.KEY2CATDIM[out_key] | |
| ) | |
| else: | |
| output[out_key] = emb | |
| if out_key in ["crossattn", "concat"]: | |
| if output[out_key].shape[0] != bs: | |
| output[out_key] = repeat( | |
| output[out_key], "b ... -> (b t) ...", t=num_frames | |
| ) | |
| return output | |