File size: 3,723 Bytes
464c12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import cv2 as cv
import numpy as np
import torch
from PIL import Image, ImageOps

from comic_style.face_detection import align

torch.set_grad_enabled(False)
model = torch.jit.load('comic_style/u2net_bce_itr_16000_train_3.835149_tar_0.542587-400x_360x.jit.pt')
model.eval()


# https://en.wikipedia.org/wiki/Unsharp_masking
# https://stackoverflow.com/a/55590133/1495606
def unsharp_mask(image, kernel_size=(5, 5), sigma=1.0, amount=2.0, threshold=0):
    """Return a sharpened version of the image, using an unsharp mask."""
    blurred = cv.GaussianBlur(image, kernel_size, sigma)
    sharpened = float(amount + 1) * image - float(amount) * blurred
    sharpened = np.maximum(sharpened, np.zeros(sharpened.shape))
    sharpened = np.minimum(sharpened, 255 * np.ones(sharpened.shape))
    sharpened = sharpened.round().astype(np.uint8)
    if threshold > 0:
        low_contrast_mask = np.absolute(image - blurred) < threshold
        np.copyto(sharpened, image, where=low_contrast_mask)
    return sharpened


def normPRED(d):
    ma = np.max(d)
    mi = np.min(d)

    dn = (d - mi) / (ma - mi)

    return dn


def array_to_np(array_in):
    array_in = normPRED(array_in)
    array_in = np.squeeze(255.0 * (array_in))
    array_in = np.transpose(array_in, (1, 2, 0))
    return array_in


def array_to_image(array_in):
    array_in = normPRED(array_in)
    array_in = np.squeeze(255.0 * (array_in))
    array_in = np.transpose(array_in, (1, 2, 0))
    im = Image.fromarray(array_in.astype(np.uint8))
    return im


def image_as_array(image_in):
    image_in = np.array(image_in, np.float32)
    tmpImg = np.zeros((image_in.shape[0], image_in.shape[1], 3))
    image_in = image_in / np.max(image_in)
    if image_in.shape[2] == 1:
        tmpImg[:, :, 0] = (image_in[:, :, 0] - 0.485) / 0.229
        tmpImg[:, :, 1] = (image_in[:, :, 0] - 0.485) / 0.229
        tmpImg[:, :, 2] = (image_in[:, :, 0] - 0.485) / 0.229
    else:
        tmpImg[:, :, 0] = (image_in[:, :, 0] - 0.485) / 0.229
        tmpImg[:, :, 1] = (image_in[:, :, 1] - 0.456) / 0.224
        tmpImg[:, :, 2] = (image_in[:, :, 2] - 0.406) / 0.225

    tmpImg = tmpImg.transpose((2, 0, 1))
    image_out = np.expand_dims(tmpImg, 0)
    return image_out


def find_aligned_face(image_in, size=400):
    aligned_image, n_faces, quad = align(image_in, face_index=0, output_size=size)
    return aligned_image, n_faces, quad


def align_first_face(image_in, size=400):
    aligned_image, n_faces, quad = find_aligned_face(image_in, size=size)
    if n_faces == 0:
        try:
            image_in = ImageOps.exif_transpose(image_in)
        except:
            print("exif problem, not rotating")
        image_in = image_in.resize((size, size))
        im_array = image_as_array(image_in)
    else:
        im_array = image_as_array(aligned_image)

    return im_array


def img_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst


def face2hero(
        img: Image.Image,
        size: int
) -> Image.Image:
    aligned_img = align_first_face(img)
    if aligned_img is None:
        output = None
    else:
        input = torch.Tensor(aligned_img)
        results = model(input)
        hero_np_image = array_to_np(results[1].detach().numpy())
        hero_image = unsharp_mask(hero_np_image)
        hero_image = Image.fromarray(hero_image)
        # hero_image = hero_image.resize((int(hero_image.width * 0.3), int(hero_image.height * 0.3)), Image.ANTIALIAS)
        # output = img_concat_h(array_to_image(aligned_img), hero_image)
        del results
    return hero_image


def inference(img):
    out = face2hero(img, 400)
    return out