File size: 1,204 Bytes
076bd8f
ad54d7a
3c9f42d
eb35177
3139aef
076bd8f
26ee91d
8029b4a
ad54d7a
39fedb9
ad54d7a
 
 
 
 
 
d3426a1
eb35177
 
83e14fc
37ae518
 
3470fd4
416769d
37ae518
b40acfb
d3426a1
416769d
ad54d7a
d3426a1
 
ad54d7a
d3426a1
ad54d7a
416769d
d3426a1
 
076bd8f
42ae73b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gradio as gr
from test import inference_img
from models import *
import numpy as np 
from PIL import Image

device='cpu'
model = StyleMatte()
model = model.to(device)
checkpoint = f"stylematte.pth"
state_dict = torch.load(checkpoint, map_location=f'{device}')

model.load_state_dict(state_dict)
model.eval()

def predict(inp):
    print("***********Inference****************")
    mask = inference_img(model, inp) 
    inp_np = np.array(inp)
    fg = np.uint8((mask[:,:,None]*inp_np))
    alpha_channel = (mask*255).astype(np.uint8)
    print(fg.max(), alpha_channel.max(), fg.shape, alpha_channel.shape)
    print("***********Inference finish****************")
    # print("***********MASK****************", inp_np.max(), mask.max())
    fg = np.dstack((fg, alpha_channel))
    fg_pil = Image.fromarray(fg, 'RGBA')
    
    return [mask, fg_pil]

print("MODEL LOADED")
print("************************************")

iface = gr.Interface(fn=predict, 
             inputs=gr.Image(type="numpy"),
             outputs=[gr.Image(type="numpy"),gr.Image(type="pil", image_mode='RGBA')],
             examples=["./logo.jpeg"])
print("****************Interface created******************")

iface.launch()