Spaces:
Runtime error
Runtime error
from UNet import * | |
import torch # ; print('Using torch version -', torch.__version__) | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
from torch.nn import Module, Conv2d, ReLU, ModuleList, MaxPool2d, ConvTranspose2d, BCELoss, BCEWithLogitsLoss, functional as F | |
from torch.optim import Adam | |
from torchvision import transforms | |
from torchvision.transforms import CenterCrop | |
from torch.utils.data import Dataset, DataLoader | |
import cv2 | |
import gradio as gr | |
def getoutput(input_img): | |
unet = UNet().to(device) | |
unet = torch.load("unet_06_07_2022_17_13_42_256_256.pth").to(device) | |
output_img = make_predictions(unet, input_img, threshold=0.5) | |
return output_img | |
demo = gr.Interface(getoutput, gr.Image(shape=(200, 200)), "image") | |
demo.launch(share=True) |