from UNet import * from make_predictions import * import torch import gradio as gr def getoutput(threshold,input_img): unet = UNet().to('cpu') unet = torch.load("unet_06_07_2022_22_43_42_1024_1024.pth", map_location='cpu').to('cpu') output_img = make_predictions(unet, input_img, threshold=threshold) return output_img demo = gr.Interface( fn=getoutput, inputs=["number", gr.Image(shape=(200, 200))], outputs=["image"]) demo.launch()