File size: 4,201 Bytes
def3395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# https://github.com/MarcoForte/FBA_Matting
import cv2
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download

from networks.models import build_model
from networks.transforms import trimap_transform, normalise_image

REPO_ID = "leonelhs/FBA-Matting"

weights = hf_hub_download(repo_id=REPO_ID, filename="FBA.pth")
model = build_model(weights)
model.eval().cpu()


def np_to_torch(x, permute=True):
    if permute:
        return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cpu()
    else:
        return torch.from_numpy(x)[None, :, :, :].float().cpu()


def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
    ''' Scales inputs to multiple of 8. '''
    h, w = x.shape[:2]
    h1 = int(np.ceil(scale * h / 8) * 8)
    w1 = int(np.ceil(scale * w / 8) * 8)
    x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
    return x_scale


def inference(image_np: np.ndarray, trimap_np: np.ndarray) -> [np.ndarray]:
    ''' Predict alpha, foreground and background.
        Parameters:
        image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
        Returns:
        fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
    '''
    h, w = trimap_np.shape[:2]
    image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
    trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)

    with torch.no_grad():
        image_torch = np_to_torch(image_scale_np)
        trimap_torch = np_to_torch(trimap_scale_np)

        trimap_transformed_torch = np_to_torch(
            trimap_transform(trimap_scale_np), permute=False)
        image_transformed_torch = normalise_image(
            image_torch.clone())

        output = model(
            image_torch,
            trimap_torch,
            image_transformed_torch,
            trimap_transformed_torch)
        output = cv2.resize(
            output[0].cpu().numpy().transpose(
                (1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)

    alpha = output[:, :, 0]
    fg = output[:, :, 1:4]
    bg = output[:, :, 4:7]

    alpha[trimap_np[:, :, 0] == 1] = 0
    alpha[trimap_np[:, :, 1] == 1] = 1
    fg[alpha == 1] = image_np[alpha == 1]
    bg[alpha == 0] = image_np[alpha == 0]

    return fg, bg, alpha


def read_image(name):
    return (cv2.imread(name) / 255.0)[:, :, ::-1]


def read_trimap(name):
    trimap_im = cv2.imread(name, 0) / 255.0
    h, w = trimap_im.shape
    trimap_np = np.zeros((h, w, 2))
    trimap_np[trimap_im == 1, 1] = 1
    trimap_np[trimap_im == 0, 0] = 1
    return trimap_np


def predict(image, trimap):
    image_np = read_image(image)
    trimap_np = read_trimap(trimap)
    return inference(image_np, trimap_np)


footer = r"""
<center>
<b>
Demo for <a href='https://github.com/MarcoForte/FBA_Matting'>FBA Matting</a>
</b>
</center>
"""

with gr.Blocks(title="FBA Matting") as app:
    gr.HTML("<center><h1>FBA Matting</h1></center>")
    gr.HTML("<center><h3>Foreground, Background, Alpha Matting Generator.</h3></center>")
    with gr.Row().style(equal_height=False):
        with gr.Column():
            input_img = gr.Image(type="filepath", label="Input image")
            input_trimap = gr.Image(type="filepath", label="Trimap image")
            run_btn = gr.Button(variant="primary")
        with gr.Column():
            fg = gr.Image(type="numpy", label="Foreground")
            bg = gr.Image(type="numpy", label="Background")
            alpha = gr.Image(type="numpy", label="Alpha")

    run_btn.click(predict, [input_img, input_trimap], [fg, bg, alpha])

    with gr.Row():
        examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
        examples = gr.Dataset(components=[input_img], samples=examples_data)
        examples.click(lambda x: x[0], [examples], [input_img])

    with gr.Row():
        gr.HTML(footer)

app.launch(share=False, debug=True, enable_queue=True, show_error=True)