from UNet import * import torch # ; print('Using torch version -', torch.__version__) if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' from torch.nn import Module, Conv2d, ReLU, ModuleList, MaxPool2d, ConvTranspose2d, BCELoss, BCEWithLogitsLoss, functional as F from torch.optim import Adam from torchvision import transforms from torchvision.transforms import CenterCrop from torch.utils.data import Dataset, DataLoader import cv2 import gradio as gr def getoutput(input_img): unet = UNet().to(device) unet = torch.load("unet_06_07_2022_17_13_42_256_256.pth").to(device) output_img = make_predictions(unet, input_img, threshold=0.5) return output_img demo = gr.Interface(getoutput, gr.Image(shape=(200, 200)), "image") demo.launch(share=True)