Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import torch | |
import os | |
from skimage import io, transform | |
import torch | |
import torchvision | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
import numpy as np | |
from PIL import Image | |
import glob | |
import cv2 | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from u2net import U2NET | |
from inference import TestData, RescaleT, ToTensorLab, normPRED | |
def load_model(model_type): | |
model = U2NET(3,1) | |
if model_type == "U2Net": | |
model_path = "weights/u2net.pth" | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
else: | |
model_path = "weights/quant_model_u2net.pth" | |
model = torch.jit.load(model_path) | |
return model.eval() | |
def normPred(d): | |
ma = torch.max(d) | |
mi = torch.min(d) | |
dn = (d-mi)/(ma-mi) | |
return dn | |
def segment(model_type, img): | |
#img = cv2.imread(img) | |
src = img | |
#img = cv2.resize(img, dsize = (512, 512)) | |
#img = np.moveaxis(img, -1, 0) | |
#img = np.array(img) / 255.0 | |
#img = np.expand_dims(img, axis = 0) | |
#img = img.astype(np.float32) | |
model = load_model(model_type) | |
#output = model.predict(img).round() | |
# with torch.no_grad(): | |
# d1,d2,d3,d4,d5,d6,d7 = model(torch.from_numpy(img)) | |
# output = d1[:,0,:,:] | |
# output = normPred(output) | |
test_dataset = TestData(img_name_list = [img], lbl_name_list = [], | |
transform = transforms.Compose([RescaleT(512), ToTensorLab(flag = 0)])) | |
test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = 1) | |
for i_test, data_test in enumerate(test_dataloader): | |
#print("Inferencing : ", img_name_list[i_test].split(os.sep)[-1]) | |
inputs_test = data_test['image'] | |
inputs_test = inputs_test.type(torch.FloatTensor) | |
inputs_test = Variable(inputs_test) | |
d1, d2, d3, d4, d5, d6, d7 = model(inputs_test) | |
pred = d1[:,0,:,:] | |
pred = normPRED(pred) | |
#output = output[...,0]#.squeeze() #* 255.0 | |
# segmented = superimpose | |
#output = output.squeeze(axis = 0) | |
#output = #torch.argmax(output, dim = 1) | |
#print("output -> ", output.shape) | |
#print(output) | |
#output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB) | |
#mask2 = np.stack((output,)*3, axis=-1) | |
#segmented = superimpose(src / 255 , mask2) | |
#segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=output, alpha=0.5) | |
#output = #np.moveaxis(output, -1, 0) | |
#print(pred.shape) | |
pred = pred.detach().numpy() | |
#print(pred) | |
pred = np.transpose(pred, (1,2,0)) | |
pred = np.squeeze(pred, axis = 2) | |
pred = Image.fromarray((pred*255).astype(np.uint8)) | |
#segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=np.expand_dims(pred, axis =2), alpha=0.5) | |
#from PIL import ImageChops | |
#im2 = Image.fromarray(src.astype(np.uint8)) | |
#segmented = ImageChops.logical_xor(pred, im2) | |
#print(pred.shape) | |
#return pred | |
segmented = np.dstack((src, pred)) | |
return segmented | |
#return output#segmented | |
iface = gr.Interface(fn=segment, inputs=[gr.inputs.Dropdown(["Lite U2Net", "U2Net"]), gr.Image(shape = (512, 512))], outputs= gr.Image(shape = (512,512))) | |
iface.launch() |