File size: 519 Bytes
076bd8f
ad54d7a
076bd8f
ad54d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
076bd8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import gradio as gr
from test import inference_img

device='cuda'
model = MaskForm()
model = model.to(device)
checkpoint = f"stylematte_synth.pth"
state_dict = torch.load(checkpoint, map_location=f'{device}')

model.load_state_dict(state_dict)
model.eval()

def predict(inp):
    res = inference_img(model, inp) 

    return res


gr.Interface(fn=predict, 
             inputs=gr.Image(type="numpy"),
             outputs=gr.Image(type="numpy"),
             examples=["./logo.jpeg"]).launch(share=True)

iface.launch()