t2i / test_sdxl_vqvae_v2.py
csuhan's picture
Upload folder using huggingface_hub
71dd809 verified
from pipeline_stable_diffusion_xl_vqvae import StableDiffusionXLPipeline
from diffusers import UNet2DConditionModel
import torch
import os
from torchvision import transforms
from PIL import Image
import glob
import torch.nn as nn
import numpy as np
class Projection(nn.Module):
def __init__(self) -> None:
super().__init__()
self.deconv = nn.ConvTranspose2d(256, 4, 2, stride=2, padding=0, bias=False)
self.norm = nn.LayerNorm(4)
def save_pretrained(self, dir):
if not os.path.exists(dir): os.makedirs(dir)
torch.save(self.state_dict(), os.path.join(dir, "projection.pth"))
def from_pretrained(self, dir):
state_dict = torch.load(os.path.join(dir, "projection.pth"))
msg = self.load_state_dict(state_dict, strict=False)
print(msg)
def forward(self, x):
x = self.deconv(x)
B, C, H, W = x.shape
x = x.reshape(B, C, -1).permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1).reshape(B, C, H, W)
return x
resolution = 512
# Preprocessing the datasets.
train_resize = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)
train_resize2 = transforms.Resize(resolution//2, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(resolution)
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
def preprocess_train(image):
original_size = (image.height, image.width)
image = train_resize(image)
y1 = max(0, int(round((image.height - resolution) / 2.0)))
x1 = max(0, int(round((image.width - resolution) / 2.0)))
image = train_crop(image)
image2 = train_resize2(image)
crop_top_left = (y1, x1)
image = train_transforms(image)
image2 = train_transforms(image2)
examples = {
"original_sizes": original_size,
"crop_top_lefts": crop_top_left,
"pixel_values": image,
"pixel_values2": image2
}
return examples
def compute_vqvae_encodings(image, vqvae):
images = [image]
pixel_values = torch.stack(list(images))
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
pixel_values = pixel_values.to(vqvae._device, dtype=vqvae._dtype)
with torch.no_grad():
z_q, _, _ = vqvae._vq_model.encode(pixel_values)
return z_q
ckpt_path = "sdxl-vqvae-model-v2/checkpoint-4000/unet"
unet = UNet2DConditionModel.from_pretrained(
ckpt_path,
# subfolder="",
# encoder_hid_dim=256,
# projection_class_embeddings_input_dim=1792,
torch_dtype=torch.float16,
use_safetensors=True,
)
projection_module = Projection()
projection_module.from_pretrained(ckpt_path)
projection_module.cuda()
pipe = StableDiffusionXLPipeline.from_pretrained("../sdxl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe.unet = unet
pipe.to("cuda")
# VQ-VAE
from image_tokenizer import ImageTokenizer
ckpt_dir = "/mnt/bn/hanjiaming-lq/code/mllm_bridge/anole/ckpts/Anole-7b-v0.1/tokenizer"
cfg_path = os.path.join(ckpt_dir, 'vqgan.yaml')
ckpt_path = os.path.join(ckpt_dir, 'vqgan.ckpt')
vq_tokenizer = ImageTokenizer(cfg_path, ckpt_path, "cuda")
# image_path = "/mnt/bn/hanjiaming-lq/code/mllm_bridge/lumina_t2i/bridge.jpg"
data_dir = "/mnt/bn/hanjiaming-lq/data/JourneyDB/data/valid/imgs/"
datas = glob.glob(os.path.join(data_dir, "*/*"+".jpg"))[:256]
# datas = glob.glob("/mnt/bn/hanjiaming-lq/code/mllm_bridge/lumina_t2i/*.jpg")
for i, image_path in enumerate(datas):
# image_path = datas[]
image = Image.open(image_path).convert('RGB')
image_tensor = preprocess_train(image)['pixel_values'].cuda()
ori_image = (image_tensor * 0.5 + 0.5).clip(0.0, 1.0)
ori_image = (ori_image * 255).permute(1, 2, 0).detach().cpu().numpy()
ori_image = np.array(ori_image, dtype=np.uint8)
prompt_embeddings = compute_vqvae_encodings(image_tensor, vq_tokenizer).cuda()
empty_prompt_embeddings = compute_vqvae_encodings(torch.zeros_like(image_tensor), vq_tokenizer).cuda()
# vqvae reconstruction
rec_embeddings = vq_tokenizer._vq_model.decode(prompt_embeddings)
rec_image = vq_tokenizer._pil_from_chw_tensor(rec_embeddings[0])
rec_numpy = np.array(rec_image)
prompt_embeddings = projection_module(prompt_embeddings)
empty_prompt_embeddings = projection_module(empty_prompt_embeddings)
images = pipe(prompt="", height=512, width=512,
guidance_scale = 1.0,
num_inference_steps = 50,
vqvae_embedding=prompt_embeddings, empty_vqvae_embedding=empty_prompt_embeddings).images[0]
new_image = np.concatenate([ori_image, rec_numpy, np.array(images)], axis=1)
new_image = Image.fromarray(new_image)
new_image.save(f"output/{str(i)}.png")