import gradio as gr import os import torch import os from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms import numpy as np from PIL import Image import glob import cv2 import torch import numpy as np import matplotlib.pyplot as plt from u2net import U2NET from inference import TestData, RescaleT, ToTensorLab, normPRED def load_model(model_type): model = U2NET(3,1) if model_type == "U2Net": model_path = "weights/u2net.pth" model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) else: model_path = "weights/quant_model_u2net.pth" model = torch.jit.load(model_path) return model.eval() def normPred(d): ma = torch.max(d) mi = torch.min(d) dn = (d-mi)/(ma-mi) return dn def segment(model_type, img): #img = cv2.imread(img) src = img #img = cv2.resize(img, dsize = (512, 512)) #img = np.moveaxis(img, -1, 0) #img = np.array(img) / 255.0 #img = np.expand_dims(img, axis = 0) #img = img.astype(np.float32) model = load_model(model_type) #output = model.predict(img).round() # with torch.no_grad(): # d1,d2,d3,d4,d5,d6,d7 = model(torch.from_numpy(img)) # output = d1[:,0,:,:] # output = normPred(output) test_dataset = TestData(img_name_list = [img], lbl_name_list = [], transform = transforms.Compose([RescaleT(512), ToTensorLab(flag = 0)])) test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = 1) for i_test, data_test in enumerate(test_dataloader): #print("Inferencing : ", img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) inputs_test = Variable(inputs_test) d1, d2, d3, d4, d5, d6, d7 = model(inputs_test) pred = d1[:,0,:,:] pred = normPRED(pred) #output = output[...,0]#.squeeze() #* 255.0 # segmented = superimpose #output = output.squeeze(axis = 0) #output = #torch.argmax(output, dim = 1) #print("output -> ", output.shape) #print(output) #output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB) #mask2 = np.stack((output,)*3, axis=-1) #segmented = superimpose(src / 255 , mask2) #segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=output, alpha=0.5) #output = #np.moveaxis(output, -1, 0) #print(pred.shape) pred = pred.detach().numpy() #print(pred) pred = np.transpose(pred, (1,2,0)) pred = np.squeeze(pred, axis = 2) pred = Image.fromarray((pred*255).astype(np.uint8)) #segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=np.expand_dims(pred, axis =2), alpha=0.5) #from PIL import ImageChops #im2 = Image.fromarray(src.astype(np.uint8)) #segmented = ImageChops.logical_xor(pred, im2) #print(pred.shape) #return pred segmented = np.dstack((src, pred)) return segmented #return output#segmented iface = gr.Interface(fn=segment, inputs=[gr.inputs.Dropdown(["Lite U2Net", "U2Net"]), gr.Image(shape = (512, 512))], outputs= gr.Image(shape = (512,512))) iface.launch()