File size: 4,719 Bytes
d7713d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import cv2
import numpy as np
import torch
import tqdm
import yaml
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data._utils.collate import default_collate
from saicinpainting.training.trainers import load_checkpoint
from saicinpainting.evaluation.utils import move_to_device, load_image, prepare_image, pad_img_to_modulo, scale_image
from saicinpainting.evaluation.refinement import refine_predict

refiner_config = {
    'gpu_ids': '0,',
    'modulo': 8,
    'n_iters': 15,
    'lr': 0.002,
    'min_side': 512,
    'max_scales': 3,
    'px_budget': 1800000
}

class Inpainter():
    def __init__(self, config):
        self.model = None
        self.config = config
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.scale_factor = config['scale_factor']
        self.pad_out_to_modulo = config['pad_out_to_modulo']
        self.predict_config = config['predict']
        self.predict_config['model_path'] = 'big-lama'
        self.predict_config['model_checkpoint'] = 'best.ckpt'
        self.refiner_config = refiner_config
    
    def load_model_from_checkpoint(self, model_path, checkpoint):
        train_config_path = os.path.join(model_path, 'config.yaml')
        with open(train_config_path, 'r') as f:
            train_config = OmegaConf.create(yaml.safe_load(f))
        
        train_config.training_model.predict_only = True
        train_config.visualizer.kind = 'noop'
      
        checkpoint_path = os.path.join(model_path, 
                                       'models', 
                                       checkpoint)
        self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
        
        
    def load_batch_data(self, img_, mask_):
        """Loads the image and mask from the given filenames.

        """
        image = prepare_image(img_, mode='RGB')
        mask = prepare_image(mask_, mode='L')
            
        result = dict(image=image, mask=mask[None, ...])

        if self.scale_factor is not None:
            result['image'] = scale_image(result['image'], self.scale_factor)
            result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)

        if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
            result['unpad_to_size'] = result['image'].shape[1:]
            result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
            result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)

        return result
    
    def inpaint_img(self, original_img, mask_img, refine=False) -> Image:
        """ Inpaints the image region defined by the given mask.

            White pixels are to be masked and black pixels kept. 

        args:

            refine: if True, uses the refinement model to enhance the inpainting result, at the cost of speed.

        

        returns: the inpainted image

        """             
        # in case we are given filenames instead of images
        if isinstance(original_img, str):
            original_img = load_image(original_img, mode='RGB')
            mask_img = load_image(mask_img, mode='L')
        
        self.model.eval()
        if not refine:
            self.model.to(self.device)
        # load the image and mask        
        batch = default_collate([self.load_batch_data(original_img, mask_img)])
        
        if refine:
            assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
            # image unpadding is taken care of in the refiner, so that output image
            # is same size as the input image
            cur_res = refine_predict(batch, self.model, **self.refiner_config)
            cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
        else:
            with torch.no_grad():
                batch = move_to_device(batch, self.device)
                batch['mask'] = (batch['mask'] > 0) * 1
                batch = self.model(batch)                    
                cur_res = batch[self.predict_config['out_key']][0].permute(1, 2, 0).detach().cpu().numpy()
                unpad_to_size = batch.get('unpad_to_size', None)
                if unpad_to_size is not None:
                    orig_height, orig_width = unpad_to_size
                    cur_res = cur_res[:orig_height, :orig_width]

        cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
        rslt_image = Image.fromarray(cur_res, 'RGB')
        #cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)

        return rslt_image