| from typing import Optional, Union |
| from PIL import Image |
| from torchvision import transforms |
| import torch |
|
|
|
|
| def get_default_transform(img_size: int = 224) -> transforms.Compose: |
| """Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet. |
| |
| Args: |
| img_size: Tamanho da imagem de entrada do modelo (default: 224) |
| |
| Returns: |
| Compose de transforms para preprocessamento |
| """ |
| |
| resize_size = int(img_size * 256 / 224) |
| return transforms.Compose([ |
| transforms.Resize(resize_size), |
| transforms.CenterCrop(img_size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
|
|
| def preprocess_image( |
| image: Union[str, Image.Image], |
| transform: Optional[transforms.Compose] = None, |
| ) -> torch.Tensor: |
| """Carrega e transforma uma imagem (caminho ou PIL) retornando um tensor 1xCxHxW.""" |
| transform = transform or get_default_transform() |
|
|
| if isinstance(image, str): |
| img = Image.open(image).convert('RGB') |
| elif isinstance(image, Image.Image): |
| img = image.convert('RGB') |
| else: |
| raise ValueError("Imagem inválida: informe caminho ou PIL.Image") |
|
|
| return transform(img).unsqueeze(0) |
|
|