fyp-deploy / src /eval.py
Mairaaa's picture
Update src/eval.py
ea12b33 verified
raw
history blame
5.71 kB
import os
# external libraries
import torch
import torch.utils.checkpoint
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer
# custom imports
from src.datasets.dresscode import DressCodeDataset
from src.datasets.vitonhd import VitonHDDataset
from src.mgd_pipelines.mgd_pipe import MGDPipe
from src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled
from src.utils.arg_parser import eval_parse_args
from src.utils.image_from_pipe import generate_images_from_mgd_pipe
from src.utils.set_seeds import set_seed
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__, log_level="INFO")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_START_METHOD"] = "thread"
def main() -> None:
args = eval_parse_args()
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
)
device = accelerator.device
# Set the training seed
if args.seed is not None:
set_seed(args.seed)
# Load scheduler, tokenizer, and models
val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
val_scheduler.set_timesteps(50, device=device)
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
# Load unet
unet = torch.hub.load(
dataset=args.dataset,
repo_or_dir="aimagelab/multimodal-garment-designer",
source="github",
model="mgd",
pretrained=True,
)
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# Enable memory efficient attention if requested
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Set the dataset category
category = [args.category] if args.category else ["dresses", "upper_body", "lower_body"]
# Load the appropriate dataset
if args.dataset == "dresscode":
test_dataset = DressCodeDataset(
dataroot_path=args.dataset_path,
phase="test",
order=args.test_order,
radius=5,
sketch_threshold_range=(20, 20),
tokenizer=tokenizer,
category=category,
size=(512, 384),
)
elif args.dataset == "vitonhd":
test_dataset = VitonHDDataset(
dataroot_path=args.dataset_path,
phase="test",
order=args.test_order,
sketch_threshold_range=(20, 20),
radius=5,
tokenizer=tokenizer,
size=(512, 384),
)
else:
raise NotImplementedError(f"Dataset {args.dataset} is not supported.")
# Prepare the dataloader
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
shuffle=False,
batch_size=args.batch_size,
num_workers=args.num_workers_test,
)
# Cast text_encoder and vae to half-precision for mixed precision training
weight_dtype = torch.float32 if args.mixed_precision != "fp16" else torch.float16
text_encoder.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
# Ensure unet is in eval mode
unet.eval()
# Select the appropriate pipeline
with torch.inference_mode():
if args.disentagle:
val_pipe = MGDPipeDisentangled(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=val_scheduler,
).to(device)
else:
val_pipe = MGDPipe(
text_encoder=text_encoder,
vae=vae,
unet=unet.to(vae.dtype),
tokenizer=tokenizer,
scheduler=val_scheduler,
).to(device)
# Debugging: Ensure val_pipe is callable
assert callable(val_pipe), "The pipeline object (val_pipe) is not callable. Check MGDPipe implementation."
# Enable attention slicing for memory efficiency
val_pipe.enable_attention_slicing()
# Prepare dataloader with accelerator
test_dataloader = accelerator.prepare(test_dataloader)
# Call the image generation function
generate_images_from_mgd_pipe(
test_order=args.test_order,
pipe=val_pipe,
test_dataloader=test_dataloader,
save_name=args.save_name,
dataset=args.dataset,
output_dir=args.output_dir,
guidance_scale=args.guidance_scale,
guidance_scale_pose=args.guidance_scale_pose,
guidance_scale_sketch=args.guidance_scale_sketch,
sketch_cond_rate=args.sketch_cond_rate,
start_cond_rate=args.start_cond_rate,
no_pose=False,
disentagle=args.disentagle,
seed=args.seed,
)
if __name__ == "__main__":
main()