Spaces:
Paused
Paused
| import os | |
| import json | |
| import cv2 | |
| import torch | |
| from torch import nn | |
| from PIL import Image | |
| import numpy as np | |
| from diffusers import UniPCMultistepScheduler | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | |
| from transformers import CLIPImageProcessor | |
| from src.pipelines.stage3_refined_pipeline import Stage3_RefinedPipeline | |
| import argparse | |
| from transformers import Dinov2Model | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from skimage.metrics import structural_similarity as compare_ssim | |
| import torch | |
| import torch.nn as nn | |
| import torch.multiprocessing as mp | |
| import json | |
| import time | |
| def split_list_into_chunks(lst, n): | |
| chunk_size = len(lst) // n | |
| chunks = [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] | |
| if len(chunks) > n: | |
| last_chunk = chunks.pop() | |
| chunks[-1].extend(last_chunk) | |
| return chunks | |
| def image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows * cols | |
| w, h = imgs[0].size | |
| grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| grid_w, grid_h = grid.size | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| return grid | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| nn.init.zeros_(p) | |
| return module | |
| class ImageProjModel_p(torch.nn.Module): | |
| """SD model with image prompt""" | |
| def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, out_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): # b, 257,1280 | |
| return self.net(x) | |
| def inference(): | |
| device = "cuda" | |
| generator = torch.Generator(device=device).manual_seed(42) | |
| clip_image_processor = CLIPImageProcessor() | |
| img_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ]) | |
| # model define | |
| image_proj_model_p_dict = {} | |
| unet_dict = {} | |
| image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant').to(device).eval() | |
| image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).eval() | |
| #model_ckpt = "{}/mp_rank_00_model_states.pt".format('{save_ckpt}') | |
| model_ckpt = "s3_512.pt" | |
| with torch.no_grad(): | |
| model_sd = torch.load(model_ckpt)["module"] | |
| for k in model_sd.keys(): | |
| if k.startswith("image_proj_model_p"): | |
| image_proj_model_p_dict[k.replace("image_proj_model_p.", "")] = model_sd[k] | |
| elif k.startswith("unet"): | |
| unet_dict[k.replace("unet.", "")] = model_sd[k] | |
| else: | |
| print(k) | |
| image_proj_model_p.load_state_dict(image_proj_model_p_dict) | |
| pipe = Stage3_RefinedPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base",torch_dtype=torch.float16).to(device) | |
| pipe.unet= UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", | |
| in_channels=8, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device) | |
| pipe.unet.load_state_dict(unet_dict) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe.enable_xformers_memory_efficient_attention() | |
| all_ssim = [] | |
| s_img_path = 'imgs/sm.png' | |
| #t_img_path = 'imgs/expected.png' | |
| gen_t_img_path = 'imgs/coarse.png' | |
| s_img = Image.open(s_img_path).convert("RGB").resize((512,512), Image.BICUBIC) | |
| #t_img = Image.open(t_img_path).convert("RGB").resize((512,512), Image.BICUBIC) | |
| gen_t_img = Image.open(gen_t_img_path).convert("RGB").resize((512,512), Image.BICUBIC) | |
| clip_processor_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values | |
| s_img_f = image_encoder_p(clip_processor_s_img.to(device)).last_hidden_state | |
| s_img_proj_f = image_proj_model_p(s_img_f) # s_img | |
| vae_gen_t_image = torch.unsqueeze(img_transform(gen_t_img), 0) | |
| output = pipe( | |
| height=512, | |
| width=512, | |
| guidance_rescale=2.0, | |
| vae_gen_t_image=vae_gen_t_image, | |
| s_img_proj_f=s_img_proj_f, | |
| num_images_per_prompt=4, | |
| guidance_scale=1.0, | |
| generator=generator, | |
| num_inference_steps=20, | |
| ) | |
| for i, r in enumerate(output.images): | |
| r.save('out'+str(i)+'.png') | |
| save_output = [] | |
| result = output.images[0].crop((512, 0, 512 * 2, 512)) | |
| save_output.append(result.resize((352, 512), Image.BICUBIC)) | |
| save_output.insert(0, gen_t_img.resize((352, 512), Image.BICUBIC)) | |
| save_output.insert(0, s_img.resize((352, 512), Image.BICUBIC)) | |
| grid = image_grid(save_output, 1, 3) | |
| grid.save("out.png") | |
| if __name__ == "__main__": | |
| inference() | |