Spaces:
Running
on
Zero
Running
on
Zero
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| import math | |
| import random | |
| from typing import Union | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms import functional as TVF | |
| from torchvision.transforms.functional import InterpolationMode | |
| class AreaResize: | |
| def __init__( | |
| self, | |
| max_area: float, | |
| downsample_only: bool = False, | |
| interpolation: InterpolationMode = InterpolationMode.BICUBIC, | |
| ): | |
| self.max_area = max_area | |
| self.downsample_only = downsample_only | |
| self.interpolation = interpolation | |
| def __call__(self, image: Union[torch.Tensor, Image.Image]): | |
| if isinstance(image, torch.Tensor): | |
| height, width = image.shape[-2:] | |
| elif isinstance(image, Image.Image): | |
| width, height = image.size | |
| else: | |
| raise NotImplementedError | |
| scale = math.sqrt(self.max_area / (height * width)) | |
| # keep original height and width for small pictures. | |
| scale = 1 if scale >= 1 and self.downsample_only else scale | |
| resized_height, resized_width = round(height * scale), round(width * scale) | |
| return TVF.resize( | |
| image, | |
| size=(resized_height, resized_width), | |
| interpolation=self.interpolation, | |
| ) | |
| class AreaRandomCrop: | |
| def __init__( | |
| self, | |
| max_area: float, | |
| ): | |
| self.max_area = max_area | |
| def get_params(self, input_size, output_size): | |
| """Get parameters for ``crop`` for a random crop. | |
| Args: | |
| img (PIL Image): Image to be cropped. | |
| output_size (tuple): Expected output size of the crop. | |
| Returns: | |
| tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. | |
| """ | |
| # w, h = _get_image_size(img) | |
| h, w = input_size | |
| th, tw = output_size | |
| if w <= tw and h <= th: | |
| return 0, 0, h, w | |
| i = random.randint(0, h - th) | |
| j = random.randint(0, w - tw) | |
| return i, j, th, tw | |
| def __call__(self, image: Union[torch.Tensor, Image.Image]): | |
| if isinstance(image, torch.Tensor): | |
| height, width = image.shape[-2:] | |
| elif isinstance(image, Image.Image): | |
| width, height = image.size | |
| else: | |
| raise NotImplementedError | |
| resized_height = math.sqrt(self.max_area / (width / height)) | |
| resized_width = (width / height) * resized_height | |
| # print('>>>>>>>>>>>>>>>>>>>>>') | |
| # print((height, width)) | |
| # print( (resized_height, resized_width)) | |
| resized_height, resized_width = round(resized_height), round(resized_width) | |
| i, j, h, w = self.get_params((height, width), (resized_height, resized_width)) | |
| image = TVF.crop(image, i, j, h, w) | |
| return image | |
| class ScaleResize: | |
| def __init__( | |
| self, | |
| scale: float, | |
| ): | |
| self.scale = scale | |
| def __call__(self, image: Union[torch.Tensor, Image.Image]): | |
| if isinstance(image, torch.Tensor): | |
| height, width = image.shape[-2:] | |
| interpolation_mode = InterpolationMode.BILINEAR | |
| antialias = True if image.ndim == 4 else "warn" | |
| elif isinstance(image, Image.Image): | |
| width, height = image.size | |
| interpolation_mode = InterpolationMode.LANCZOS | |
| antialias = "warn" | |
| else: | |
| raise NotImplementedError | |
| scale = self.scale | |
| # keep original height and width for small pictures | |
| resized_height, resized_width = round(height * scale), round(width * scale) | |
| image = TVF.resize( | |
| image, | |
| size=(resized_height, resized_width), | |
| interpolation=interpolation_mode, | |
| antialias=antialias, | |
| ) | |
| return image | |