""" Source url: https://github.com/OPHoperHPO/image-background-remove-tool Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. License: Apache License 2.0 """ import pathlib from typing import List, Union import PIL.Image import numpy as np import torch from PIL import Image from carvekit.ml.arch.u2net.u2net import U2NETArchitecture from carvekit.ml.files.models_loc import u2net_full_pretrained from carvekit.utils.image_utils import load_image, convert_image from carvekit.utils.pool_utils import thread_pool_processing, batch_generator __all__ = ["U2NET"] class U2NET(U2NETArchitecture): """U^2-Net model interface""" def __init__( self, layers_cfg="full", device="cpu", input_image_size: Union[List[int], int] = 320, batch_size: int = 10, load_pretrained: bool = True, fp16: bool = False, ): """ Initialize the U2NET model Args: layers_cfg: neural network layers configuration device: processing device input_image_size: input image size batch_size: the number of images that the neural network processes in one run load_pretrained: loading pretrained model fp16: use fp16 precision // not supported at this moment. """ super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1) self.device = device self.batch_size = batch_size if isinstance(input_image_size, list): self.input_image_size = input_image_size[:2] else: self.input_image_size = (input_image_size, input_image_size) self.to(device) if load_pretrained: self.load_state_dict( torch.load(u2net_full_pretrained(), map_location=self.device) ) self.eval() def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: """ Transform input image to suitable data format for neural network Args: data: input image Returns: input for neural network """ resized = data.resize(self.input_image_size, resample=3) # noinspection PyTypeChecker resized_arr = np.array(resized, dtype=float) temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3)) if np.max(resized_arr) != 0: resized_arr /= np.max(resized_arr) temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229 temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224 temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225 temp_image = temp_image.transpose((2, 0, 1)) temp_image = np.expand_dims(temp_image, 0) return torch.from_numpy(temp_image).type(torch.FloatTensor) @staticmethod def data_postprocessing( data: torch.tensor, original_image: PIL.Image.Image ) -> PIL.Image.Image: """ Transforms output data from neural network to suitable data format for using with other components of this framework. Args: data: output data from neural network original_image: input image which was used for predicted data Returns: Segmentation mask as PIL Image instance """ data = data.unsqueeze(0) mask = data[:, 0, :, :] ma = torch.max(mask) # Normalizes prediction mi = torch.min(mask) predict = ((mask - mi) / (ma - mi)).squeeze() predict_np = predict.cpu().data.numpy() * 255 mask = Image.fromarray(predict_np).convert("L") mask = mask.resize(original_image.size, resample=3) return mask def __call__( self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] ) -> List[PIL.Image.Image]: """ Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances Args: images: input images Returns: segmentation masks as for input images, as PIL.Image.Image instances """ collect_masks = [] for image_batch in batch_generator(images, self.batch_size): images = thread_pool_processing( lambda x: convert_image(load_image(x)), image_batch ) batches = torch.vstack( thread_pool_processing(self.data_preprocessing, images) ) with torch.no_grad(): batches = batches.to(self.device) masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches) masks_cpu = masks.cpu() del d2, d3, d4, d5, d6, d7, batches, masks masks = thread_pool_processing( lambda x: self.data_postprocessing(masks_cpu[x], images[x]), range(len(images)), ) collect_masks += masks return collect_masks