|
import os |
|
import numpy as np |
|
from torchvision import transforms |
|
import torch |
|
import torch.nn as nn |
|
import PIL |
|
import clip |
|
import open_clip |
|
from functools import partial |
|
|
|
|
|
from dalle2_pytorch import DiffusionPrior |
|
from dalle2_pytorch.dalle2_pytorch import l2norm, default, exists |
|
from tqdm.auto import tqdm |
|
import random |
|
import json |
|
from dalle2_pytorch.train_configs import DiffusionPriorNetworkConfig |
|
|
|
from dalle2_pytorch.dalle2_pytorch import RotaryEmbedding, CausalTransformer, SinusoidalPosEmb, MLP, Rearrange, repeat, rearrange, prob_mask_like, LayerNorm, RelPosBias, FeedForward, Attention |
|
|
|
|
|
from diffusers import StableDiffusionImageVariationPipeline, VersatileDiffusionDualGuidedPipeline |
|
from typing import Callable, List, Optional, Union |
|
|
|
from diffusers.models.vae import Decoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Clipper(torch.nn.Module): |
|
def __init__(self, clip_variant, clamp_embs=False, norm_embs=False, |
|
hidden_state=False, device=torch.device('cpu')): |
|
super().__init__() |
|
assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \ |
|
"clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64" |
|
print(clip_variant, device) |
|
|
|
if clip_variant=="ViT-L/14" and hidden_state: |
|
|
|
|
|
from transformers import CLIPVisionModelWithProjection |
|
sd_cache_dir = '/fsx/proj-fmri/shared/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7' |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').eval() |
|
image_encoder = image_encoder.to(device) |
|
for param in image_encoder.parameters(): |
|
param.requires_grad = False |
|
self.image_encoder = image_encoder |
|
elif hidden_state: |
|
raise Exception("hidden_state embeddings only works with ViT-L/14 right now") |
|
|
|
clip_model, preprocess = clip.load(clip_variant, device=device) |
|
clip_model.eval() |
|
for param in clip_model.parameters(): |
|
param.requires_grad = False |
|
|
|
self.clip = clip_model |
|
self.clip_variant = clip_variant |
|
if clip_variant == "RN50x64": |
|
self.clip_size = (448,448) |
|
else: |
|
self.clip_size = (224,224) |
|
|
|
preproc = transforms.Compose([ |
|
transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(size=self.clip_size), |
|
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
self.preprocess = preproc |
|
self.hidden_state = hidden_state |
|
self.mean = np.array([0.48145466, 0.4578275, 0.40821073]) |
|
self.std = np.array([0.26862954, 0.26130258, 0.27577711]) |
|
self.normalize = transforms.Normalize(self.mean, self.std) |
|
self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist()) |
|
self.clamp_embs = clamp_embs |
|
self.norm_embs = norm_embs |
|
self.device= device |
|
|
|
def versatile_normalize_embeddings(encoder_output): |
|
embeds = encoder_output.last_hidden_state |
|
embeds = image_encoder.vision_model.post_layernorm(embeds) |
|
embeds = image_encoder.visual_projection(embeds) |
|
return embeds |
|
self.versatile_normalize_embeddings = versatile_normalize_embeddings |
|
|
|
def resize_image(self, image): |
|
|
|
return transforms.Resize(self.clip_size)(image.to(self.device)) |
|
|
|
def embed_image(self, image): |
|
"""Expects images in -1 to 1 range""" |
|
if self.hidden_state: |
|
|
|
clip_emb = self.preprocess((image).to(self.device)) |
|
clip_emb = self.image_encoder(clip_emb) |
|
clip_emb = self.versatile_normalize_embeddings(clip_emb) |
|
else: |
|
clip_emb = self.preprocess(image.to(self.device)) |
|
clip_emb = self.clip.encode_image(clip_emb) |
|
|
|
if self.clamp_embs: |
|
clip_emb = torch.clamp(clip_emb, -1.5, 1.5) |
|
if self.norm_embs: |
|
if self.hidden_state: |
|
|
|
clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1) |
|
else: |
|
clip_emb = nn.functional.normalize(clip_emb, dim=-1) |
|
return clip_emb |
|
|
|
def embed_text(self, text_samples): |
|
clip_text = clip.tokenize(text_samples).to(self.device) |
|
clip_text = self.clip.encode_text(clip_text) |
|
if self.clamp_embs: |
|
clip_text = torch.clamp(clip_text, -1.5, 1.5) |
|
if self.norm_embs: |
|
clip_text = nn.functional.normalize(clip_text, dim=-1) |
|
return clip_text |
|
|
|
def embed_curated_annotations(self, annots): |
|
for i,b in enumerate(annots): |
|
t = '' |
|
while t == '': |
|
rand = torch.randint(5,(1,1))[0][0] |
|
t = b[0,rand] |
|
if i==0: |
|
txt = np.array(t) |
|
else: |
|
txt = np.vstack((txt,t)) |
|
txt = txt.flatten() |
|
return self.embed_text(txt) |
|
|
|
class OpenClipper(torch.nn.Module): |
|
def __init__(self, clip_variant, norm_embs=False, device=torch.device('cpu')): |
|
super().__init__() |
|
print(clip_variant, device) |
|
assert clip_variant == 'ViT-H-14' |
|
|
|
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', |
|
pretrained='laion2b_s32b_b79k', device=device) |
|
clip_model.eval() |
|
for param in clip_model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=None), |
|
transforms.CenterCrop(224), |
|
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
|
|
tokenizer = open_clip.get_tokenizer('ViT-H-14') |
|
|
|
self.clip = clip_model |
|
self.norm_embs = norm_embs |
|
self.preprocess = preprocess |
|
self.tokenizer = tokenizer |
|
self.device = device |
|
|
|
def embed_image(self, image): |
|
"""Expects images in -1 to 1 range""" |
|
image = self.preprocess(image).to(self.device) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
image_features = self.clip.encode_image(image) |
|
if self.norm_embs: |
|
image_features = nn.functional.normalize(image_features, dim=-1) |
|
return image_features |
|
|
|
def embed_text(self, text_samples): |
|
text = self.tokenizer(text_samples).to(self.device) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
text_features = self.clip.encode_text(text) |
|
if self.norm_embs: |
|
text_features = nn.functional.normalize(text_features, dim=-1) |
|
return text_features |
|
|
|
def embed_curated_annotations(self, annots): |
|
for i,b in enumerate(annots): |
|
t = '' |
|
while t == '': |
|
rand = torch.randint(5,(1,1))[0][0] |
|
t = b[0,rand] |
|
if i==0: |
|
txt = np.array(t) |
|
else: |
|
txt = np.vstack((txt,t)) |
|
txt = txt.flatten() |
|
return self.embed_text(txt) |
|
|
|
class BrainNetwork(nn.Module): |
|
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15): |
|
super().__init__() |
|
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) |
|
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU |
|
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) |
|
|
|
self.lin0 = nn.Sequential( |
|
nn.Linear(in_dim, h), |
|
*[item() for item in act_and_norm], |
|
nn.Dropout(drop1), |
|
) |
|
self.mlp = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(h, h), |
|
*[item() for item in act_and_norm], |
|
nn.Dropout(drop2) |
|
) for _ in range(n_blocks) |
|
]) |
|
self.lin1 = nn.Linear(h, out_dim, bias=True) |
|
self.n_blocks = n_blocks |
|
self.clip_size = clip_size |
|
|
|
self.use_projector = use_projector |
|
if use_projector: |
|
self.projector = nn.Sequential( |
|
nn.LayerNorm(clip_size), |
|
nn.GELU(), |
|
nn.Linear(clip_size, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, clip_size) |
|
) |
|
|
|
def forward(self, x): |
|
''' |
|
bs, 1, 15724 -> bs, 32, h |
|
bs, 32, h -> bs, 32h |
|
b2, 32h -> bs, 768 |
|
''' |
|
if x.ndim == 4: |
|
|
|
assert x.shape[1] == 81 and x.shape[2] == 104 and x.shape[3] == 83 |
|
|
|
x = x.reshape(x.shape[0], -1) |
|
x = self.lin0(x) |
|
residual = x |
|
for res_block in range(self.n_blocks): |
|
x = self.mlp[res_block](x) |
|
x += residual |
|
residual = x |
|
x = x.reshape(len(x), -1) |
|
x = self.lin1(x) |
|
if self.use_projector: |
|
return x, self.projector(x.reshape(len(x), -1, self.clip_size)) |
|
return x |
|
|
|
class BrainDiffusionPriorOld(DiffusionPrior): |
|
""" |
|
Differences from original: |
|
- Allow for passing of generators to torch random functions |
|
- Option to include the voxel2clip model and pass voxels into forward method |
|
- Return predictions when computing loss |
|
- Load pretrained model from @nousr trained on LAION aesthetics |
|
""" |
|
def __init__(self, *args, **kwargs): |
|
voxel2clip = kwargs.pop('voxel2clip', None) |
|
super().__init__(*args, **kwargs) |
|
self.voxel2clip = voxel2clip |
|
|
|
@torch.no_grad() |
|
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1., |
|
generator=None): |
|
b, *_, device = *x.shape, x.device |
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) |
|
if generator is None: |
|
noise = torch.randn_like(x) |
|
else: |
|
|
|
noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator) |
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) |
|
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise |
|
return pred, x_start |
|
|
|
@torch.no_grad() |
|
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None): |
|
batch, device = shape[0], self.device |
|
|
|
if generator is None: |
|
image_embed = torch.randn(shape, device = device) |
|
else: |
|
image_embed = torch.randn(shape, device = device, generator=generator) |
|
x_start = None |
|
|
|
if self.init_image_embed_l2norm: |
|
image_embed = l2norm(image_embed) * self.image_embed_scale |
|
|
|
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True): |
|
times = torch.full((batch,), i, device = device, dtype = torch.long) |
|
|
|
self_cond = x_start if self.net.self_cond else None |
|
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale, |
|
generator=generator) |
|
|
|
if self.sampling_final_clamp_l2norm and self.predict_x_start: |
|
image_embed = self.l2norm_clamp_embed(image_embed) |
|
|
|
return image_embed |
|
|
|
def p_losses(self, image_embed, times, text_cond, noise = None): |
|
noise = default(noise, lambda: torch.randn_like(image_embed)) |
|
|
|
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) |
|
|
|
self_cond = None |
|
if self.net.self_cond and random.random() < 0.5: |
|
with torch.no_grad(): |
|
self_cond = self.net(image_embed_noisy, times, **text_cond).detach() |
|
|
|
pred = self.net( |
|
image_embed_noisy, |
|
times, |
|
self_cond = self_cond, |
|
text_cond_drop_prob = self.text_cond_drop_prob, |
|
image_cond_drop_prob = self.image_cond_drop_prob, |
|
**text_cond |
|
) |
|
|
|
if self.predict_x_start and self.training_clamp_l2norm: |
|
pred = self.l2norm_clamp_embed(pred) |
|
|
|
if self.predict_v: |
|
target = self.noise_scheduler.calculate_v(image_embed, times, noise) |
|
elif self.predict_x_start: |
|
target = image_embed |
|
else: |
|
target = noise |
|
|
|
loss = self.noise_scheduler.loss_fn(pred, target) |
|
return loss, pred |
|
|
|
def forward( |
|
self, |
|
text = None, |
|
image = None, |
|
voxel = None, |
|
text_embed = None, |
|
image_embed = None, |
|
text_encodings = None, |
|
*args, |
|
**kwargs |
|
): |
|
assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied' |
|
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied' |
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' |
|
|
|
if exists(voxel): |
|
assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels' |
|
assert not exists(text_embed), 'cannot pass in both text and voxels' |
|
text_embed = self.voxel2clip(voxel) |
|
|
|
if exists(image): |
|
image_embed, _ = self.clip.embed_image(image) |
|
|
|
|
|
|
|
if exists(text): |
|
text_embed, text_encodings = self.clip.embed_text(text) |
|
|
|
text_cond = dict(text_embed = text_embed) |
|
|
|
if self.condition_on_text_encodings: |
|
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' |
|
text_cond = {**text_cond, 'text_encodings': text_encodings} |
|
|
|
|
|
|
|
batch, device = image_embed.shape[0], image_embed.device |
|
times = self.noise_scheduler.sample_random_times(batch) |
|
|
|
|
|
|
|
image_embed *= self.image_embed_scale |
|
|
|
|
|
|
|
loss, pred = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs) |
|
|
|
return loss, pred |
|
|
|
@staticmethod |
|
def from_pretrained(net_kwargs={}, prior_kwargs={}, voxel2clip_path=None, ckpt_dir='./checkpoints'): |
|
|
|
config_url = os.path.join(ckpt_dir, "prior_config.json") |
|
config = json.load(open(config_url)) |
|
|
|
config['prior']['net']['max_text_len'] = 256 |
|
config['prior']['net'].update(net_kwargs) |
|
|
|
net_config = DiffusionPriorNetworkConfig(**config['prior']['net']) |
|
|
|
kwargs = config['prior'] |
|
kwargs.pop('clip') |
|
kwargs.pop('net') |
|
kwargs.update(prior_kwargs) |
|
|
|
|
|
diffusion_prior_network = net_config.create() |
|
diffusion_prior = BrainDiffusionPriorOld(net=diffusion_prior_network, clip=None, **kwargs).to(torch.device('cpu')) |
|
|
|
|
|
ckpt_url = os.path.join(ckpt_dir, 'best.pth') |
|
ckpt = torch.load(ckpt_url, map_location=torch.device('cpu')) |
|
|
|
|
|
|
|
|
|
diffusion_prior.load_state_dict(ckpt, strict=False) |
|
|
|
|
|
|
|
if voxel2clip_path: |
|
|
|
checkpoint = torch.load(voxel2clip_path, map_location=torch.device('cpu')) |
|
|
|
state_dict = checkpoint['model_state_dict'] |
|
for key in list(state_dict.keys()): |
|
if 'module.' in key: |
|
state_dict[key.replace('module.', '')] = state_dict[key] |
|
del state_dict[key] |
|
diffusion_prior.voxel2clip.load_state_dict(state_dict) |
|
|
|
return diffusion_prior |
|
|
|
class BrainDiffusionPrior(DiffusionPrior): |
|
""" |
|
Differences from original: |
|
- Allow for passing of generators to torch random functions |
|
- Option to include the voxel2clip model and pass voxels into forward method |
|
- Return predictions when computing loss |
|
- Load pretrained model from @nousr trained on LAION aesthetics |
|
""" |
|
def __init__(self, *args, **kwargs): |
|
voxel2clip = kwargs.pop('voxel2clip', None) |
|
super().__init__(*args, **kwargs) |
|
self.voxel2clip = voxel2clip |
|
|
|
@torch.no_grad() |
|
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1., |
|
generator=None): |
|
b, *_, device = *x.shape, x.device |
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) |
|
if generator is None: |
|
noise = torch.randn_like(x) |
|
else: |
|
|
|
noise = torch.randn(x.size(), device=x.device, dtype=x.dtype, generator=generator) |
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) |
|
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise |
|
return pred, x_start |
|
|
|
@torch.no_grad() |
|
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1., generator=None): |
|
batch, device = shape[0], self.device |
|
|
|
if generator is None: |
|
image_embed = torch.randn(shape, device = device) |
|
else: |
|
image_embed = torch.randn(shape, device = device, generator=generator) |
|
x_start = None |
|
|
|
if self.init_image_embed_l2norm: |
|
image_embed = l2norm(image_embed) * self.image_embed_scale |
|
|
|
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps, disable=True): |
|
times = torch.full((batch,), i, device = device, dtype = torch.long) |
|
|
|
self_cond = x_start if self.net.self_cond else None |
|
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale, |
|
generator=generator) |
|
|
|
if self.sampling_final_clamp_l2norm and self.predict_x_start: |
|
image_embed = self.l2norm_clamp_embed(image_embed) |
|
|
|
return image_embed |
|
|
|
def p_losses(self, image_embed, times, text_cond, noise = None): |
|
noise = default(noise, lambda: torch.randn_like(image_embed)) |
|
|
|
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) |
|
|
|
self_cond = None |
|
if self.net.self_cond and random.random() < 0.5: |
|
with torch.no_grad(): |
|
self_cond = self.net(image_embed_noisy, times, **text_cond).detach() |
|
|
|
pred = self.net( |
|
image_embed_noisy, |
|
times, |
|
self_cond = self_cond, |
|
text_cond_drop_prob = self.text_cond_drop_prob, |
|
image_cond_drop_prob = self.image_cond_drop_prob, |
|
**text_cond |
|
) |
|
|
|
if self.predict_x_start and self.training_clamp_l2norm: |
|
pred = self.l2norm_clamp_embed(pred) |
|
|
|
if self.predict_v: |
|
target = self.noise_scheduler.calculate_v(image_embed, times, noise) |
|
elif self.predict_x_start: |
|
target = image_embed |
|
else: |
|
target = noise |
|
|
|
loss = self.noise_scheduler.loss_fn(pred, target) |
|
return loss, pred |
|
|
|
def forward( |
|
self, |
|
text = None, |
|
image = None, |
|
voxel = None, |
|
text_embed = None, |
|
image_embed = None, |
|
text_encodings = None, |
|
*args, |
|
**kwargs |
|
): |
|
assert exists(text) ^ exists(text_embed) ^ exists(voxel), 'either text, text embedding, or voxel must be supplied' |
|
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied' |
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' |
|
|
|
if exists(voxel): |
|
assert exists(self.voxel2clip), 'voxel2clip must be trained if you wish to pass in voxels' |
|
assert not exists(text_embed), 'cannot pass in both text and voxels' |
|
if self.voxel2clip.use_projector: |
|
clip_voxels_mse, clip_voxels = self.voxel2clip(voxel) |
|
text_embed = clip_voxels_mse |
|
else: |
|
clip_voxels = self.voxel2clip(voxel) |
|
text_embed = clip_voxels_mse = clip_voxels |
|
|
|
|
|
if exists(image): |
|
image_embed, _ = self.clip.embed_image(image) |
|
|
|
|
|
|
|
if exists(text): |
|
text_embed, text_encodings = self.clip.embed_text(text) |
|
|
|
text_cond = dict(text_embed = text_embed) |
|
|
|
if self.condition_on_text_encodings: |
|
assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified' |
|
text_cond = {**text_cond, 'text_encodings': text_encodings} |
|
|
|
|
|
|
|
batch, device = image_embed.shape[0], image_embed.device |
|
times = self.noise_scheduler.sample_random_times(batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss, pred = self.p_losses(image_embed*self.image_embed_scale, times, text_cond = text_cond, *args, **kwargs) |
|
|
|
|
|
return loss, pred |
|
|
|
class BrainSD(StableDiffusionImageVariationPipeline): |
|
""" |
|
Differences from original: |
|
- Keep generated images on GPU and return tensors |
|
- No NSFW checker |
|
- Can pass in image or image_embedding to generate a variation |
|
NOTE: requires latest version of diffusers to avoid the latent dims not being correct. |
|
""" |
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / 0.18215 * latents |
|
image = self.vae.decode(latents).sample |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
return image |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
image_embeddings: Optional[torch.FloatTensor] = None, |
|
): |
|
|
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
if image_embeddings is None: |
|
assert image is not None, "If image_embeddings is None, image must not be None" |
|
|
|
|
|
tform = transforms.Compose([ |
|
|
|
transforms.Resize( |
|
(224, 224), |
|
interpolation=transforms.InterpolationMode.BICUBIC, |
|
antialias=False, |
|
), |
|
transforms.Normalize( |
|
[0.48145466, 0.4578275, 0.40821073], |
|
[0.26862954, 0.26130258, 0.27577711]), |
|
]) |
|
image = tform(image) |
|
|
|
|
|
self.check_inputs(image, height, width, callback_steps) |
|
|
|
|
|
if isinstance(image, PIL.Image.Image): |
|
batch_size = 1 |
|
elif isinstance(image, list): |
|
batch_size = len(image) |
|
else: |
|
batch_size = image.shape[0] |
|
|
|
|
|
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) |
|
else: |
|
batch_size = image_embeddings.shape[0] // 2 |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet.in_channels |
|
|
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
image_embeddings.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
|
|
image = self.decode_latents(latents) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return image |
|
|
|
class Voxel2StableDiffusionModel(torch.nn.Module): |
|
def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False): |
|
super().__init__() |
|
self.lin0 = nn.Sequential( |
|
nn.Linear(in_dim, h, bias=False), |
|
nn.LayerNorm(h), |
|
nn.SiLU(inplace=True), |
|
nn.Dropout(0.5), |
|
) |
|
|
|
self.mlp = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(h, h, bias=False), |
|
nn.LayerNorm(h), |
|
nn.SiLU(inplace=True), |
|
nn.Dropout(0.25) |
|
) for _ in range(n_blocks) |
|
]) |
|
self.lin1 = nn.Linear(h, 16384, bias=False) |
|
self.norm = nn.LayerNorm(512) |
|
|
|
self.register_parameter('queries', nn.Parameter(torch.randn(1, 256, 512) * 0.044)) |
|
self.transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer(d_model=512, nhead=8, norm_first=True, |
|
dim_feedforward=1024, activation=nn.functional.gelu, |
|
batch_first=True, dropout=0.25), |
|
num_layers=n_blocks |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_cont: |
|
self.maps_projector = nn.Sequential( |
|
nn.LayerNorm(512), |
|
nn.Linear(512, 512), |
|
nn.LayerNorm(512), |
|
nn.ReLU(True), |
|
nn.Linear(512, 512), |
|
nn.LayerNorm(512), |
|
nn.ReLU(True), |
|
nn.Linear(512, 512) |
|
) |
|
else: |
|
self.maps_projector = nn.Identity() |
|
|
|
self.upsampler = nn.Sequential( |
|
nn.GroupNorm(1, 32), |
|
nn.SiLU(inplace=True), |
|
nn.Conv2d(32, 320, 3, padding=1), |
|
nn.GroupNorm(32, 320), |
|
nn.SiLU(inplace=True), |
|
nn.Conv2d(320, 320, 3, padding=1), |
|
nn.GroupNorm(32, 320), |
|
nn.SiLU(inplace=True), |
|
nn.Conv2d(320, 4, 3, padding=1) |
|
) |
|
|
|
def forward(self, x, return_transformer_feats=False): |
|
x = self.lin0(x) |
|
residual = x |
|
for res_block in self.mlp: |
|
x = res_block(x) |
|
x = x + residual |
|
residual = x |
|
x = x.reshape(len(x), -1) |
|
x = self.lin1(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.norm(x.reshape(x.shape[0], 32, 512)) |
|
preds = self.transformer(self.queries.expand(x.shape[0], -1, -1), x) |
|
sd_embeds = preds.permute(0,2,1).reshape(-1, 512, 16, 16) |
|
sd_embeds = nn.functional.pixel_shuffle(sd_embeds, 4) |
|
|
|
|
|
if return_transformer_feats: |
|
return self.upsampler(sd_embeds), self.maps_projector(preds) |
|
|
|
return self.upsampler(sd_embeds) |
|
|
|
class BrainNetworkDETR(BrainNetwork): |
|
|
|
def __init__(self, out_dim=768, in_dim=15724, h=4096, n_blocks=4, norm_type='ln', act_first=False, |
|
encoder_tokens=32, decoder_tokens=257): |
|
|
|
super().__init__(out_dim*encoder_tokens, in_dim, h, n_blocks, norm_type, act_first) |
|
self.norm = nn.LayerNorm(out_dim) |
|
self.encoder_tokens = encoder_tokens |
|
|
|
self.register_parameter('queries', nn.Parameter(torch.randn(1, decoder_tokens, out_dim))) |
|
self.transformer = nn.TransformerDecoder( |
|
nn.TransformerDecoderLayer(d_model=out_dim, nhead=8, |
|
dim_feedforward=1024, |
|
batch_first=True, dropout=0.25), |
|
num_layers=n_blocks |
|
) |
|
self.decoder_projector = nn.Sequential( |
|
nn.LayerNorm(out_dim), |
|
nn.Linear(out_dim, out_dim) |
|
) |
|
|
|
|
|
def forward(self, x): |
|
enc = super().forward(x) |
|
enc = self.norm(enc.reshape(enc.shape[0], self.encoder_tokens, -1)) |
|
|
|
dec = self.transformer(self.queries.expand(x.shape[0], -1, -1), enc) |
|
dec = self.decoder_projector(dec) |
|
return dec |
|
|
|
class VersatileDiffusionPriorNetwork(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
num_timesteps = None, |
|
num_time_embeds = 1, |
|
|
|
|
|
num_tokens = 257, |
|
causal = True, |
|
learned_query_mode = 'none', |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.num_time_embeds = num_time_embeds |
|
self.continuous_embedded_time = not exists(num_timesteps) |
|
self.learned_query_mode = learned_query_mode |
|
|
|
self.to_time_embeds = nn.Sequential( |
|
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), |
|
Rearrange('b (n d) -> b n d', n = num_time_embeds) |
|
) |
|
|
|
if self.learned_query_mode == 'token': |
|
self.learned_query = nn.Parameter(torch.randn(num_tokens, dim)) |
|
if self.learned_query_mode == 'pos_emb': |
|
scale = dim ** -0.5 |
|
self.learned_query = nn.Parameter(torch.randn(num_tokens, dim) * scale) |
|
if self.learned_query_mode == 'all_pos_emb': |
|
scale = dim ** -0.5 |
|
self.learned_query = nn.Parameter(torch.randn(num_tokens*2+1, dim) * scale) |
|
self.causal_transformer = FlaggedCausalTransformer(dim = dim, causal=causal, **kwargs) |
|
|
|
self.null_brain_embeds = nn.Parameter(torch.randn(num_tokens, dim)) |
|
self.null_image_embed = nn.Parameter(torch.randn(num_tokens, dim)) |
|
|
|
self.num_tokens = num_tokens |
|
self.self_cond = False |
|
|
|
def forward_with_cond_scale( |
|
self, |
|
*args, |
|
cond_scale = 1., |
|
**kwargs |
|
): |
|
logits = self.forward(*args, **kwargs) |
|
|
|
if cond_scale == 1: |
|
return logits |
|
|
|
null_logits = self.forward(*args, brain_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs) |
|
return null_logits + (logits - null_logits) * cond_scale |
|
|
|
def forward( |
|
self, |
|
image_embed, |
|
diffusion_timesteps, |
|
*, |
|
self_cond=None, |
|
brain_embed=None, |
|
text_embed=None, |
|
brain_cond_drop_prob = 0., |
|
text_cond_drop_prob = None, |
|
image_cond_drop_prob = 0. |
|
): |
|
if text_embed is not None: |
|
brain_embed = text_embed |
|
if text_cond_drop_prob is not None: |
|
brain_cond_drop_prob = text_cond_drop_prob |
|
|
|
image_embed = image_embed.view(len(image_embed),-1,768) |
|
|
|
brain_embed = brain_embed.view(len(brain_embed),-1,768) |
|
|
|
|
|
|
|
batch, _, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype |
|
|
|
|
|
|
|
brain_keep_mask = prob_mask_like((batch,), 1 - brain_cond_drop_prob, device = device) |
|
brain_keep_mask = rearrange(brain_keep_mask, 'b -> b 1 1') |
|
|
|
image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device) |
|
image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1') |
|
|
|
|
|
|
|
|
|
null_brain_embeds = self.null_brain_embeds.to(brain_embed.dtype) |
|
brain_embed = torch.where( |
|
brain_keep_mask, |
|
brain_embed, |
|
null_brain_embeds[None] |
|
) |
|
|
|
|
|
null_image_embed = self.null_image_embed.to(image_embed.dtype) |
|
image_embed = torch.where( |
|
image_keep_mask, |
|
image_embed, |
|
null_image_embed[None] |
|
) |
|
|
|
|
|
|
|
if self.continuous_embedded_time: |
|
|
|
diffusion_timesteps = diffusion_timesteps.type(dtype) |
|
time_embed = self.to_time_embeds(diffusion_timesteps) |
|
|
|
if self.learned_query_mode == 'token': |
|
learned_queries = repeat(self.learned_query, 'n d -> b n d', b = batch) |
|
elif self.learned_query_mode == 'pos_emb': |
|
pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch) |
|
image_embed = image_embed + pos_embs |
|
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device) |
|
elif self.learned_query_mode == 'all_pos_emb': |
|
pos_embs = repeat(self.learned_query, 'n d -> b n d', b = batch) |
|
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device) |
|
else: |
|
learned_queries = torch.empty((batch, 0, dim), device=brain_embed.device) |
|
|
|
tokens = torch.cat(( |
|
brain_embed, |
|
time_embed, |
|
image_embed, |
|
learned_queries |
|
), dim = -2) |
|
if self.learned_query_mode == 'all_pos_emb': |
|
tokens = tokens + pos_embs |
|
|
|
|
|
tokens = self.causal_transformer(tokens) |
|
|
|
|
|
pred_image_embed = tokens[..., -self.num_tokens:, :] |
|
|
|
return pred_image_embed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlaggedCausalTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth, |
|
dim_head = 64, |
|
heads = 8, |
|
ff_mult = 4, |
|
norm_in = False, |
|
norm_out = True, |
|
attn_dropout = 0., |
|
ff_dropout = 0., |
|
final_proj = True, |
|
normformer = False, |
|
rotary_emb = False, |
|
causal=True |
|
): |
|
super().__init__() |
|
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() |
|
|
|
self.rel_pos_bias = RelPosBias(heads = heads) |
|
|
|
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None |
|
|
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append(nn.ModuleList([ |
|
Attention(dim = dim, causal = causal, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb), |
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) |
|
])) |
|
|
|
self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() |
|
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity() |
|
|
|
def forward(self, x): |
|
n, device = x.shape[1], x.device |
|
|
|
x = self.init_norm(x) |
|
|
|
attn_bias = self.rel_pos_bias(n, n + 1, device = device) |
|
|
|
for attn, ff in self.layers: |
|
x = attn(x, attn_bias = attn_bias) + x |
|
x = ff(x) + x |
|
|
|
out = self.norm(x) |
|
return self.project_out(out) |
|
|
|
class BrainVD(VersatileDiffusionDualGuidedPipeline): |
|
""" |
|
Differences from original: |
|
- Keep generated images on GPU and return tensors |
|
- No NSFW checker |
|
- Can pass in image or image_embedding to generate a variation |
|
NOTE: requires latest version of diffusers to avoid the latent dims not being correct. |
|
""" |
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / self.vae.config.scaling_factor * latents |
|
image = self.vae.decode(latents).sample |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
return image |
|
|
|
def check_inputs(self, prompt, image, height, width, callback_steps): |
|
if prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list): |
|
raise ValueError(f"`prompt` has to be of type None, `str` or `list` but is {type(prompt)}") |
|
if image is not None and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): |
|
raise ValueError(f"`image` has to be of type None, `PIL.Image` or `list` but is {type(image)}") |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
if (callback_steps is None) or ( |
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) |
|
): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]] = None, |
|
image: Union[str, List[str]] = None, |
|
text_to_image_strength: float = 0.5, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: Optional[int] = 1, |
|
image_embeddings: Optional[torch.FloatTensor] = None, |
|
prompt_embeddings: Optional[torch.FloatTensor] = None, |
|
**kwargs, |
|
): |
|
|
|
height = height or self.image_unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.image_unet.config.sample_size * self.vae_scale_factor |
|
|
|
self.check_inputs(prompt, image, height, width, callback_steps) |
|
|
|
prompt = [prompt] if prompt is not None and not isinstance(prompt, list) else prompt |
|
image = [image] if image is not None and not isinstance(image, list) else image |
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
|
|
if image_embeddings is None: |
|
if image is not None: |
|
image_embeddings = self._encode_image_prompt( |
|
image, device, num_images_per_prompt, do_classifier_free_guidance |
|
) |
|
batch_size = len(image) |
|
else: |
|
image_embeddings = None |
|
|
|
if prompt_embeddings is None: |
|
if prompt is not None: |
|
prompt_embeddings = self._encode_text_prompt( |
|
prompt, device, num_images_per_prompt, do_classifier_free_guidance |
|
) |
|
batch_size = len(prompt) |
|
else: |
|
prompt_embeddings = None |
|
if image_embeddings is not None: |
|
batch_size = image_embeddings.shape[0] // 2 |
|
elif prompt_embeddings is not None: |
|
batch_size = prompt_embeddings.shape[0] // 2 |
|
|
|
if image_embeddings is not None and prompt_embeddings is not None: |
|
dual_prompt_embeddings = torch.cat([prompt_embeddings, image_embeddings], dim=1) |
|
elif image_embeddings is None: |
|
dual_prompt_embeddings = prompt_embeddings |
|
text_to_image_strength = 1. |
|
elif prompt_embeddings is None: |
|
dual_prompt_embeddings = image_embeddings |
|
text_to_image_strength = 0. |
|
else: |
|
raise ValueError() |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.image_unet.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
dual_prompt_embeddings.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
self.set_transformer_params(text_to_image_strength, ("text", "image")) |
|
|
|
|
|
for i, t in enumerate(self.progress_bar(timesteps)): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
|
|
image = self.decode_latents(latents) |
|
|
|
return image |
|
|