AntoreepJana's picture
Update app.py
5523f82
raw
history blame contribute delete
No virus
3.14 kB
import gradio as gr
import os
import torch
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from PIL import Image
import glob
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from u2net import U2NET
from inference import TestData, RescaleT, ToTensorLab, normPRED
def load_model(model_type):
model = U2NET(3,1)
if model_type == "U2Net":
model_path = "weights/u2net.pth"
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
else:
model_path = "weights/quant_model_u2net.pth"
model = torch.jit.load(model_path)
return model.eval()
def normPred(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def segment(model_type, img):
#img = cv2.imread(img)
src = img
#img = cv2.resize(img, dsize = (512, 512))
#img = np.moveaxis(img, -1, 0)
#img = np.array(img) / 255.0
#img = np.expand_dims(img, axis = 0)
#img = img.astype(np.float32)
model = load_model(model_type)
#output = model.predict(img).round()
# with torch.no_grad():
# d1,d2,d3,d4,d5,d6,d7 = model(torch.from_numpy(img))
# output = d1[:,0,:,:]
# output = normPred(output)
test_dataset = TestData(img_name_list = [img], lbl_name_list = [],
transform = transforms.Compose([RescaleT(512), ToTensorLab(flag = 0)]))
test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = 1)
for i_test, data_test in enumerate(test_dataloader):
#print("Inferencing : ", img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
inputs_test = Variable(inputs_test)
d1, d2, d3, d4, d5, d6, d7 = model(inputs_test)
pred = d1[:,0,:,:]
pred = normPRED(pred)
#output = output[...,0]#.squeeze() #* 255.0
# segmented = superimpose
#output = output.squeeze(axis = 0)
#output = #torch.argmax(output, dim = 1)
#print("output -> ", output.shape)
#print(output)
#output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
#mask2 = np.stack((output,)*3, axis=-1)
#segmented = superimpose(src / 255 , mask2)
#segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=output, alpha=0.5)
#output = #np.moveaxis(output, -1, 0)
#print(pred.shape)
pred = pred.detach().numpy()
#print(pred)
pred = np.transpose(pred, (1,2,0))
pred = np.squeeze(pred, axis = 2)
pred = Image.fromarray((pred*255).astype(np.uint8))
#segmented = pcv.visualize.overlay_two_imgs(img1=src, img2=np.expand_dims(pred, axis =2), alpha=0.5)
#from PIL import ImageChops
#im2 = Image.fromarray(src.astype(np.uint8))
#segmented = ImageChops.logical_xor(pred, im2)
#print(pred.shape)
#return pred
segmented = np.dstack((src, pred))
return segmented
#return output#segmented
iface = gr.Interface(fn=segment, inputs=[gr.inputs.Dropdown(["Lite U2Net", "U2Net"]), gr.Image(shape = (512, 512))], outputs= gr.Image(shape = (512,512)))
iface.launch()