import os from PIL import Image import torch from torch.utils.data import DataLoader from datasets.inference_dataset import InferenceDataset from datasets.process_image import ImageProcessor from models.styleres import StyleRes from options.inference_options import InferenceOptions from options import Settings from utils import parse_config from tqdm import tqdm def initialize_styleres(checkpoint_path, device): Settings.device = device model = StyleRes() model.load_ckpt(checkpoint_path) model.send_to_device() model.eval() for param in model.parameters(): param.requires_grad = False return model def run(): args = InferenceOptions().parse() edit_configs = parse_config(args.edit_configs) if torch.cuda.is_available(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = InferenceDataset(args.datadir, aligner_path=args.aligner_path) print(f"Dataset is created. Number of images is {len(dataset)}") dataloader = DataLoader(dataset, batch_size = args.test_batch_size, shuffle=False, num_workers=int(args.test_workers), drop_last=False) if args.n_images == None: args.n_images = len(dataset) # Create output directories output_dir = args.outdir os.makedirs(output_dir, exist_ok=True) for edit_config in edit_configs: cfg_vals = edit_config.values() edit_config.outdir = '_'.join( str(i) for i in cfg_vals) os.makedirs( os.path.join(output_dir, edit_config.outdir), exist_ok=True) resize_amount = (1024, 1024) if args.resize_outputs: resize_amount = (256,256) # Setup model model = initialize_styleres(args.checkpoint_path, device) n_images = 0 for data in tqdm(dataloader): if n_images >= args.n_images: break n_images = n_images + data['image'].shape[0] for edit_config in edit_configs: images = model.edit_images( data['image'], edit_config) images = ImageProcessor.postprocess_image(images.detach().cpu().numpy()) for j in range( images.shape[0]): save_name = data['name'][j] pil_img = Image.fromarray(images[j]).resize(resize_amount) pil_img.save(os.path.join(output_dir, edit_config.outdir, save_name)) if __name__ == '__main__': run()