|
from pipeline_stable_diffusion_xl_vqvae import StableDiffusionXLPipeline |
|
from diffusers import UNet2DConditionModel |
|
import torch |
|
import os |
|
from torchvision import transforms |
|
from PIL import Image |
|
import glob |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
class Projection(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.deconv = nn.ConvTranspose2d(256, 4, 2, stride=2, padding=0, bias=False) |
|
self.norm = nn.LayerNorm(4) |
|
|
|
def save_pretrained(self, dir): |
|
if not os.path.exists(dir): os.makedirs(dir) |
|
torch.save(self.state_dict(), os.path.join(dir, "projection.pth")) |
|
|
|
def from_pretrained(self, dir): |
|
state_dict = torch.load(os.path.join(dir, "projection.pth")) |
|
msg = self.load_state_dict(state_dict, strict=False) |
|
print(msg) |
|
|
|
def forward(self, x): |
|
x = self.deconv(x) |
|
B, C, H, W = x.shape |
|
x = x.reshape(B, C, -1).permute(0, 2, 1) |
|
x = self.norm(x) |
|
x = x.permute(0, 2, 1).reshape(B, C, H, W) |
|
return x |
|
|
|
|
|
resolution = 512 |
|
|
|
|
|
train_resize = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR) |
|
train_resize2 = transforms.Resize(resolution//2, interpolation=transforms.InterpolationMode.BILINEAR) |
|
train_crop = transforms.CenterCrop(resolution) |
|
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) |
|
|
|
def preprocess_train(image): |
|
original_size = (image.height, image.width) |
|
image = train_resize(image) |
|
|
|
y1 = max(0, int(round((image.height - resolution) / 2.0))) |
|
x1 = max(0, int(round((image.width - resolution) / 2.0))) |
|
image = train_crop(image) |
|
image2 = train_resize2(image) |
|
|
|
crop_top_left = (y1, x1) |
|
image = train_transforms(image) |
|
image2 = train_transforms(image2) |
|
|
|
examples = { |
|
"original_sizes": original_size, |
|
"crop_top_lefts": crop_top_left, |
|
"pixel_values": image, |
|
"pixel_values2": image2 |
|
} |
|
return examples |
|
|
|
|
|
def compute_vqvae_encodings(image, vqvae): |
|
images = [image] |
|
pixel_values = torch.stack(list(images)) |
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
|
pixel_values = pixel_values.to(vqvae._device, dtype=vqvae._dtype) |
|
|
|
with torch.no_grad(): |
|
z_q, _, _ = vqvae._vq_model.encode(pixel_values) |
|
return z_q |
|
|
|
ckpt_path = "sdxl-vqvae-model-v2/checkpoint-4000/unet" |
|
unet = UNet2DConditionModel.from_pretrained( |
|
ckpt_path, |
|
|
|
|
|
|
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
) |
|
|
|
projection_module = Projection() |
|
projection_module.from_pretrained(ckpt_path) |
|
projection_module.cuda() |
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained("../sdxl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") |
|
pipe.unet = unet |
|
pipe.to("cuda") |
|
|
|
|
|
from image_tokenizer import ImageTokenizer |
|
ckpt_dir = "/mnt/bn/hanjiaming-lq/code/mllm_bridge/anole/ckpts/Anole-7b-v0.1/tokenizer" |
|
cfg_path = os.path.join(ckpt_dir, 'vqgan.yaml') |
|
ckpt_path = os.path.join(ckpt_dir, 'vqgan.ckpt') |
|
vq_tokenizer = ImageTokenizer(cfg_path, ckpt_path, "cuda") |
|
|
|
|
|
|
|
data_dir = "/mnt/bn/hanjiaming-lq/data/JourneyDB/data/valid/imgs/" |
|
datas = glob.glob(os.path.join(data_dir, "*/*"+".jpg"))[:256] |
|
|
|
|
|
|
|
for i, image_path in enumerate(datas): |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
image_tensor = preprocess_train(image)['pixel_values'].cuda() |
|
|
|
ori_image = (image_tensor * 0.5 + 0.5).clip(0.0, 1.0) |
|
ori_image = (ori_image * 255).permute(1, 2, 0).detach().cpu().numpy() |
|
ori_image = np.array(ori_image, dtype=np.uint8) |
|
|
|
prompt_embeddings = compute_vqvae_encodings(image_tensor, vq_tokenizer).cuda() |
|
empty_prompt_embeddings = compute_vqvae_encodings(torch.zeros_like(image_tensor), vq_tokenizer).cuda() |
|
|
|
|
|
rec_embeddings = vq_tokenizer._vq_model.decode(prompt_embeddings) |
|
rec_image = vq_tokenizer._pil_from_chw_tensor(rec_embeddings[0]) |
|
rec_numpy = np.array(rec_image) |
|
|
|
prompt_embeddings = projection_module(prompt_embeddings) |
|
empty_prompt_embeddings = projection_module(empty_prompt_embeddings) |
|
|
|
images = pipe(prompt="", height=512, width=512, |
|
guidance_scale = 1.0, |
|
num_inference_steps = 50, |
|
vqvae_embedding=prompt_embeddings, empty_vqvae_embedding=empty_prompt_embeddings).images[0] |
|
|
|
new_image = np.concatenate([ori_image, rec_numpy, np.array(images)], axis=1) |
|
new_image = Image.fromarray(new_image) |
|
new_image.save(f"output/{str(i)}.png") |