File size: 835 Bytes
cff9442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8d559d
81c7e37
23fa3d9
81c7e37
cff9442
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import gradio as gr
from test import inference_img
from models import *

device='cpu'
model = StyleMatte()
model = model.to(device)
checkpoint = f"stylematte.pth"
state_dict = torch.load(checkpoint, map_location=f'{device}')

model.load_state_dict(state_dict)
model.eval()

def predict(inp):
    print("***********Inference****************")
    res = inference_img(model, inp) 
    print(res.shape, res.dtype)
    main_subject = inp.copy()
    main_subject = main_subject * (res > 0.1)[:,:,np.newaxis]
    return main_subject

print("MODEL LOADED")
print("************************************")

iface = gr.Interface(fn=predict, 
             inputs=gr.Image(type="numpy"),
             outputs=gr.Image(type="numpy"),
             examples=["./logo.jpeg"])
print("****************Interface created******************")

iface.launch()