import os import json import cv2 import torch from torch import nn from PIL import Image import numpy as np from diffusers import UniPCMultistepScheduler import torch.nn.functional as F from torchvision import transforms from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from transformers import CLIPImageProcessor from src.pipelines.stage3_refined_pipeline import Stage3_RefinedPipeline import argparse from transformers import Dinov2Model from typing import Any, Dict, List, Optional, Tuple, Union from skimage.metrics import structural_similarity as compare_ssim import torch import torch.nn as nn import torch.multiprocessing as mp import json import time def split_list_into_chunks(lst, n): chunk_size = len(lst) // n chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] if len(chunks) > n: last_chunk = chunks.pop() chunks[-1].extend(last_chunk) return chunks def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module class ImageProjModel_p(torch.nn.Module): """SD model with image prompt""" def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, out_dim), nn.Dropout(dropout) ) def forward(self, x): # b, 257,1280 return self.net(x) def inference(): device = "cuda" generator = torch.Generator(device=device).manual_seed(42) clip_image_processor = CLIPImageProcessor() img_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) # model define image_proj_model_p_dict = {} unet_dict = {} image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant').to(device).eval() image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).eval() #model_ckpt = "{}/mp_rank_00_model_states.pt".format('{save_ckpt}') model_ckpt = "s3_512.pt" with torch.no_grad(): model_sd = torch.load(model_ckpt)["module"] for k in model_sd.keys(): if k.startswith("image_proj_model_p"): image_proj_model_p_dict[k.replace("image_proj_model_p.", "")] = model_sd[k] elif k.startswith("unet"): unet_dict[k.replace("unet.", "")] = model_sd[k] else: print(k) image_proj_model_p.load_state_dict(image_proj_model_p_dict) pipe = Stage3_RefinedPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base",torch_dtype=torch.float16).to(device) pipe.unet= UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", in_channels=8, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device) pipe.unet.load_state_dict(unet_dict) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_xformers_memory_efficient_attention() all_ssim = [] s_img_path = 'imgs/sm.png' #t_img_path = 'imgs/expected.png' gen_t_img_path = 'imgs/coarse.png' s_img = Image.open(s_img_path).convert("RGB").resize((512,512), Image.BICUBIC) #t_img = Image.open(t_img_path).convert("RGB").resize((512,512), Image.BICUBIC) gen_t_img = Image.open(gen_t_img_path).convert("RGB").resize((512,512), Image.BICUBIC) clip_processor_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values s_img_f = image_encoder_p(clip_processor_s_img.to(device)).last_hidden_state s_img_proj_f = image_proj_model_p(s_img_f) # s_img vae_gen_t_image = torch.unsqueeze(img_transform(gen_t_img), 0) output = pipe( height=512, width=512, guidance_rescale=2.0, vae_gen_t_image=vae_gen_t_image, s_img_proj_f=s_img_proj_f, num_images_per_prompt=4, guidance_scale=1.0, generator=generator, num_inference_steps=20, ) for i, r in enumerate(output.images): r.save('out'+str(i)+'.png') save_output = [] result = output.images[0].crop((512, 0, 512 * 2, 512)) save_output.append(result.resize((352, 512), Image.BICUBIC)) save_output.insert(0, gen_t_img.resize((352, 512), Image.BICUBIC)) save_output.insert(0, s_img.resize((352, 512), Image.BICUBIC)) grid = image_grid(save_output, 1, 3) grid.save("out.png") if __name__ == "__main__": inference()