mindeyev2old / src /models.py
ckadirt's picture
Upload folder using huggingface_hub
626cbe5 verified
import os
import numpy as np
from torchvision import transforms
import torch
import torch.nn as nn
import PIL
import clip
import open_clip
from functools import partial
# for prior
from dalle2_pytorch import DiffusionPrior
from dalle2_pytorch.dalle2_pytorch import l2norm, default, exists
from tqdm.auto import tqdm
import random
import json
from dalle2_pytorch.train_configs import DiffusionPriorNetworkConfig
# vd prior
from dalle2_pytorch.dalle2_pytorch import RotaryEmbedding, CausalTransformer, SinusoidalPosEmb, MLP, Rearrange, repeat, rearrange, prob_mask_like, LayerNorm, RelPosBias, FeedForward, Attention
# for pipeline
from diffusers import StableDiffusionImageVariationPipeline, VersatileDiffusionDualGuidedPipeline
from typing import Callable, List, Optional, Union
from diffusers.models.vae import Decoder
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# class BrainMLP(nn.Module):
# def __init__(self, out_dim=257*768, in_dim=15724, clip_size=768, h=4096):
# super().__init__()
# self.lin0 = nn.Sequential(
# nn.Linear(in_dim, h, bias=False),
# nn.LayerNorm(h),
# nn.GELU(inplace=True),
# nn.Dropout(0.5))
# self.mlp = nn.ModuleList([
# nn.Sequential(
# nn.Linear(h, h),
# nn.LayerNorm(h),
# nn.GELU(inplace=True),
# nn.Dropout(0.15)
# ) for _ in range(4)])
# self.lin1 = nn.Linear(h, out_dim, bias=True)
# self.proj = nn.Sequential(
# nn.LayerNorm(clip_size),
# nn.GELU(),
# nn.Linear(clip_size, 2048),
# nn.LayerNorm(2048),
# nn.GELU(),
# nn.Linear(2048, 2048),
# nn.LayerNorm(2048),
# nn.GELU(),
# nn.Linear(2048, clip_size))
# def forward(self, x):
# x = self.lin0(x)
# residual = x
# for res_block in range(self.n_blocks):
# x = self.mlp[res_block](x)
# x += residual
# residual = x
# diffusion_prior_input = self.lin1(x.reshape(len(x), -1))
# disjointed_clip_fmri = self.proj(diffusion_prior_input.reshape(
# len(x),-1, self.clip_size))
# return diffusion_prior_input, disjointed_clip_fmri
class Clipper(torch.nn.Module):
def __init__(self, clip_variant, clamp_embs=False, norm_embs=False,
hidden_state=False, device=torch.device('cpu')):
super().__init__()
assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \
"clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64"
print(clip_variant, device)
if clip_variant=="ViT-L/14" and hidden_state:
# from transformers import CLIPVisionModelWithProjection
# image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14",cache_dir="/fsx/proj-medarc/fmri/cache")
from transformers import CLIPVisionModelWithProjection
sd_cache_dir = '/fsx/proj-fmri/shared/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').eval()
image_encoder = image_encoder.to(device)
for param in image_encoder.parameters():
param.requires_grad = False # dont need to calculate gradients
self.image_encoder = image_encoder
elif hidden_state:
raise Exception("hidden_state embeddings only works with ViT-L/14 right now")
clip_model, preprocess = clip.load(clip_variant, device=device)
clip_model.eval() # dont want to train model
for param in clip_model.parameters():
param.requires_grad = False # dont need to calculate gradients
self.clip = clip_model
self.clip_variant = clip_variant
if clip_variant == "RN50x64":
self.clip_size = (448,448)
else:
self.clip_size = (224,224)
preproc = transforms.Compose([
transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(size=self.clip_size),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
self.preprocess = preproc
self.hidden_state = hidden_state
self.mean = np.array([0.48145466, 0.4578275, 0.40821073])
self.std = np.array([0.26862954, 0.26130258, 0.27577711])
self.normalize = transforms.Normalize(self.mean, self.std)
self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist())
self.clamp_embs = clamp_embs
self.norm_embs = norm_embs
self.device= device
def versatile_normalize_embeddings(encoder_output):
embeds = encoder_output.last_hidden_state
embeds = image_encoder.vision_model.post_layernorm(embeds)
embeds = image_encoder.visual_projection(embeds)
return embeds
self.versatile_normalize_embeddings = versatile_normalize_embeddings
def resize_image(self, image):
# note: antialias should be False if planning to use Pinkney's Image Variation SD model
return transforms.Resize(self.clip_size)(image.to(self.device))
def embed_image(self, image):
"""Expects images in -1 to 1 range"""
if self.hidden_state:
# clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation
clip_emb = self.preprocess((image).to(self.device))
clip_emb = self.image_encoder(clip_emb)
clip_emb = self.versatile_normalize_embeddings(clip_emb)
else:
clip_emb = self.preprocess(image.to(self.device))
clip_emb = self.clip.encode_image(clip_emb)
# input is now in CLIP space, but mind-reader preprint further processes embeddings:
if self.clamp_embs:
clip_emb = torch.clamp(clip_emb, -1.5, 1.5)
if self.norm_embs:
if self.hidden_state:
# normalize all tokens by cls token's norm
clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1)
else:
clip_emb = nn.functional.normalize(clip_emb, dim=-1)
return clip_emb
def embed_text(self, text_samples):
clip_text = clip.tokenize(text_samples).to(self.device)
clip_text = self.clip.encode_text(clip_text)
if self.clamp_embs:
clip_text = torch.clamp(clip_text, -1.5, 1.5)
if self.norm_embs:
clip_text = nn.functional.normalize(clip_text, dim=-1)
return clip_text
def embed_curated_annotations(self, annots):
for i,b in enumerate(annots):
t = ''
while t == '':
rand = torch.randint(5,(1,1))[0][0]
t = b[0,rand]
if i==0:
txt = np.array(t)
else:
txt = np.vstack((txt,t))
txt = txt.flatten()
return self.embed_text(txt)
class OpenClipper(torch.nn.Module):
def __init__(self, clip_variant, norm_embs=False, device=torch.device('cpu')):
super().__init__()
print(clip_variant, device)
assert clip_variant == 'ViT-H-14' # not setup for other models yet
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14',
pretrained='laion2b_s32b_b79k', device=device)
clip_model.eval() # dont want to train model
for param in clip_model.parameters():
param.requires_grad = False # dont need to calculate gradients
# overwrite preprocess to accept torch inputs instead of PIL Image
preprocess = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=None),
transforms.CenterCrop(224),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
tokenizer = open_clip.get_tokenizer('ViT-H-14')
self.clip = clip_model
self.norm_embs = norm_embs
self.preprocess = preprocess
self.tokenizer = tokenizer
self.device = device
def embed_image(self, image):
"""Expects images in -1 to 1 range"""
image = self.preprocess(image).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.clip.encode_image(image)
if self.norm_embs:
image_features = nn.functional.normalize(image_features, dim=-1)
return image_features
def embed_text(self, text_samples):
text = self.tokenizer(text_samples).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip.encode_text(text)
if self.norm_embs:
text_features = nn.functional.normalize(text_features, dim=-1)
return text_features
def embed_curated_annotations(self, annots):
for i,b in enumerate(annots):
t = ''
while t == '':
rand = torch.randint(5,(1,1))[0][0]
t = b[0,rand]
if i==0:
txt = np.array(t)
else:
txt = np.vstack((txt,t))
txt = txt.flatten()
return self.embed_text(txt)
class BrainNetwork(nn.Module):
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):
super().__init__()
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
# self.temp = nn.Parameter(torch.tensor(.006))
self.lin0 = nn.Sequential(
nn.Linear(in_dim, h),
*[item() for item in act_and_norm],
nn.Dropout(drop1),
)
self.mlp = nn.ModuleList([
nn.Sequential(
nn.Linear(h, h),
*[item() for item in act_and_norm],
nn.Dropout(drop2)
) for _ in range(n_blocks)
])
self.lin1 = nn.Linear(h, out_dim, bias=True)
self.n_blocks = n_blocks
self.clip_size = clip_size
self.use_projector = use_projector
if use_projector:
self.projector = nn.Sequential(
nn.LayerNorm(clip_size),
nn.GELU(),
nn.Linear(clip_size, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Linear(2048, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Linear(2048, clip_size)
)
def forward(self, x):
'''
bs, 1, 15724 -> bs, 32, h
bs, 32, h -> bs, 32h
b2, 32h -> bs, 768
'''
if x.ndim == 4:
# case when we passed 3D data of shape [N, 81, 104, 83]
assert x.shape[1] == 81 and x.shape[2] == 104 and x.shape[3] == 83
# [N, 699192]
x = x.reshape(x.shape[0], -1)
x = self.lin0(x) # bs, h
residual = x
for res_block in range(self.n_blocks):
x = self.mlp[res_block](x)
x += residual
residual = x
x = x.reshape(len(x), -1)
x = self.lin1(x)
if self.use_projector:
return x, self.projector(x.reshape(len(x), -1, self.clip_size))
return x
class BrainDiffusionPriorOld(DiffusionPrior):
"""
Differences from original:
- Allow for passing of generators to torch random functions
- Option to include the voxel2clip model and pass voxels into forward method
- Return predictions when computing loss
- Load pretrained model from @nousr trained on LAION aesthetics
"""
def __init__(self, *args, **kwargs):
voxel2clip = kwargs.pop('voxel2clip', None)
super().__init__(*args, **kwargs)
self.voxel2clip = voxel2clip
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.,
generator=None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
if generator is None:
noise = torch.randn_like(x)
else:
#noise = torch.randn_like(x)
noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None):
batch, device = shape[0], self.device
if generator is None:
image_embed = torch.randn(shape, device = device)
else:
image_embed = torch.randn(shape, device = device, generator=generator)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True):
times = torch.full((batch,), i, device = device, dtype = torch.long)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale,
generator=generator)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net(
image_embed_noisy,
times,
self_cond = self_cond,
text_cond_drop_prob = self.text_cond_drop_prob,
image_cond_drop_prob = self.image_cond_drop_prob,
**text_cond
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred)
if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
loss = self.noise_scheduler.loss_fn(pred, target)
return loss, pred
def forward(
self,
text = None,
image = None,
voxel = None,
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
*args,
**kwargs
):
assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied'
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(voxel):
assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels'
assert not exists(text_embed), 'cannot pass in both text and voxels'
text_embed = self.voxel2clip(voxel)
if exists(image):
image_embed, _ = self.clip.embed_image(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings}
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = self.noise_scheduler.sample_random_times(batch)
# scale image embed (Katherine)
image_embed *= self.image_embed_scale
# calculate forward loss
loss, pred = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss, pred#, text_embed
@staticmethod
def from_pretrained(net_kwargs={}, prior_kwargs={}, voxel2clip_path=None, ckpt_dir='./checkpoints'):
# "https://huggingface.co/nousr/conditioned-prior/raw/main/vit-l-14/aesthetic/prior_config.json"
config_url = os.path.join(ckpt_dir, "prior_config.json")
config = json.load(open(config_url))
config['prior']['net']['max_text_len'] = 256
config['prior']['net'].update(net_kwargs)
# print('net_config', config['prior']['net'])
net_config = DiffusionPriorNetworkConfig(**config['prior']['net'])
kwargs = config['prior']
kwargs.pop('clip')
kwargs.pop('net')
kwargs.update(prior_kwargs)
# print('prior_config', kwargs)
diffusion_prior_network = net_config.create()
diffusion_prior = BrainDiffusionPriorOld(net=diffusion_prior_network, clip=None, **kwargs).to(torch.device('cpu'))
# 'https://huggingface.co/nousr/conditioned-prior/resolve/main/vit-l-14/aesthetic/best.pth'
ckpt_url = os.path.join(ckpt_dir, 'best.pth')
ckpt = torch.load(ckpt_url, map_location=torch.device('cpu'))
# Note these keys will be missing (maybe due to an update to the code since training):
# "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed"
# I don't think these get used if `cond_drop_prob = 0` though (which is the default here)
diffusion_prior.load_state_dict(ckpt, strict=False)
# keys = diffusion_prior.load_state_dict(ckpt, strict=False)
# print("missing keys in prior checkpoint (probably ok)", keys.missing_keys)
if voxel2clip_path:
# load the voxel2clip weights
checkpoint = torch.load(voxel2clip_path, map_location=torch.device('cpu'))
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys()):
if 'module.' in key:
state_dict[key.replace('module.', '')] = state_dict[key]
del state_dict[key]
diffusion_prior.voxel2clip.load_state_dict(state_dict)
return diffusion_prior
class BrainDiffusionPrior(DiffusionPrior):
"""
Differences from original:
- Allow for passing of generators to torch random functions
- Option to include the voxel2clip model and pass voxels into forward method
- Return predictions when computing loss
- Load pretrained model from @nousr trained on LAION aesthetics
"""
def __init__(self, *args, **kwargs):
voxel2clip = kwargs.pop('voxel2clip', None)
super().__init__(*args, **kwargs)
self.voxel2clip = voxel2clip
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.,
generator=None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
if generator is None:
noise = torch.randn_like(x)
else:
#noise = torch.randn_like(x)
noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None):
batch, device = shape[0], self.device
if generator is None:
image_embed = torch.randn(shape, device = device)
else:
image_embed = torch.randn(shape, device = device, generator=generator)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True):
times = torch.full((batch,), i, device = device, dtype = torch.long)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale,
generator=generator)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net(
image_embed_noisy,
times,
self_cond = self_cond,
text_cond_drop_prob = self.text_cond_drop_prob,
image_cond_drop_prob = self.image_cond_drop_prob,
**text_cond
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred)
if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise
loss = self.noise_scheduler.loss_fn(pred, target)
return loss, pred
def forward(
self,
text = None,
image = None,
voxel = None,
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
image_embed = None,
text_encodings = None, # as well as CLIP text encodings
*args,
**kwargs
):
assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied'
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
if exists(voxel):
assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels'
assert not exists(text_embed), 'cannot pass in both text and voxels'
if self.voxel2clip.use_projector:
clip_voxels_mse, clip_voxels = self.voxel2clip(voxel)
text_embed = clip_voxels_mse
else:
clip_voxels = self.voxel2clip(voxel)
text_embed = clip_voxels_mse = clip_voxels
# text_embed = self.voxel2clip(voxel)
if exists(image):
image_embed, _ = self.clip.embed_image(image)
# calculate text conditionings, based on what is passed in
if exists(text):
text_embed, text_encodings = self.clip.embed_text(text)
text_cond = dict(text_embed = text_embed)
if self.condition_on_text_encodings:
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
text_cond = {**text_cond, 'text_encodings': text_encodings}
# timestep conditioning from ddpm
batch, device = image_embed.shape[0], image_embed.device
times = self.noise_scheduler.sample_random_times(batch)
# PS: I dont think we need this? also if uncommented this does in-place global variable change
# scale image embed (Katherine)
# image_embed *= self.image_embed_scale
# calculate forward loss
loss, pred = self.p_losses(image_embed*self.image_embed_scale, times, text_cond = text_cond, *args, **kwargs)
# undo the scaling so we can directly use it for real mse loss and reconstruction
return loss, pred
class BrainSD(StableDiffusionImageVariationPipeline):
"""
Differences from original:
- Keep generated images on GPU and return tensors
- No NSFW checker
- Can pass in image or image_embedding to generate a variation
NOTE: requires latest version of diffusers to avoid the latent dims not being correct.
"""
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
# image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
@torch.no_grad()
def __call__(
self,
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
image_embeddings: Optional[torch.FloatTensor] = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if image_embeddings is None:
assert image is not None, "If image_embeddings is None, image must not be None"
# resize and normalize the way that's recommended
tform = transforms.Compose([
#transforms.ToTensor(), ## don't need this since we've already got tensors
transforms.Resize(
(224, 224),
interpolation=transforms.InterpolationMode.BICUBIC,
antialias=False,
),
transforms.Normalize(
[0.48145466, 0.4578275, 0.40821073],
[0.26862954, 0.26130258, 0.27577711]),
])
image = tform(image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
else:
batch_size = image_embeddings.shape[0] // 2
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# # 9. Run safety checker
# image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
# # 10. Convert to PIL
# if output_type == "pil":
# image = self.numpy_to_pil(image)
# if not return_dict:
# return (image, has_nsfw_concept)
# return StableDiffusionPipelineOutput(images=image)
return image
class Voxel2StableDiffusionModel(torch.nn.Module):
def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False):
super().__init__()
self.lin0 = nn.Sequential(
nn.Linear(in_dim, h, bias=False),
nn.LayerNorm(h),
nn.SiLU(inplace=True),
nn.Dropout(0.5),
)
self.mlp = nn.ModuleList([
nn.Sequential(
nn.Linear(h, h, bias=False),
nn.LayerNorm(h),
nn.SiLU(inplace=True),
nn.Dropout(0.25)
) for _ in range(n_blocks)
])
self.lin1 = nn.Linear(h, 16384, bias=False)
self.norm = nn.LayerNorm(512)
self.register_parameter('queries', nn.Parameter(torch.randn(1, 256, 512) * 0.044))
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=512, nhead=8, norm_first=True,
dim_feedforward=1024, activation=nn.functional.gelu,
batch_first=True, dropout=0.25),
num_layers=n_blocks
)
# option 1 -> 124.56M
# self.lin1 = nn.Linear(h, 32768, bias=True)
# self.upsampler = Decoder(
# in_channels=64,
# out_channels=4,
# up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
# block_out_channels=[64, 128, 256, 256],
# layers_per_block=1,
# )
# option2 -> 132.52M
# self.lin1 = nn.Linear(h, 1024, bias=True)
# self.upsampler = Decoder(
# in_channels=64,
# out_channels=4,
# up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"],
# block_out_channels=[64, 128, 256, 256, 512],
# layers_per_block=1,
# )
if use_cont:
self.maps_projector = nn.Sequential(
nn.LayerNorm(512),
nn.Linear(512, 512),
nn.LayerNorm(512),
nn.ReLU(True),
nn.Linear(512, 512),
nn.LayerNorm(512),
nn.ReLU(True),
nn.Linear(512, 512)
)
else:
self.maps_projector = nn.Identity()
self.upsampler = nn.Sequential(
nn.GroupNorm(1, 32),
nn.SiLU(inplace=True),
nn.Conv2d(32, 320, 3, padding=1),
nn.GroupNorm(32, 320),
nn.SiLU(inplace=True),
nn.Conv2d(320, 320, 3, padding=1),
nn.GroupNorm(32, 320),
nn.SiLU(inplace=True),
nn.Conv2d(320, 4, 3, padding=1)
)
def forward(self, x, return_transformer_feats=False):
x = self.lin0(x)
residual = x
for res_block in self.mlp:
x = res_block(x)
x = x + residual
residual = x
x = x.reshape(len(x), -1)
x = self.lin1(x) # bs, 4096
# # x = x.reshape(x.shape[0], -1, 8, 8).contiguous() # bs, 64, 8, 8
# x = x.reshape(x.shape[0], -1, 64, 64).contiguous()
# return self.upsampler(x)
# decoder
x = self.norm(x.reshape(x.shape[0], 32, 512))
preds = self.transformer(self.queries.expand(x.shape[0], -1, -1), x)
sd_embeds = preds.permute(0,2,1).reshape(-1, 512, 16, 16)
sd_embeds = nn.functional.pixel_shuffle(sd_embeds, 4) # bs, 32, 32, 32
# contrastive
if return_transformer_feats:
return self.upsampler(sd_embeds), self.maps_projector(preds)
return self.upsampler(sd_embeds)
class BrainNetworkDETR(BrainNetwork):
# 133M
def __init__(self, out_dim=768, in_dim=15724, h=4096, n_blocks=4, norm_type='ln', act_first=False,
encoder_tokens=32, decoder_tokens=257):
# encoder
super().__init__(out_dim*encoder_tokens, in_dim, h, n_blocks, norm_type, act_first)
self.norm = nn.LayerNorm(out_dim)
self.encoder_tokens = encoder_tokens
self.register_parameter('queries', nn.Parameter(torch.randn(1, decoder_tokens, out_dim)))
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=out_dim, nhead=8,
dim_feedforward=1024,
batch_first=True, dropout=0.25),
num_layers=n_blocks
)
self.decoder_projector = nn.Sequential(
nn.LayerNorm(out_dim),
nn.Linear(out_dim, out_dim)
)
def forward(self, x):
enc = super().forward(x)
enc = self.norm(enc.reshape(enc.shape[0], self.encoder_tokens, -1))
dec = self.transformer(self.queries.expand(x.shape[0], -1, -1), enc)
dec = self.decoder_projector(dec)
return dec
class VersatileDiffusionPriorNetwork(nn.Module):
def __init__(
self,
dim,
num_timesteps = None,
num_time_embeds = 1,
# num_image_embeds = 1,
# num_brain_embeds = 1,
num_tokens = 257,
causal = True,
learned_query_mode = 'none',
**kwargs
):
super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds
self.continuous_embedded_time = not exists(num_timesteps)
self.learned_query_mode = learned_query_mode
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
if self.learned_query_mode == 'token':
self.learned_query = nn.Parameter(torch.randn(num_tokens, dim))
if self.learned_query_mode == 'pos_emb':
scale = dim ** -0.5
self.learned_query = nn.Parameter(torch.randn(num_tokens, dim) * scale)
if self.learned_query_mode == 'all_pos_emb':
scale = dim ** -0.5
self.learned_query = nn.Parameter(torch.randn(num_tokens*2+1, dim) * scale)
self.causal_transformer = FlaggedCausalTransformer(dim = dim, causal=causal, **kwargs)
self.null_brain_embeds = nn.Parameter(torch.randn(num_tokens, dim))
self.null_image_embed = nn.Parameter(torch.randn(num_tokens, dim))
self.num_tokens = num_tokens
self.self_cond = False
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, brain_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
image_embed,
diffusion_timesteps,
*,
self_cond=None,
brain_embed=None,
text_embed=None,
brain_cond_drop_prob = 0.,
text_cond_drop_prob = None,
image_cond_drop_prob = 0.
):
if text_embed is not None:
brain_embed = text_embed
if text_cond_drop_prob is not None:
brain_cond_drop_prob = text_cond_drop_prob
image_embed = image_embed.view(len(image_embed),-1,768)
# text_embed = text_embed.view(len(text_embed),-1,768)
brain_embed = brain_embed.view(len(brain_embed),-1,768)
# print(*image_embed.shape)
# print(*image_embed.shape, image_embed.device, image_embed.dtype)
batch, _, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
# num_time_embeds, num_image_embeds, num_brain_embeds = self.num_time_embeds, self.num_image_embeds, self.num_brain_embeds
# classifier free guidance masks
brain_keep_mask = prob_mask_like((batch,), 1 - brain_cond_drop_prob, device = device)
brain_keep_mask = rearrange(brain_keep_mask, 'b -> b 1 1')
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
# mask out brain embeddings with null brain embeddings
# import pdb; pdb.set_trace()
null_brain_embeds = self.null_brain_embeds.to(brain_embed.dtype)
brain_embed = torch.where(
brain_keep_mask,
brain_embed,
null_brain_embeds[None]
)
# mask out image embeddings with null image embeddings
null_image_embed = self.null_image_embed.to(image_embed.dtype)
image_embed = torch.where(
image_keep_mask,
image_embed,
null_image_embed[None]
)
# whether brain embedding is used for conditioning depends on whether brain encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if self.continuous_embedded_time:
# if continuous cast to flat, else keep int for indexing embeddings
diffusion_timesteps = diffusion_timesteps.type(dtype)
time_embed = self.to_time_embeds(diffusion_timesteps)
if self.learned_query_mode == 'token':
learned_queries = repeat(self.learned_query, 'n d -> b n d', b = batch)
elif self.learned_query_mode == 'pos_emb':
pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch)
image_embed = image_embed + pos_embs
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
elif self.learned_query_mode == 'all_pos_emb':
pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch)
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
else:
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device)
tokens = torch.cat((
brain_embed, # 257
time_embed, # 1
image_embed, # 257
learned_queries # 257
), dim = -2)
if self.learned_query_mode == 'all_pos_emb':
tokens = tokens + pos_embs
# attend
tokens = self.causal_transformer(tokens)
# get learned query, which should predict the image embedding (per DDPM timestep)
pred_image_embed = tokens[..., -self.num_tokens:, :]
return pred_image_embed
# import math
# from collections import namedtuple
# from einops import rearrange, repeat, reduce, pack, unpack
# from einops.layers.torch import Rearrange
# from torch import einsum
# class Attention(nn.Module):
# def __init__(
# self,
# dim,
# *,
# dim_head = 64,
# heads = 8,
# dropout = 0.,
# causal = False,
# rotary_emb = None,
# cosine_sim = True,
# cosine_sim_scale = 16
# ):
# super().__init__()
# self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
# self.cosine_sim = cosine_sim
# self.heads = heads
# inner_dim = dim_head * heads
# self.dim = dim
# self.inner_dim = inner_dim
# self.causal = causal
# self.norm = LayerNorm(dim)
# self.dropout = nn.Dropout(dropout)
# self.null_kv = nn.Parameter(torch.randn(2, dim_head))
# self.to_q = nn.Linear(dim, inner_dim, bias = False)
# self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
# self.rotary_emb = rotary_emb
# self.to_out = nn.Sequential(
# nn.Linear(inner_dim, dim, bias = False),
# LayerNorm(dim)
# )
# def forward(self, x, mask = None, attn_bias = None):
# b, n, device = *x.shape[:2], x.device
# print("xinit", torch.any(torch.isnan(x)))
# x = self.norm(x)
# print("xnorm", torch.any(torch.isnan(x)))
# print("xnorm.shape", x.shape)
# q = self.to_q(x)
# print("q0", torch.any(torch.isnan(q)))
# k, v = self.to_kv(x).chunk(2, dim = -1)
# print("k0", torch.any(torch.isnan(k)))
# print("v0", torch.any(torch.isnan(v)))
# q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# q = q * self.scale
# # rotary embeddings
# if exists(self.rotary_emb):
# q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# # add null key / value for classifier free guidance in prior net
# nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
# k = torch.cat((nk, k), dim = -2)
# v = torch.cat((nv, v), dim = -2)
# # whether to use cosine sim
# if self.cosine_sim:
# q, k = map(l2norm, (q, k))
# q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
# print("q2", torch.any(torch.isnan(q)))
# print("k2", torch.any(torch.isnan(k)))
# # calculate query / key similarities
# sim = einsum('b h i d, b j d -> b h i j', q, k)
# # relative positional encoding (T5 style)
# if exists(attn_bias):
# sim = sim + attn_bias
# # masking
# max_neg_value = -torch.finfo(sim.dtype).max
# print("sim1", torch.any(torch.isnan(sim)))
# if exists(mask):
# mask = F.pad(mask, (1, 0), value = True)
# mask = rearrange(mask, 'b j -> b 1 1 j')
# sim = sim.masked_fill(~mask, max_neg_value)
# print("sim2", torch.any(torch.isnan(sim)))
# if self.causal:
# i, j = sim.shape[-2:]
# causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
# sim = sim.masked_fill(causal_mask, max_neg_value)
# # attention
# print("simFinal", torch.any(torch.isnan(sim)))
# attn = sim.softmax(dim = -1, dtype = torch.float32)
# print("attn", torch.any(torch.isnan(attn)))
# attn = attn.type(sim.dtype)
# attn = self.dropout(attn)
# # aggregate values
# out = einsum('b h i j, b j d -> b h i d', attn, v)
# out = rearrange(out, 'b h n d -> b n (h d)')
# return self.to_out(out)
class FlaggedCausalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_in = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True,
normformer = False,
rotary_emb = False,
causal=True
):
super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(self, x):
n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, attn_bias = attn_bias) + x
x = ff(x) + x
out = self.norm(x)
return self.project_out(out)
class BrainVD(VersatileDiffusionDualGuidedPipeline):
"""
Differences from original:
- Keep generated images on GPU and return tensors
- No NSFW checker
- Can pass in image or image_embedding to generate a variation
NOTE: requires latest version of diffusers to avoid the latent dims not being correct.
"""
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
# image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, height, width, callback_steps):
if prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type None, `str` or `list` but is {type(prompt)}")
if image is not None and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
raise ValueError(f"`image` has to be of type None, `PIL.Image` or `list` but is {type(image)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
@torch.no_grad()
def __call__(
self,
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
image: Union[str, List[str]] = None,
text_to_image_strength: float = 0.5,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
image_embeddings: Optional[torch.FloatTensor] = None,
prompt_embeddings: Optional[torch.FloatTensor] = None,
**kwargs,
):
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
self.check_inputs(prompt, image, height, width, callback_steps)
prompt = [prompt] if prompt is not None and not isinstance(prompt, list) else prompt
image = [image] if image is not None and not isinstance(image, list) else image
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
if image_embeddings is None:
if image is not None:
image_embeddings = self._encode_image_prompt(
image, device, num_images_per_prompt, do_classifier_free_guidance
)
batch_size = len(image)
else:
image_embeddings = None
if prompt_embeddings is None:
if prompt is not None:
prompt_embeddings = self._encode_text_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance
)
batch_size = len(prompt)
else:
prompt_embeddings = None
if image_embeddings is not None:
batch_size = image_embeddings.shape[0] // 2
elif prompt_embeddings is not None:
batch_size = prompt_embeddings.shape[0] // 2
if image_embeddings is not None and prompt_embeddings is not None:
dual_prompt_embeddings = torch.cat([prompt_embeddings, image_embeddings], dim=1)
elif image_embeddings is None:
dual_prompt_embeddings = prompt_embeddings
text_to_image_strength = 1.
elif prompt_embeddings is None:
dual_prompt_embeddings = image_embeddings
text_to_image_strength = 0.
else:
raise ValueError()
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
dual_prompt_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Combine the attention blocks of the image and text UNets
self.set_transformer_params(text_to_image_strength, ("text", "image"))
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
return image