customdiffusion360's picture
first commit
ad7bc89
raw
history blame
No virus
40.8 kB
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
from packaging import version
import kornia
import numpy as np
import open_clip
from open_clip.tokenizer import SimpleTokenizer
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint
import transformers
from transformers import (ByT5Tokenizer, CLIPTextModel, CLIPTokenizer,
T5EncoderModel, T5Tokenizer)
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
from ...modules.diffusionmodules.model import Encoder
from ...modules.diffusionmodules.openaimodel import Timestep
from ...modules.diffusionmodules.util import (extract_into_tensor,
make_beta_schedule)
from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import (append_dims, autocast, count_params, default,
disabled_train, expand_dims_like, instantiate_from_config)
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate
@property
def input_key(self) -> str:
return self._input_key
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate
@input_key.deleter
def input_key(self):
del self._input_key
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
def __init__(self, emb_models: Union[List, ListConfig]):
super().__init__()
embedders = []
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
embedder.is_trainable = embconfig.get("is_trainable", False)
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
if "input_key" in embconfig:
embedder.input_key = embconfig["input_key"]
elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"].split(',')
else:
raise KeyError(
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
)
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def forward(
self, batch: Dict, force_zero_embeddings: Optional[List] = None, force_ref_zero_embeddings: bool = False
) -> Dict:
output = dict()
if force_zero_embeddings is None:
force_zero_embeddings = []
for embedder in self.embedders:
embedding_context = nullcontext if (embedder.is_trainable or embedder.modifier_token is not None) else torch.no_grad
with embedding_context():
if hasattr(embedder, "input_key") and (embedder.input_key is not None):
if embedder.legacy_ucg_val is not None:
batch = self.possibly_get_ucg_val(embedder, batch)
emb_out = embedder(batch[embedder.input_key])
elif hasattr(embedder, "input_keys"):
if force_ref_zero_embeddings:
emb_out = embedder(batch[embedder.input_keys[0]])
else:
emb_out = [embedder(batch[k]) for k in embedder.input_keys]
if isinstance(emb_out[0], tuple):
emb_out = [torch.cat([x[0] for x in emb_out]), torch.cat([x[1] for x in emb_out])]
else:
emb_out = torch.cat(emb_out)
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate)
* torch.ones(emb.shape[0], device=emb.device)
),
emb,
)
* emb
)
if (
hasattr(embedder, "input_key")
and embedder.input_key in force_zero_embeddings
):
emb = torch.zeros_like(emb)
if (
hasattr(embedder, "input_keys")
and embedder.input_keys in force_zero_embeddings
):
emb = torch.zeros_like(emb)
if out_key in output:
if hasattr(embedder, "input_keys"):
catdim = 1 if ('pose' in embedder.input_keys) else self.KEY2CATDIM[out_key]
if not force_ref_zero_embeddings:
c, c1 = emb.chunk(2)
output[out_key] = torch.cat(
(output[out_key], c), catdim
)
output[out_key+'_ref'] = torch.cat(
(output[out_key+'_ref'], c1), catdim
)
else:
# print(output[out_key].size(), emb.size(), "$")
output[out_key] = torch.cat(
(output[out_key], emb), catdim
)
else:
catdim = 1 if ('pose' in embedder.input_key and emb.size(1) != 77) else self.KEY2CATDIM[out_key]
output[out_key] = torch.cat(
(output[out_key], emb), catdim
)
else:
if hasattr(embedder, "input_keys"):
if not force_ref_zero_embeddings:
c, c1 = emb.chunk(2)
output[out_key] = c
output[out_key+'_ref'] = c1
else:
output[out_key] = emb
else:
output[out_key] = emb
for out_key in self.OUTPUT_DIM2KEYS.values():
if out_key+'_ref' in output and not force_ref_zero_embeddings:
output[out_key] = torch.cat([output[out_key], output[out_key+'_ref']], 0)
del output[out_key+'_ref']
return output
def get_unconditional_conditioning(
self,
batch_c: Dict,
batch_uc: Optional[Dict] = None,
force_uc_zero_embeddings: Optional[List[str]] = None,
force_ref_zero_embeddings: Optional[List[str]] = None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0
c = self(batch_c, force_ref_zero_embeddings=force_ref_zero_embeddings)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings, force_ref_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
return c, uc
class InceptionV3(nn.Module):
"""Wrapper around the https://github.com/mseitzer/pytorch-fid inception
port with an additional squeeze at the end"""
def __init__(self, normalize_input=False, **kwargs):
super().__init__()
from pytorch_fid import inception
kwargs["resize_input"] = True
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
def forward(self, inp):
outp = self.model(inp)
if len(outp) == 1:
return outp[0].squeeze()
return outp
class IdentityEncoder(AbstractEmbModel):
def encode(self, x):
return x
def forward(self, x):
return x
class ClassEmbedder(AbstractEmbModel):
def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
super().__init__()
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.add_sequence_dim = add_sequence_dim
def forward(self, c):
c = self.embedding(c)
if self.add_sequence_dim:
c = c[:, None, :]
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc.long()}
return uc
class ClassEmbedderForMultiCond(ClassEmbedder):
def forward(self, batch, key=None, disable_dropout=False):
out = batch
key = default(key, self.key)
islist = isinstance(batch[key], list)
if islist:
batch[key] = batch[key][0]
c_out = super().forward(batch, key, disable_dropout)
out[key] = [c_out] if islist else c_out
return out
class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenByT5Embedder(AbstractEmbModel):
"""
Uses the ByT5 transformer encoder for text. Is character-aware.
"""
def __init__(
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEmbModel):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(
self,
modifier_token=None,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False,
): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.modifier_token = modifier_token
if self.modifier_token is not None:
if '+' in self.modifier_token:
self.modifier_token = self.modifier_token.split('+')
else:
self.modifier_token = [self.modifier_token]
self.add_token()
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
self.return_pooled = always_return_pooled
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def add_token(self):
self.modifier_token_id = []
for each_modifier_token in self.modifier_token:
print(each_modifier_token, "adding new token")
_ = self.tokenizer.add_tokens(each_modifier_token)
modifier_token_id = self.tokenizer.convert_tokens_to_ids(each_modifier_token)
self.modifier_token_id.append(modifier_token_id)
self.transformer.resize_token_embeddings(len(self.tokenizer))
token_embeds = self.transformer.get_input_embeddings().weight.data
token_embeds[self.modifier_token_id[-1]] = torch.nn.Parameter(token_embeds[42170], requires_grad=True)
if len(self.modifier_token) == 2:
token_embeds[self.modifier_token_id[-2]] = torch.nn.Parameter(token_embeds[47629], requires_grad=True)
if len(self.modifier_token) == 3:
token_embeds[self.modifier_token_id[-3]] = torch.nn.Parameter(token_embeds[43514], requires_grad=True)
def freeze(self):
if self.modifier_token is not None:
self.transformer = self.transformer.eval()
for param in self.transformer.text_model.encoder.parameters():
param.requires_grad = False
for param in self.transformer.text_model.final_layer_norm.parameters():
param.requires_grad = False
for param in self.transformer.text_model.embeddings.parameters():
param.requires_grad = False
for param in self.transformer.get_input_embeddings().parameters():
param.requires_grad = True
print("making grad true")
else:
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
@autocast
def custom_forward(self, hidden_states, input_ids):
r"""
Returns:
"""
input_shape = hidden_states.size()
bsz, seq_len = input_shape[:2]
if version.parse(transformers.__version__) >= version.parse('4.21'):
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
else:
causal_attention_mask = self.transformer.text_model._build_causal_attention_mask(bsz, seq_len).to(
hidden_states.device
)
encoder_outputs = self.transformer.text_model.encoder(
inputs_embeds=hidden_states,
causal_attention_mask=causal_attention_mask,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.transformer.text_model.final_layer_norm(last_hidden_state)
return last_hidden_state
@autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
)
tokens = batch_encoding["input_ids"].to(self.device)
if self.modifier_token is not None:
indices = tokens == self.modifier_token_id[-1]
for token_id in self.modifier_token_id:
indices |= tokens == token_id
indices = (indices*1).unsqueeze(-1)
input_shape = tokens.size()
tokens = tokens.view(-1, input_shape[-1])
hidden_states = self.transformer.text_model.embeddings(input_ids=tokens)
if self.modifier_token is not None:
hidden_states = (1-indices)*hidden_states.detach() + indices*hidden_states
z = self.custom_forward(hidden_states, tokens)
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = ["pooled", "last", "penultimate"]
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device("cpu"),
pretrained=version,
)
del model.visual
self.model = model
self.modifier_token = None
self.device = device
self.max_length = max_length
self.return_pooled = always_return_pooled
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
self.legacy = legacy
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
if not self.return_pooled and self.legacy:
return z
if self.return_pooled:
assert not self.legacy
return z[self.layer], z["pooled"]
return z[self.layer]
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
if self.legacy:
x = x[self.layer]
x = self.model.ln_final(x)
return x
else:
# x is a dict and will stay a dict
o = x["last"]
o = self.model.ln_final(o)
pooled = self.pool(o, text)
x["pooled"] = pooled
return x
def pool(self, x, text):
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = (
x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
@ self.model.text_projection
)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
outputs = {}
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - 1:
outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
return outputs
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEmbModel):
LAYERS = [
# "pooled",
"last",
"penultimate",
]
def __init__(
self,
modifier_token=None,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
self.modifier_token = modifier_token
self.return_pooled = always_return_pooled
if self.modifier_token is not None:
if '+' in self.modifier_token:
self.modifier_token = self.modifier_token.split('+')
else:
self.modifier_token = [self.modifier_token]
self.tokenizer = SimpleTokenizer(additional_special_tokens=self.modifier_token)
self.add_token()
else:
self.tokenizer = SimpleTokenizer()
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
self.legacy = legacy
def tokenize(self, texts, context_length=77):
return self.tokenizer(texts, context_length=context_length)
def add_token(self):
self.modifier_token_id = []
token_embeds1 = self.model.token_embedding.weight.data
for each_modifier_token in self.modifier_token:
modifier_token_id = self.tokenizer.encoder[each_modifier_token]
self.modifier_token_id.append(modifier_token_id)
self.model.token_embedding = nn.Embedding(token_embeds1.shape[0] + len(self.modifier_token), token_embeds1.shape[1])
self.model.token_embedding.weight.data[:token_embeds1.shape[0]] = token_embeds1
self.model.token_embedding.weight.data[self.modifier_token_id[-1]] = token_embeds1[42170]
if len(self.modifier_token) == 2:
self.model.token_embedding.weight.data[self.modifier_token_id[-2]] = token_embeds1[47629]
def freeze(self):
if self.modifier_token is not None:
self.model = self.model.eval()
for param in self.model.transformer.parameters():
param.requires_grad = False
for param in self.model.ln_final.parameters():
param.requires_grad = False
self.model.text_projection.requires_grad = False
self.model.positional_embedding.requires_grad = False
for param in self.model.token_embedding.parameters():
param.requires_grad = True
print("making grad true")
else:
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
tokens = self.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
if not self.return_pooled and self.legacy:
return z
if self.return_pooled:
assert not self.legacy
return z[self.layer], z["pooled"]
return z[self.layer]
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
if self.modifier_token is not None:
indices = text == self.modifier_token_id[-1]
for token_id in self.modifier_token_id:
indices |= text == token_id
indices = (indices*1).unsqueeze(-1)
x = ((1-indices)*x.detach() + indices*x) + self.model.positional_embedding
else:
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
if self.legacy:
x = x[self.layer]
x = self.model.ln_final(x)
return x
else:
# x is a dict and will stay a dict
o = x["last"]
o = self.model.ln_final(o)
pooled = self.pool(o, text)
x["pooled"] = pooled
return x
def pool(self, x, text):
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = (
x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
@ self.model.text_projection
)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
outputs = {}
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - 1:
outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
if (
self.model.transformer.grad_checkpointing
and not torch.jit.is_scripting()
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
return outputs
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
init_device=None,
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device(default(init_device, "cpu")),
pretrained=version,
)
del model.transformer
self.model = model
self.max_crops = num_image_crops
self.pad_to_max_len = self.max_crops > 0
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.ucg_rate = ucg_rate
self.unsqueeze_dim = unsqueeze_dim
self.stored_batch = None
self.model.visual.output_tokens = output_tokens
self.output_tokens = output_tokens
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
tokens = None
if self.output_tokens:
z, tokens = z[0], z[1]
z = z.to(image.dtype)
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
z = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
)[:, None]
* z
)
if tokens is not None:
tokens = (
expand_dims_like(
torch.bernoulli(
(1.0 - self.ucg_rate)
* torch.ones(tokens.shape[0], device=tokens.device)
),
tokens,
)
* tokens
)
if self.unsqueeze_dim:
z = z[:, None, :]
if self.output_tokens:
assert not self.repeat_to_max_len
assert not self.pad_to_max_len
return tokens, z
if self.repeat_to_max_len:
if z.dim() == 2:
z_ = z[:, None, :]
else:
z_ = z
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
elif self.pad_to_max_len:
assert z.dim() == 3
z_pad = torch.cat(
(
z,
torch.zeros(
z.shape[0],
self.max_length - z.shape[1],
z.shape[2],
device=z.device,
),
),
1,
)
return z_pad, z_pad[:, 0, ...]
return z
def encode_with_vision_transformer(self, img):
# if self.max_crops > 0:
# img = self.preprocess_by_cropping(img)
if img.dim() == 5:
assert self.max_crops == img.shape[1]
img = rearrange(img, "b n c h w -> (b n) c h w")
img = self.preprocess(img)
if not self.output_tokens:
assert not self.model.visual.output_tokens
x = self.model.visual(img)
tokens = None
else:
assert self.model.visual.output_tokens
x, tokens = self.model.visual(img)
if self.max_crops > 0:
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
# drop out between 0 and all along the sequence axis
x = (
torch.bernoulli(
(1.0 - self.ucg_rate)
* torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
)
* x
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message."
)
if self.output_tokens:
return x, tokens
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEmbModel):
def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class SpatialRescaler(nn.Module):
def __init__(
self,
n_stages=1,
method="bilinear",
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
wrap_video=False,
kernel_size=1,
remap_output=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in [
"nearest",
"linear",
"bilinear",
"trilinear",
"bicubic",
"area",
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None or remap_output
if self.remap_output:
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
bias=bias,
padding=kernel_size // 2,
)
self.wrap_video = wrap_video
def forward(self, x):
if self.wrap_video and x.ndim == 5:
B, C, T, H, W = x.shape
x = rearrange(x, "b c t h w -> b t c h w")
x = rearrange(x, "b t c h w -> (b t) c h w")
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.wrap_video:
x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
x = rearrange(x, "b t c h w -> b c t h w")
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class LowScaleEncoder(nn.Module):
def __init__(
self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0,
):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(
self,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def forward(self, x):
z = self.model.encode(x)
if isinstance(z, DiagonalGaussianDistribution):
z = z.sample()
z = z * self.scale_factor
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0],), device=x.device
).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
class ConcatTimestepEmbedderND(AbstractEmbModel):
"""embeds each dimension independently and concatenates them"""
def __init__(self, outdim):
super().__init__()
self.timestep = Timestep(outdim)
self.outdim = outdim
self.modifier_token = None
def forward(self, x):
if x.ndim == 1:
x = x[:, None]
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
x = rearrange(x, "b d -> (b d)")
emb = self.timestep(x)
emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
return emb
class GaussianEncoder(Encoder, AbstractEmbModel):
def __init__(
self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.posterior = DiagonalGaussianRegularizer()
self.weight = weight
self.flatten_output = flatten_output
def forward(self, x) -> Tuple[Dict, torch.Tensor]:
z = super().forward(x)
z, log = self.posterior(z)
log["loss"] = log["kl_loss"]
log["weight"] = self.weight
if self.flatten_output:
z = rearrange(z, "b c h w -> b (h w ) c")
return log, z