Spaces:
Runtime error
Runtime error
""" | |
Source url: https://github.com/OPHoperHPO/image-background-remove-tool | |
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
""" | |
import pathlib | |
from typing import List, Union | |
import PIL.Image | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.models.segmentation import deeplabv3_resnet101 | |
from carvekit.ml.files.models_loc import deeplab_pretrained | |
from carvekit.utils.image_utils import convert_image, load_image | |
from carvekit.utils.models_utils import get_precision_autocast, cast_network | |
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing | |
__all__ = ["DeepLabV3"] | |
class DeepLabV3: | |
def __init__( | |
self, | |
device="cpu", | |
batch_size: int = 10, | |
input_image_size: Union[List[int], int] = 1024, | |
load_pretrained: bool = True, | |
fp16: bool = False, | |
): | |
""" | |
Initialize the DeepLabV3 model | |
Args: | |
device: processing device | |
input_image_size: input image size | |
batch_size: the number of images that the neural network processes in one run | |
load_pretrained: loading pretrained model | |
fp16: use half precision | |
""" | |
self.device = device | |
self.batch_size = batch_size | |
self.network = deeplabv3_resnet101( | |
pretrained=False, pretrained_backbone=False, aux_loss=True | |
) | |
self.network.to(self.device) | |
if load_pretrained: | |
self.network.load_state_dict( | |
torch.load(deeplab_pretrained(), map_location=self.device) | |
) | |
if isinstance(input_image_size, list): | |
self.input_image_size = input_image_size[:2] | |
else: | |
self.input_image_size = (input_image_size, input_image_size) | |
self.network.eval() | |
self.fp16 = fp16 | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
def to(self, device: str): | |
""" | |
Moves neural network to specified processing device | |
Args: | |
device (:class:`torch.device`): the desired device. | |
Returns: | |
None | |
""" | |
self.network.to(device) | |
def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: | |
""" | |
Transform input image to suitable data format for neural network | |
Args: | |
data: input image | |
Returns: | |
input for neural network | |
""" | |
copy = data.copy() | |
copy.thumbnail(self.input_image_size, resample=3) | |
return self.transform(copy) | |
def data_postprocessing( | |
data: torch.tensor, original_image: PIL.Image.Image | |
) -> PIL.Image.Image: | |
""" | |
Transforms output data from neural network to suitable data | |
format for using with other components of this framework. | |
Args: | |
data: output data from neural network | |
original_image: input image which was used for predicted data | |
Returns: | |
Segmentation mask as PIL Image instance | |
""" | |
return ( | |
Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size) | |
) | |
def __call__( | |
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] | |
) -> List[PIL.Image.Image]: | |
""" | |
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances | |
Args: | |
images: input images | |
Returns: | |
segmentation masks as for input images, as PIL.Image.Image instances | |
""" | |
collect_masks = [] | |
autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) | |
with autocast: | |
cast_network(self.network, dtype) | |
for image_batch in batch_generator(images, self.batch_size): | |
images = thread_pool_processing( | |
lambda x: convert_image(load_image(x)), image_batch | |
) | |
batches = thread_pool_processing(self.data_preprocessing, images) | |
with torch.no_grad(): | |
masks = [ | |
self.network(i.to(self.device).unsqueeze(0))["out"][0] | |
.argmax(0) | |
.byte() | |
.cpu() | |
for i in batches | |
] | |
del batches | |
masks = thread_pool_processing( | |
lambda x: self.data_postprocessing(masks[x], images[x]), | |
range(len(images)), | |
) | |
collect_masks += masks | |
return collect_masks | |