import gradio as gr from test import inference_img device='cuda' model = MaskForm() model = model.to(device) checkpoint = f"stylematte_synth.pth" state_dict = torch.load(checkpoint, map_location=f'{device}') model.load_state_dict(state_dict) model.eval() def predict(inp): res = inference_img(model, inp) return res gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs=gr.Image(type="numpy"), examples=["./logo.jpeg"]).launch(share=True) iface.launch()