File size: 1,036 Bytes
076bd8f
ad54d7a
3c9f42d
eb35177
3139aef
076bd8f
26ee91d
8029b4a
ad54d7a
39fedb9
ad54d7a
 
 
 
 
 
d3426a1
eb35177
 
14cca43
3470fd4
416769d
e5113ec
d3426a1
416769d
ad54d7a
d3426a1
 
ad54d7a
d3426a1
ad54d7a
416769d
d3426a1
 
076bd8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gr
from test import inference_img
from models import *
import numpy as np 
from PIL import Image

device='cpu'
model = StyleMatte()
model = model.to(device)
checkpoint = f"stylematte.pth"
state_dict = torch.load(checkpoint, map_location=f'{device}')

model.load_state_dict(state_dict)
model.eval()

def predict(inp):
    print("***********Inference****************")
    mask = inference_img(model, inp) 
    inp_np = np.array(inp)
    fg = np.uint8((mask[:,:,None]*inp_np))
    print("***********Inference finish****************")
    # print("***********MASK****************", inp_np.max(), mask.max())
    fg_pil = Image.fromarray(fg)
    
    return [mask, fg_pil]

print("MODEL LOADED")
print("************************************")

iface = gr.Interface(fn=predict, 
             inputs=gr.Image(type="numpy"),
             outputs=[gr.Image(type="numpy"),gr.Image(type="pil", image_mode='RGBA')],
             examples=["./logo.jpeg"])
print("****************Interface created******************")

iface.launch()