import os from typing import Union from skimage import io, transform import torch import torchvision from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms#, utils # import torch.optim as optim import numpy as np from PIL import Image import glob from .data_loader import RescaleT from .data_loader import ToTensor from .data_loader import ToTensorLab from .data_loader import SalObjDataset from .u2net import U2NET # full size version 173.6 MB from .u2net import U2NETP # small version u2net 4.7 MB # normalize the predicted SOD probability map def normPRED(d): ma = torch.max(d) mi = torch.min(d) dn = (d-mi)/(ma-mi) return dn def save_output(image_name,pred,d_dir): predict = pred predict = predict.squeeze() predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np*255).convert('RGB') img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) pb_np = np.array(imo) aaa = img_name.split(".") bbb = aaa[0:-1] imidx = bbb[0] for i in range(1,len(bbb)): imidx = imidx + "." + bbb[i] imo.save(d_dir+imidx+'.png') def get_u2net_model(): model_pth = "models/u2net.pth" net = U2NET(3,1) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net.load_state_dict(torch.load(model_pth, map_location=device)) net.eval() return net def get_saliency_mask(model, image_or_image_path : Union[str, np.array]): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if isinstance(image_or_image_path, str): image = io.imread(image_or_image_path) else: image = image_or_image_path transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) sample = transform({ 'imidx' : np.array([0]), 'image' : image, 'label' : np.expand_dims(np.zeros(image.shape[:-1]), -1) }) input_test = sample["image"].unsqueeze(0).type(torch.FloatTensor).to(device) d1,d2,d3,d4,d5,d6,d7= model(input_test) pred = d1[:,0,:,:] pred = normPRED(pred) pred = pred.squeeze() predict_np = pred.cpu().data.numpy() im = Image.fromarray(predict_np * 255).convert("RGB") return im