Towsif7's picture
firrst commit
59e40e1
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import pathlib
from typing import Union, List
import PIL
import numpy as np
import torch
from PIL import Image
from carvekit.ml.arch.basnet.basnet import BASNet
from carvekit.ml.files.models_loc import basnet_pretrained
from carvekit.utils.image_utils import convert_image, load_image
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
__all__ = ["BASNET"]
class BASNET(BASNet):
"""BASNet model interface"""
def __init__(
self,
device="cpu",
input_image_size: Union[List[int], int] = 320,
batch_size: int = 10,
load_pretrained: bool = True,
fp16: bool = False,
):
"""
Initialize the BASNET model
Args:
device: processing device
input_image_size: input image size
batch_size: the number of images that the neural network processes in one run
load_pretrained: loading pretrained model
fp16: use fp16 precision // not supported at this moment
"""
super(BASNET, self).__init__(n_channels=3, n_classes=1)
self.device = device
self.batch_size = batch_size
if isinstance(input_image_size, list):
self.input_image_size = input_image_size[:2]
else:
self.input_image_size = (input_image_size, input_image_size)
self.to(device)
if load_pretrained:
self.load_state_dict(
torch.load(basnet_pretrained(), map_location=self.device)
)
self.eval()
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
"""
Transform input image to suitable data format for neural network
Args:
data: input image
Returns:
input for neural network
"""
resized = data.resize(self.input_image_size)
# noinspection PyTypeChecker
resized_arr = np.array(resized, dtype=np.float64)
temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3))
if np.max(resized_arr) != 0:
resized_arr /= np.max(resized_arr)
temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229
temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224
temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225
temp_image = temp_image.transpose((2, 0, 1))
temp_image = np.expand_dims(temp_image, 0)
return torch.from_numpy(temp_image).type(torch.FloatTensor)
@staticmethod
def data_postprocessing(
data: torch.tensor, original_image: PIL.Image.Image
) -> PIL.Image.Image:
"""
Transforms output data from neural network to suitable data
format for using with other components of this framework.
Args:
data: output data from neural network
original_image: input image which was used for predicted data
Returns:
Segmentation mask as PIL Image instance
"""
data = data.unsqueeze(0)
mask = data[:, 0, :, :]
ma = torch.max(mask) # Normalizes prediction
mi = torch.min(mask)
predict = ((mask - mi) / (ma - mi)).squeeze()
predict_np = predict.cpu().data.numpy() * 255
mask = Image.fromarray(predict_np).convert("L")
mask = mask.resize(original_image.size, resample=3)
return mask
def __call__(
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
) -> List[PIL.Image.Image]:
"""
Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances
Args:
images: input images
Returns:
segmentation masks as for input images, as PIL.Image.Image instances
"""
collect_masks = []
for image_batch in batch_generator(images, self.batch_size):
images = thread_pool_processing(
lambda x: convert_image(load_image(x)), image_batch
)
batches = torch.vstack(
thread_pool_processing(self.data_preprocessing, images)
)
with torch.no_grad():
batches = batches.to(self.device)
masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__(
batches
)
masks_cpu = masks.cpu()
del d2, d3, d4, d5, d6, d7, d8, batches, masks
masks = thread_pool_processing(
lambda x: self.data_postprocessing(masks_cpu[x], images[x]),
range(len(images)),
)
collect_masks += masks
return collect_masks