import gradio as gr from test import inference_img from models import * device='cpu' model = StyleMatte() model = model.to(device) checkpoint = f"stylematte.pth" state_dict = torch.load(checkpoint, map_location=f'{device}') model.load_state_dict(state_dict) model.eval() def predict(inp): print("***********Inference****************") res = inference_img(model, inp) print(res.shape, res.dtype) main_subject = inp.copy() main_subject = main_subject * (res > 0.1)[:,:,np.newaxis] return main_subject print("MODEL LOADED") print("************************************") iface = gr.Interface(fn=predict, inputs=gr.Image(type="numpy"), outputs=gr.Image(type="numpy"), examples=["./logo.jpeg"]) print("****************Interface created******************") iface.launch()