ELITE / model.py
hysts's picture
hysts HF staff
Update
71235ec
from __future__ import annotations
import pathlib
import random
import sys
from typing import Any
import cv2
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import tqdm.auto
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import hf_hub_download
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
repo_dir = pathlib.Path(__file__).parent
submodule_dir = repo_dir / 'ELITE'
sys.path.insert(0, submodule_dir.as_posix())
from train_local import (Mapper, MapperLocal, inj_forward_crossattention,
inj_forward_text, th2image, value_local_list)
def get_tensor_clip(normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [T.ToTensor()]
if normalize:
transform_list += [
T.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
]
return T.Compose(transform_list)
def process(image: np.ndarray, size: int = 512) -> torch.Tensor:
image = cv2.resize(image, (size, size), interpolation=cv2.INTER_CUBIC)
image = np.array(image).astype(np.float32)
image = image / 127.5 - 1.0
return torch.from_numpy(image).permute(2, 0, 1)
class Model:
def __init__(self):
self.device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu')
(self.vae, self.unet, self.text_encoder, self.tokenizer,
self.image_encoder, self.mapper, self.mapper_local,
self.scheduler) = self.load_model()
def download_mappers(self) -> tuple[str, str]:
global_mapper_path = hf_hub_download('ELITE-library/ELITE',
'global_mapper.pt',
subfolder='checkpoints',
repo_type='model')
local_mapper_path = hf_hub_download('ELITE-library/ELITE',
'local_mapper.pt',
subfolder='checkpoints',
repo_type='model')
return global_mapper_path, local_mapper_path
def load_model(
self,
scheduler_type=LMSDiscreteScheduler
) -> tuple[UNet2DConditionModel, CLIPTextModel, CLIPTokenizer,
AutoencoderKL, CLIPVisionModel, Mapper, MapperLocal,
LMSDiscreteScheduler, ]:
diffusion_model_id = 'CompVis/stable-diffusion-v1-4'
vae = AutoencoderKL.from_pretrained(
diffusion_model_id,
subfolder='vae',
torch_dtype=torch.float16,
)
tokenizer = CLIPTokenizer.from_pretrained(
'openai/clip-vit-large-patch14',
torch_dtype=torch.float16,
)
text_encoder = CLIPTextModel.from_pretrained(
'openai/clip-vit-large-patch14',
torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModel.from_pretrained(
'openai/clip-vit-large-patch14',
torch_dtype=torch.float16,
)
# Load models and create wrapper for stable diffusion
for _module in text_encoder.modules():
if _module.__class__.__name__ == 'CLIPTextTransformer':
_module.__class__.__call__ = inj_forward_text
unet = UNet2DConditionModel.from_pretrained(
diffusion_model_id,
subfolder='unet',
torch_dtype=torch.float16,
)
inj_forward_crossattention
mapper = Mapper(input_dim=1024, output_dim=768)
mapper_local = MapperLocal(input_dim=1024, output_dim=768)
for _name, _module in unet.named_modules():
if _module.__class__.__name__ == 'CrossAttention':
if 'attn1' in _name:
continue
_module.__class__.__call__ = inj_forward_crossattention
shape = _module.to_k.weight.shape
to_k_global = nn.Linear(shape[1], shape[0], bias=False)
mapper.add_module(f'{_name.replace(".", "_")}_to_k',
to_k_global)
shape = _module.to_v.weight.shape
to_v_global = nn.Linear(shape[1], shape[0], bias=False)
mapper.add_module(f'{_name.replace(".", "_")}_to_v',
to_v_global)
to_v_local = nn.Linear(shape[1], shape[0], bias=False)
mapper_local.add_module(f'{_name.replace(".", "_")}_to_v',
to_v_local)
to_k_local = nn.Linear(shape[1], shape[0], bias=False)
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k',
to_k_local)
global_mapper_path, local_mapper_path = self.download_mappers()
mapper.load_state_dict(
torch.load(global_mapper_path, map_location='cpu'))
mapper.half()
mapper_local.load_state_dict(
torch.load(local_mapper_path, map_location='cpu'))
mapper_local.half()
for _name, _module in unet.named_modules():
if 'attn1' in _name:
continue
if _module.__class__.__name__ == 'CrossAttention':
_module.add_module(
'to_k_global',
mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
_module.add_module(
'to_v_global',
mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
_module.add_module(
'to_v_local',
getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
_module.add_module(
'to_k_local',
getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
vae.eval().to(self.device)
unet.eval().to(self.device)
text_encoder.eval().to(self.device)
image_encoder.eval().to(self.device)
mapper.eval().to(self.device)
mapper_local.eval().to(self.device)
scheduler = scheduler_type(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
)
return (vae, unet, text_encoder, tokenizer, image_encoder, mapper,
mapper_local, scheduler)
def prepare_data(self,
image: PIL.Image.Image,
mask: PIL.Image.Image,
text: str,
placeholder_string: str = 'S') -> dict[str, Any]:
data: dict[str, Any] = {}
data['text'] = text
placeholder_index = 0
words = text.strip().split(' ')
for idx, word in enumerate(words):
if word == placeholder_string:
placeholder_index = idx + 1
data['index'] = torch.tensor(placeholder_index)
data['input_ids'] = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors='pt',
).input_ids[0]
image = image.convert('RGB')
mask = mask.convert('RGB')
mask = np.array(mask) / 255.0
image_np = np.array(image)
object_tensor = image_np * mask
data['pixel_values'] = process(image_np)
ref_object_tensor = PIL.Image.fromarray(
object_tensor.astype('uint8')).resize(
(224, 224), resample=PIL.Image.Resampling.BICUBIC)
ref_image_tenser = PIL.Image.fromarray(
image_np.astype('uint8')).resize(
(224, 224), resample=PIL.Image.Resampling.BICUBIC)
data['pixel_values_obj'] = get_tensor_clip()(ref_object_tensor)
data['pixel_values_clip'] = get_tensor_clip()(ref_image_tenser)
ref_seg_tensor = PIL.Image.fromarray(mask.astype('uint8') * 255)
ref_seg_tensor = get_tensor_clip(normalize=False)(ref_seg_tensor)
data['pixel_values_seg'] = F.interpolate(ref_seg_tensor.unsqueeze(0),
size=(128, 128),
mode='nearest').squeeze(0)
device = torch.device('cuda:0')
data['pixel_values'] = data['pixel_values'].to(device)
data['pixel_values_clip'] = data['pixel_values_clip'].to(device).half()
data['pixel_values_obj'] = data['pixel_values_obj'].to(device).half()
data['pixel_values_seg'] = data['pixel_values_seg'].to(device).half()
data['input_ids'] = data['input_ids'].to(device)
data['index'] = data['index'].to(device).long()
for key, value in list(data.items()):
if isinstance(value, torch.Tensor):
data[key] = value.unsqueeze(0)
return data
@torch.inference_mode()
def run(
self,
image: dict[str, PIL.Image.Image],
text: str,
seed: int,
guidance_scale: float,
lambda_: float,
num_steps: int,
) -> PIL.Image.Image:
data = self.prepare_data(image['image'], image['mask'], text)
uncond_input = self.tokenizer(
[''] * data['pixel_values'].shape[0],
padding='max_length',
max_length=self.tokenizer.model_max_length,
return_tensors='pt',
)
uncond_embeddings = self.text_encoder(
{'input_ids': uncond_input.input_ids.to(self.device)})[0]
if seed == -1:
seed = random.randint(0, 1000000)
generator = torch.Generator().manual_seed(seed)
latents = torch.randn(
(data['pixel_values'].shape[0], self.unet.in_channels, 64, 64),
generator=generator,
)
latents = latents.to(data['pixel_values_clip'])
self.scheduler.set_timesteps(num_steps)
latents = latents * self.scheduler.init_noise_sigma
placeholder_idx = data['index']
image = F.interpolate(data['pixel_values_clip'], (224, 224),
mode='bilinear')
image_features = self.image_encoder(image, output_hidden_states=True)
image_embeddings = [
image_features[0],
image_features[2][4],
image_features[2][8],
image_features[2][12],
image_features[2][16],
]
image_embeddings = [emb.detach() for emb in image_embeddings]
inj_embedding = self.mapper(image_embeddings)
inj_embedding = inj_embedding[:, 0:1, :]
encoder_hidden_states = self.text_encoder({
'input_ids':
data['input_ids'],
'inj_embedding':
inj_embedding,
'inj_index':
placeholder_idx,
})[0]
image_obj = F.interpolate(data['pixel_values_obj'], (224, 224),
mode='bilinear')
image_features_obj = self.image_encoder(image_obj,
output_hidden_states=True)
image_embeddings_obj = [
image_features_obj[0],
image_features_obj[2][4],
image_features_obj[2][8],
image_features_obj[2][12],
image_features_obj[2][16],
]
image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj]
inj_embedding_local = self.mapper_local(image_embeddings_obj)
mask = F.interpolate(data['pixel_values_seg'], (16, 16),
mode='nearest')
mask = mask[:, 0].reshape(mask.shape[0], -1, 1)
inj_embedding_local = inj_embedding_local * mask
for t in tqdm.auto.tqdm(self.scheduler.timesteps):
latent_model_input = self.scheduler.scale_model_input(latents, t)
noise_pred_text = self.unet(latent_model_input,
t,
encoder_hidden_states={
'CONTEXT_TENSOR':
encoder_hidden_states,
'LOCAL': inj_embedding_local,
'LOCAL_INDEX':
placeholder_idx.detach(),
'LAMBDA': lambda_
}).sample
value_local_list.clear()
latent_model_input = self.scheduler.scale_model_input(latents, t)
noise_pred_uncond = self.unet(latent_model_input,
t,
encoder_hidden_states={
'CONTEXT_TENSOR':
uncond_embeddings,
}).sample
value_local_list.clear()
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
_latents = 1 / 0.18215 * latents.clone()
images = self.vae.decode(_latents).sample
return th2image(images[0])