Image-to-Image
Diffusers
ONNX
Safetensors
StableDiffusionXLInpaintPipeline
stable-diffusion-xl
inpainting
virtual try-on
| # coding=utf-8 | |
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal | |
| from ip_adapter.ip_adapter import Resampler | |
| import argparse | |
| import logging | |
| import os | |
| import torch.utils.data as data | |
| import torchvision | |
| import json | |
| import accelerate | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import transformers | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from packaging import version | |
| from torchvision import transforms | |
| import diffusers | |
| from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline | |
| from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from src.unet_hacked_tryon import UNet2DConditionModel | |
| from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref | |
| from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline | |
| logger = get_logger(__name__, log_level="INFO") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
| parser.add_argument("--pretrained_model_name_or_path",type=str,default= "yisol/IDM-VTON",required=False,) | |
| parser.add_argument("--width",type=int,default=768,) | |
| parser.add_argument("--height",type=int,default=1024,) | |
| parser.add_argument("--num_inference_steps",type=int,default=30,) | |
| parser.add_argument("--output_dir",type=str,default="result",) | |
| parser.add_argument("--unpaired",action="store_true",) | |
| parser.add_argument("--data_dir",type=str,default="/home/omnious/workspace/yisol/Dataset/zalando") | |
| parser.add_argument("--seed", type=int, default=42,) | |
| parser.add_argument("--test_batch_size", type=int, default=2,) | |
| parser.add_argument("--guidance_scale",type=float,default=2.0,) | |
| parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],) | |
| parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") | |
| args = parser.parse_args() | |
| return args | |
| def pil_to_tensor(images): | |
| images = np.array(images).astype(np.float32) / 255.0 | |
| images = torch.from_numpy(images.transpose(2, 0, 1)) | |
| return images | |
| class VitonHDTestDataset(data.Dataset): | |
| def __init__( | |
| self, | |
| dataroot_path: str, | |
| phase: Literal["train", "test"], | |
| order: Literal["paired", "unpaired"] = "paired", | |
| size: Tuple[int, int] = (512, 384), | |
| ): | |
| super(VitonHDTestDataset, self).__init__() | |
| self.dataroot = dataroot_path | |
| self.phase = phase | |
| self.height = size[0] | |
| self.width = size[1] | |
| self.size = size | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| self.toTensor = transforms.ToTensor() | |
| with open( | |
| os.path.join(dataroot_path, phase, "vitonhd_" + phase + "_tagged.json"), "r" | |
| ) as file1: | |
| data1 = json.load(file1) | |
| annotation_list = [ | |
| "sleeveLength", | |
| "neckLine", | |
| "item", | |
| ] | |
| self.annotation_pair = {} | |
| for k, v in data1.items(): | |
| for elem in v: | |
| annotation_str = "" | |
| for template in annotation_list: | |
| for tag in elem["tag_info"]: | |
| if ( | |
| tag["tag_name"] == template | |
| and tag["tag_category"] is not None | |
| ): | |
| annotation_str += tag["tag_category"] | |
| annotation_str += " " | |
| self.annotation_pair[elem["file_name"]] = annotation_str | |
| self.order = order | |
| self.toTensor = transforms.ToTensor() | |
| im_names = [] | |
| c_names = [] | |
| dataroot_names = [] | |
| if phase == "train": | |
| filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") | |
| else: | |
| filename = os.path.join(dataroot_path, f"{phase}_pairs.txt") | |
| with open(filename, "r") as f: | |
| for line in f.readlines(): | |
| if phase == "train": | |
| im_name, _ = line.strip().split() | |
| c_name = im_name | |
| else: | |
| if order == "paired": | |
| im_name, _ = line.strip().split() | |
| c_name = im_name | |
| else: | |
| im_name, c_name = line.strip().split() | |
| im_names.append(im_name) | |
| c_names.append(c_name) | |
| dataroot_names.append(dataroot_path) | |
| self.im_names = im_names | |
| self.c_names = c_names | |
| self.dataroot_names = dataroot_names | |
| self.clip_processor = CLIPImageProcessor() | |
| def __getitem__(self, index): | |
| c_name = self.c_names[index] | |
| im_name = self.im_names[index] | |
| if c_name in self.annotation_pair: | |
| cloth_annotation = self.annotation_pair[c_name] | |
| else: | |
| cloth_annotation = "shirts" | |
| cloth = Image.open(os.path.join(self.dataroot, self.phase, "cloth", c_name)) | |
| im_pil_big = Image.open( | |
| os.path.join(self.dataroot, self.phase, "image", im_name) | |
| ).resize((self.width,self.height)) | |
| image = self.transform(im_pil_big) | |
| mask = Image.open(os.path.join(self.dataroot, self.phase, "agnostic-mask", im_name.replace('.jpg','_mask.png'))).resize((self.width,self.height)) | |
| mask = self.toTensor(mask) | |
| mask = mask[:1] | |
| mask = 1-mask | |
| im_mask = image * mask | |
| pose_img = Image.open( | |
| os.path.join(self.dataroot, self.phase, "image-densepose", im_name) | |
| ) | |
| pose_img = self.transform(pose_img) # [-1,1] | |
| result = {} | |
| result["c_name"] = c_name | |
| result["im_name"] = im_name | |
| result["image"] = image | |
| result["cloth_pure"] = self.transform(cloth) | |
| result["cloth"] = self.clip_processor(images=cloth, return_tensors="pt").pixel_values | |
| result["inpaint_mask"] =1-mask | |
| result["im_mask"] = im_mask | |
| result["caption_cloth"] = "a photo of " + cloth_annotation | |
| result["caption"] = "model is wearing a " + cloth_annotation | |
| result["pose_img"] = pose_img | |
| return result | |
| def __len__(self): | |
| # model images + cloth image | |
| return len(self.im_names) | |
| def main(): | |
| args = parse_args() | |
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir) | |
| accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| project_config=accelerator_project_config, | |
| ) | |
| if accelerator.is_local_main_process: | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| # If passed along, set the training seed now. | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| # Handle the repository creation | |
| if accelerator.is_main_process: | |
| if args.output_dir is not None: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| weight_dtype = torch.float16 | |
| # if accelerator.mixed_precision == "fp16": | |
| # weight_dtype = torch.float16 | |
| # args.mixed_precision = accelerator.mixed_precision | |
| # elif accelerator.mixed_precision == "bf16": | |
| # weight_dtype = torch.bfloat16 | |
| # args.mixed_precision = accelerator.mixed_precision | |
| # Load scheduler, tokenizer and models. | |
| noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
| vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="vae", | |
| torch_dtype=torch.float16, | |
| ) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="unet", | |
| torch_dtype=torch.float16, | |
| ) | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="image_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="unet_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| text_encoder_one = CLIPTextModel.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| torch_dtype=torch.float16, | |
| ) | |
| text_encoder_two = CLIPTextModelWithProjection.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="text_encoder_2", | |
| torch_dtype=torch.float16, | |
| ) | |
| tokenizer_one = AutoTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer", | |
| revision=None, | |
| use_fast=False, | |
| ) | |
| tokenizer_two = AutoTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="tokenizer_2", | |
| revision=None, | |
| use_fast=False, | |
| ) | |
| # Freeze vae and text_encoder and set unet to trainable | |
| unet.requires_grad_(False) | |
| vae.requires_grad_(False) | |
| image_encoder.requires_grad_(False) | |
| UNet_Encoder.requires_grad_(False) | |
| text_encoder_one.requires_grad_(False) | |
| text_encoder_two.requires_grad_(False) | |
| UNet_Encoder.to(accelerator.device, weight_dtype) | |
| unet.eval() | |
| UNet_Encoder.eval() | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| import xformers | |
| xformers_version = version.parse(xformers.__version__) | |
| if xformers_version == version.parse("0.0.16"): | |
| logger.warn( | |
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
| ) | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| test_dataset = VitonHDTestDataset( | |
| dataroot_path=args.data_dir, | |
| phase="test", | |
| order="unpaired" if args.unpaired else "paired", | |
| size=(args.height, args.width), | |
| ) | |
| test_dataloader = torch.utils.data.DataLoader( | |
| test_dataset, | |
| shuffle=False, | |
| batch_size=args.test_batch_size, | |
| num_workers=4, | |
| ) | |
| pipe = TryonPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| unet=unet, | |
| vae=vae, | |
| feature_extractor= CLIPImageProcessor(), | |
| text_encoder = text_encoder_one, | |
| text_encoder_2 = text_encoder_two, | |
| tokenizer = tokenizer_one, | |
| tokenizer_2 = tokenizer_two, | |
| scheduler = noise_scheduler, | |
| image_encoder=image_encoder, | |
| torch_dtype=torch.float16, | |
| ).to(accelerator.device) | |
| pipe.unet_encoder = UNet_Encoder | |
| # pipe.enable_sequential_cpu_offload() | |
| # pipe.enable_model_cpu_offload() | |
| # pipe.enable_vae_slicing() | |
| with torch.no_grad(): | |
| # Extract the images | |
| with torch.cuda.amp.autocast(): | |
| with torch.no_grad(): | |
| for sample in test_dataloader: | |
| img_emb_list = [] | |
| for i in range(sample['cloth'].shape[0]): | |
| img_emb_list.append(sample['cloth'][i]) | |
| prompt = sample["caption"] | |
| num_prompts = sample['cloth'].shape[0] | |
| negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
| if not isinstance(prompt, List): | |
| prompt = [prompt] * num_prompts | |
| if not isinstance(negative_prompt, List): | |
| negative_prompt = [negative_prompt] * num_prompts | |
| image_embeds = torch.cat(img_emb_list,dim=0) | |
| with torch.inference_mode(): | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| prompt = sample["caption_cloth"] | |
| negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" | |
| if not isinstance(prompt, List): | |
| prompt = [prompt] * num_prompts | |
| if not isinstance(negative_prompt, List): | |
| negative_prompt = [negative_prompt] * num_prompts | |
| with torch.inference_mode(): | |
| ( | |
| prompt_embeds_c, | |
| _, | |
| _, | |
| _, | |
| ) = pipe.encode_prompt( | |
| prompt, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=False, | |
| negative_prompt=negative_prompt, | |
| ) | |
| generator = torch.Generator(pipe.device).manual_seed(args.seed) if args.seed is not None else None | |
| images = pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| num_inference_steps=args.num_inference_steps, | |
| generator=generator, | |
| strength = 1.0, | |
| pose_img = sample['pose_img'], | |
| text_embeds_cloth=prompt_embeds_c, | |
| cloth = sample["cloth_pure"].to(accelerator.device), | |
| mask_image=sample['inpaint_mask'], | |
| image=(sample['image']+1.0)/2.0, | |
| height=args.height, | |
| width=args.width, | |
| guidance_scale=args.guidance_scale, | |
| ip_adapter_image = image_embeds, | |
| )[0] | |
| for i in range(len(images)): | |
| x_sample = pil_to_tensor(images[i]) | |
| torchvision.utils.save_image(x_sample,os.path.join(args.output_dir,sample['im_name'][i])) | |
| if __name__ == "__main__": | |
| main() | |