File size: 4,970 Bytes
b3640b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os, glob, sys, logging
import argparse, datetime, time
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import model, basic
from utils import util


def setup_model(checkpt_path, device="cuda"):
    seed = 130
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #print('--------------', torch.cuda.is_available())
    """Load the model into memory to make running multiple predictions efficient"""
    colorLabeler = basic.ColorLabel(device=device)
    colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
    colorizer = colorizer.to(device)
    #checkpt_path = "./checkpoints/disco-beta.pth.rar"
    assert os.path.exists(checkpt_path)
    data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
    colorizer.load_state_dict(data_dict['state_dict'])
    colorizer.eval()
    return colorizer, colorLabeler


def resize_ab2l(gray_img, lab_imgs, vis=False):
    H, W = gray_img.shape[:2]
    reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
    if vis:
        gray_img = cv2.resize(lab_imgs[:,:,:1], (W,H), interpolation=cv2.INTER_LINEAR)
        return np.concatenate((gray_img[:,:,np.newaxis], reszied_ab), axis=2)
    else:
        return np.concatenate((gray_img, reszied_ab), axis=2)

def prepare_data(rgb_img, target_res):
    rgb_img = np.array(rgb_img / 255., np.float32)
    lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
    org_grays = (lab_img[:,:,[0]]-50.) / 50.
    lab_img = cv2.resize(lab_img, target_res, interpolation=cv2.INTER_LINEAR)
        
    lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
    gray_img = (lab_img[0:1,:,:]-50.) / 50.
    ab_chans = lab_img[1:3,:,:] / 110.
    input_grays = gray_img.unsqueeze(0)
    input_colors = ab_chans.unsqueeze(0)
    return input_grays, input_colors, org_grays


def colorize_grayscale(colorizer, color_class, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device="cuda"):
    n_anchors = int(n_anchors)
    n_anchors = max(n_anchors, 3)
    n_anchors = min(n_anchors, 14)
    target_res = (512,512) if is_high_res else (256,256)
    input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
    input_grays = input_grays.to(device)
    input_colors = input_colors.to(device)
    
    if is_editable:
        print('>>>:editable mode')
        sampled_T = -1
        _, input_colors, _ = prepare_data(hint_img, target_res)
        input_colors = input_colors.to(device)
        pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
                                                                input_colors, n_anchors, sampled_T)
    else:
        print('>>>:automatic mode')
        sampled_T = 0
        pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
                                                                input_colors, n_anchors, sampled_T)

    pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
    lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
    lab_imgs = resize_ab2l(org_grays, lab_imgs)
        
    lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
    lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
    rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
    return (rgb_output*255.0).astype(np.uint8)


def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_editable, device="cuda"):
    n_anchors = int(n_anchors)
    n_anchors = max(n_anchors, 3)
    n_anchors = min(n_anchors, 14)
    target_res = (512,512) if is_high_res else (256,256)
    input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
    input_grays = input_grays.cuda(non_blocking=True)
    input_colors = input_colors.cuda(non_blocking=True)
                
    sampled_T, sp_size = 0, 16
    pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
                                                            input_colors, n_anchors, sampled_T)
    pred_probs = pal_logit
    guided_colors = color_class.decode_ind2ab(ref_logit, T=0)
    guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
    anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
    marked_labs = basic.mark_color_hints(input_grays, guided_colors, anchor_masks, base_ABs=None)
    lab_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
    lab_imgs = resize_ab2l(org_grays, lab_imgs, vis=True)
        
    lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
    lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
    rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
    return (rgb_output*255.0).astype(np.uint8)