StyleRes / inference.py
hamzapehlivan
Intial Commit
6709fc9
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader
from datasets.inference_dataset import InferenceDataset
from datasets.process_image import ImageProcessor
from models.styleres import StyleRes
from options.inference_options import InferenceOptions
from options import Settings
from utils import parse_config
from tqdm import tqdm
def initialize_styleres(checkpoint_path, device):
Settings.device = device
model = StyleRes()
model.load_ckpt(checkpoint_path)
model.send_to_device()
model.eval()
for param in model.parameters():
param.requires_grad = False
return model
def run():
args = InferenceOptions().parse()
edit_configs = parse_config(args.edit_configs)
if torch.cuda.is_available():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = InferenceDataset(args.datadir, aligner_path=args.aligner_path)
print(f"Dataset is created. Number of images is {len(dataset)}")
dataloader = DataLoader(dataset, batch_size = args.test_batch_size,
shuffle=False,
num_workers=int(args.test_workers),
drop_last=False)
if args.n_images == None:
args.n_images = len(dataset)
# Create output directories
output_dir = args.outdir
os.makedirs(output_dir, exist_ok=True)
for edit_config in edit_configs:
cfg_vals = edit_config.values()
edit_config.outdir = '_'.join( str(i) for i in cfg_vals)
os.makedirs( os.path.join(output_dir, edit_config.outdir), exist_ok=True)
resize_amount = (1024, 1024)
if args.resize_outputs:
resize_amount = (256,256)
# Setup model
model = initialize_styleres(args.checkpoint_path, device)
n_images = 0
for data in tqdm(dataloader):
if n_images >= args.n_images:
break
n_images = n_images + data['image'].shape[0]
for edit_config in edit_configs:
images = model.edit_images( data['image'], edit_config)
images = ImageProcessor.postprocess_image(images.detach().cpu().numpy())
for j in range( images.shape[0]):
save_name = data['name'][j]
pil_img = Image.fromarray(images[j]).resize(resize_amount)
pil_img.save(os.path.join(output_dir, edit_config.outdir, save_name))
if __name__ == '__main__':
run()