julien.blanchon
add app
c8c12e9
"""Dataset Utils."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from typing import List, Optional, Tuple
import numpy as np
from torch import Tensor
class Denormalize:
"""Denormalize Torch Tensor into np image format."""
def __init__(self, mean: Optional[List[float]] = None, std: Optional[List[float]] = None):
"""Denormalize Torch Tensor into np image format.
Args:
mean: Mean
std: Standard deviation.
"""
# If no mean and std provided, assign ImageNet values.
if mean is None:
mean = [0.485, 0.456, 0.406]
if std is None:
std = [0.229, 0.224, 0.225]
self.mean = Tensor(mean)
self.std = Tensor(std)
def __call__(self, tensor: Tensor) -> np.ndarray:
"""Denormalize the input.
Args:
tensor (Tensor): Input tensor image (C, H, W)
Returns:
Denormalized numpy array (H, W, C).
"""
if tensor.dim() == 4:
if tensor.size(0):
tensor = tensor.squeeze(0)
else:
raise ValueError(f"Tensor has batch size of {tensor.size(0)}. Only single batch is supported.")
for tnsr, mean, std in zip(tensor, self.mean, self.std):
tnsr.mul_(std).add_(mean)
array = (tensor * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
return array
def __repr__(self):
"""Representational string."""
return self.__class__.__name__ + "()"
class ToNumpy:
"""Convert Tensor into Numpy Array."""
def __call__(self, tensor: Tensor, dims: Optional[Tuple[int, ...]] = None) -> np.ndarray:
"""Convert Tensor into Numpy Array.
Args:
tensor (Tensor): Tensor to convert. Input tensor in range 0-1.
dims (Optional[Tuple[int, ...]], optional): Convert dimensions from torch to numpy format.
Tuple corresponding to axis permutation from torch tensor to numpy array. Defaults to None.
Returns:
Converted numpy ndarray.
"""
# Default support is (C, H, W) or (N, C, H, W)
if dims is None:
dims = (0, 2, 3, 1) if len(tensor.shape) == 4 else (1, 2, 0)
array = (tensor * 255).permute(dims).cpu().numpy().astype(np.uint8)
if array.shape[0] == 1:
array = array.squeeze(0)
if array.shape[-1] == 1:
array = array.squeeze(-1)
return array
def __repr__(self) -> str:
"""Representational string."""
return self.__class__.__name__ + "()"