img_backswapper / inpainter.py
jgurzoni's picture
creating gradio app
d7713d2
raw
history blame
No virus
4.72 kB
import os
import cv2
import numpy as np
import torch
import tqdm
import yaml
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data._utils.collate import default_collate
from saicinpainting.training.trainers import load_checkpoint
from saicinpainting.evaluation.utils import move_to_device, load_image, prepare_image, pad_img_to_modulo, scale_image
from saicinpainting.evaluation.refinement import refine_predict
refiner_config = {
'gpu_ids': '0,',
'modulo': 8,
'n_iters': 15,
'lr': 0.002,
'min_side': 512,
'max_scales': 3,
'px_budget': 1800000
}
class Inpainter():
def __init__(self, config):
self.model = None
self.config = config
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.scale_factor = config['scale_factor']
self.pad_out_to_modulo = config['pad_out_to_modulo']
self.predict_config = config['predict']
self.predict_config['model_path'] = 'big-lama'
self.predict_config['model_checkpoint'] = 'best.ckpt'
self.refiner_config = refiner_config
def load_model_from_checkpoint(self, model_path, checkpoint):
train_config_path = os.path.join(model_path, 'config.yaml')
with open(train_config_path, 'r') as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
checkpoint_path = os.path.join(model_path,
'models',
checkpoint)
self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
def load_batch_data(self, img_, mask_):
"""Loads the image and mask from the given filenames.
"""
image = prepare_image(img_, mode='RGB')
mask = prepare_image(mask_, mode='L')
result = dict(image=image, mask=mask[None, ...])
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['unpad_to_size'] = result['image'].shape[1:]
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result
def inpaint_img(self, original_img, mask_img, refine=False) -> Image:
""" Inpaints the image region defined by the given mask.
White pixels are to be masked and black pixels kept.
args:
refine: if True, uses the refinement model to enhance the inpainting result, at the cost of speed.
returns: the inpainted image
"""
# in case we are given filenames instead of images
if isinstance(original_img, str):
original_img = load_image(original_img, mode='RGB')
mask_img = load_image(mask_img, mode='L')
self.model.eval()
if not refine:
self.model.to(self.device)
# load the image and mask
batch = default_collate([self.load_batch_data(original_img, mask_img)])
if refine:
assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
# image unpadding is taken care of in the refiner, so that output image
# is same size as the input image
cur_res = refine_predict(batch, self.model, **self.refiner_config)
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, self.device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = self.model(batch)
cur_res = batch[self.predict_config['out_key']][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
rslt_image = Image.fromarray(cur_res, 'RGB')
#cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
return rslt_image