Spaces:
Running
Running
import os | |
# external libraries | |
import torch | |
import torch.utils.checkpoint | |
import torch.utils.checkpoint | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# custom imports | |
from src.datasets.dresscode import DressCodeDataset | |
from src.datasets.vitonhd import VitonHDDataset | |
from src.mgd_pipelines.mgd_pipe import MGDPipe | |
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
from src.utils.arg_parser import eval_parse_args | |
from src.utils.image_from_pipe import generate_images_from_mgd_pipe | |
from src.utils.set_seeds import set_seed | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
check_min_version("0.10.0.dev0") | |
logger = get_logger(__name__, log_level="INFO") | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
os.environ["WANDB_START_METHOD"] = "thread" | |
def main() -> None: | |
args = eval_parse_args() | |
accelerator = Accelerator( | |
mixed_precision=args.mixed_precision, | |
) | |
device = accelerator.device | |
# Set the training seed | |
if args.seed is not None: | |
set_seed(args.seed) | |
# Load scheduler, tokenizer, and models | |
val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
val_scheduler.set_timesteps(50, device=device) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
) | |
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) | |
# Load unet | |
unet = torch.hub.load( | |
dataset=args.dataset, | |
repo_or_dir="aimagelab/multimodal-garment-designer", | |
source="github", | |
model="mgd", | |
pretrained=True, | |
) | |
# Freeze vae and text_encoder | |
vae.requires_grad_(False) | |
text_encoder.requires_grad_(False) | |
# Enable memory efficient attention if requested | |
if args.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
# Set the dataset category | |
category = [args.category] if args.category else ["dresses", "upper_body", "lower_body"] | |
# Load the appropriate dataset | |
if args.dataset == "dresscode": | |
test_dataset = DressCodeDataset( | |
dataroot_path=args.dataset_path, | |
phase="test", | |
order=args.test_order, | |
radius=5, | |
sketch_threshold_range=(20, 20), | |
tokenizer=tokenizer, | |
category=category, | |
size=(512, 384), | |
) | |
elif args.dataset == "vitonhd": | |
test_dataset = VitonHDDataset( | |
dataroot_path=args.dataset_path, | |
phase="test", | |
order=args.test_order, | |
sketch_threshold_range=(20, 20), | |
radius=5, | |
tokenizer=tokenizer, | |
size=(512, 384), | |
) | |
else: | |
raise NotImplementedError(f"Dataset {args.dataset} is not supported.") | |
# Prepare the dataloader | |
test_dataloader = torch.utils.data.DataLoader( | |
test_dataset, | |
shuffle=False, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers_test, | |
) | |
# Cast text_encoder and vae to half-precision for mixed precision training | |
weight_dtype = torch.float32 if args.mixed_precision != "fp16" else torch.float16 | |
text_encoder.to(device, dtype=weight_dtype) | |
vae.to(device, dtype=weight_dtype) | |
# Ensure unet is in eval mode | |
unet.eval() | |
# Select the appropriate pipeline | |
with torch.inference_mode(): | |
if args.disentagle: | |
val_pipe = MGDPipeDisentangled( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet.to(vae.dtype), | |
tokenizer=tokenizer, | |
scheduler=val_scheduler, | |
).to(device) | |
else: | |
val_pipe = MGDPipe( | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet.to(vae.dtype), | |
tokenizer=tokenizer, | |
scheduler=val_scheduler, | |
).to(device) | |
# Debugging: Ensure val_pipe is callable | |
assert callable(val_pipe), "The pipeline object (val_pipe) is not callable. Check MGDPipe implementation." | |
# Enable attention slicing for memory efficiency | |
val_pipe.enable_attention_slicing() | |
# Prepare dataloader with accelerator | |
test_dataloader = accelerator.prepare(test_dataloader) | |
# Call the image generation function | |
generate_images_from_mgd_pipe( | |
test_order=args.test_order, | |
pipe=val_pipe, | |
test_dataloader=test_dataloader, | |
save_name=args.save_name, | |
dataset=args.dataset, | |
output_dir=args.output_dir, | |
guidance_scale=args.guidance_scale, | |
guidance_scale_pose=args.guidance_scale_pose, | |
guidance_scale_sketch=args.guidance_scale_sketch, | |
sketch_cond_rate=args.sketch_cond_rate, | |
start_cond_rate=args.start_cond_rate, | |
no_pose=False, | |
disentagle=args.disentagle, | |
seed=args.seed, | |
) | |
if __name__ == "__main__": | |
main() | |