face-segmenter / image_processor.py
pogzyb's picture
Upload processor
6d17dbb verified
from typing import Dict, List, Optional, Tuple, Union, Iterable
import numpy as np
import torch
import transformers
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
ChannelDimension,
get_resize_output_image_size,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
get_channel_dimension_axis,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import is_torch_tensor
class FaceSegformerImageProcessor(BaseImageProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.image_size = kwargs.get("image_size", (224, 224))
self.normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406])
self.normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225])
self.resample = kwargs.get("resample", PILImageResampling.BILINEAR)
self.data_format = kwargs.get("data_format", ChannelDimension.FIRST)
@staticmethod
def normalize(
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
max_pixel_value: float = 255.0,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Copied from:
https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/image_transforms.py#L209
BUT uses the formula from albumentations:
https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize
img = (img - mean * max_pixel_value) / (std * max_pixel_value)
"""
if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array")
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(
image, input_data_format=input_data_format
)
num_channels = image.shape[channel_axis]
# We cast to float32 to avoid errors that can occur when subtracting uint8 values.
# We preserve the original dtype if it is a float type to prevent upcasting float16.
if not np.issubdtype(image.dtype, np.floating):
image = image.astype(np.float32)
if isinstance(mean, Iterable):
if len(mean) != num_channels:
raise ValueError(
f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}"
)
else:
mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype)
if isinstance(std, Iterable):
if len(std) != num_channels:
raise ValueError(
f"std must have {num_channels} elements if it is an iterable, got {len(std)}"
)
else:
std = [std] * num_channels
std = np.array(std, dtype=image.dtype)
# Uses max_pixel_value for normalization
if input_data_format == ChannelDimension.LAST:
image = (image - mean * max_pixel_value) / (std * max_pixel_value)
else:
image = ((image.T - mean * max_pixel_value) / (std * max_pixel_value)).T
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Copied from:
https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py
"""
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError(
"Size must contain either 'shortest_edge' or 'height' and 'width'."
)
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def __call__(self, images: ImageInput, masks: ImageInput = None, **kwargs):
"""
Adapted from:
https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py
"""
# single to iterable if needed
images = make_list_of_images(images)
# validate
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
# make numpy arrays
images = [to_numpy_array(image) for image in images]
# get channel dimensions
input_data_format = kwargs.get("input_data_format")
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
# check if training
# todo: can also assume if masks are passed that we are doing training?
if kwargs.get("do_training", False) is True:
if mask is None:
raise ValueError("must pass masks if doing training.")
# todo: implement this soon.
raise NotImplementedError("not yet implemented.")
# Assume we want to do all transformations for training
else:
# do transformations for inference...
images = [
self.resize(
image=image,
size={"height": self.image_size[0], "width": self.image_size[1]},
resample=kwargs.get("resample") or self.resample,
input_data_format=input_data_format,
)
for image in images
]
images = [
self.normalize(
image=image,
mean=kwargs.get("normalize_mean") or self.normalize_mean,
std=kwargs.get("normalize_std") or self.normalize_std,
input_data_format=input_data_format,
)
for image in images
]
# fix dimensions
images = [
to_channel_dimension_format(
image,
kwargs.get("data_format") or self.data_format,
input_channel_dim=input_data_format,
)
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type="pt")
# Copied from transformers.models.segformer.image_processing_segformer.SegformerImageProcessor.post_process_semantic_segmentation
def post_process_semantic_segmentation(
self, outputs, target_sizes: List[Tuple] = None
):
"""
Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`SegformerForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0),
size=target_sizes[idx],
mode="bilinear",
align_corners=False,
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [
semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])
]
return semantic_segmentation