Towsif7's picture
firrst commit
59e40e1
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import pathlib
from typing import List, Union
import PIL.Image
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3_resnet101
from carvekit.ml.files.models_loc import deeplab_pretrained
from carvekit.utils.image_utils import convert_image, load_image
from carvekit.utils.models_utils import get_precision_autocast, cast_network
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
__all__ = ["DeepLabV3"]
class DeepLabV3:
def __init__(
self,
device="cpu",
batch_size: int = 10,
input_image_size: Union[List[int], int] = 1024,
load_pretrained: bool = True,
fp16: bool = False,
):
"""
Initialize the DeepLabV3 model
Args:
device: processing device
input_image_size: input image size
batch_size: the number of images that the neural network processes in one run
load_pretrained: loading pretrained model
fp16: use half precision
"""
self.device = device
self.batch_size = batch_size
self.network = deeplabv3_resnet101(
pretrained=False, pretrained_backbone=False, aux_loss=True
)
self.network.to(self.device)
if load_pretrained:
self.network.load_state_dict(
torch.load(deeplab_pretrained(), map_location=self.device)
)
if isinstance(input_image_size, list):
self.input_image_size = input_image_size[:2]
else:
self.input_image_size = (input_image_size, input_image_size)
self.network.eval()
self.fp16 = fp16
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def to(self, device: str):
"""
Moves neural network to specified processing device
Args:
device (:class:`torch.device`): the desired device.
Returns:
None
"""
self.network.to(device)
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor:
"""
Transform input image to suitable data format for neural network
Args:
data: input image
Returns:
input for neural network
"""
copy = data.copy()
copy.thumbnail(self.input_image_size, resample=3)
return self.transform(copy)
@staticmethod
def data_postprocessing(
data: torch.tensor, original_image: PIL.Image.Image
) -> PIL.Image.Image:
"""
Transforms output data from neural network to suitable data
format for using with other components of this framework.
Args:
data: output data from neural network
original_image: input image which was used for predicted data
Returns:
Segmentation mask as PIL Image instance
"""
return (
Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size)
)
def __call__(
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
) -> List[PIL.Image.Image]:
"""
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
Args:
images: input images
Returns:
segmentation masks as for input images, as PIL.Image.Image instances
"""
collect_masks = []
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16)
with autocast:
cast_network(self.network, dtype)
for image_batch in batch_generator(images, self.batch_size):
images = thread_pool_processing(
lambda x: convert_image(load_image(x)), image_batch
)
batches = thread_pool_processing(self.data_preprocessing, images)
with torch.no_grad():
masks = [
self.network(i.to(self.device).unsqueeze(0))["out"][0]
.argmax(0)
.byte()
.cpu()
for i in batches
]
del batches
masks = thread_pool_processing(
lambda x: self.data_postprocessing(masks[x], images[x]),
range(len(images)),
)
collect_masks += masks
return collect_masks