| | import os
|
| | import torch
|
| | import numpy as np
|
| | from PIL import Image
|
| | from transformers import ConvNextImageProcessor, ConvNextForImageClassification
|
| | from rembg import remove
|
| |
|
| |
|
| | def segment_image(pil_image: Image.Image) -> Image.Image:
|
| | """
|
| | Recebe uma imagem PIL, remove o fundo e coloca fundo preto.
|
| | Retorna a imagem PIL tratada.
|
| | """
|
| | try:
|
| |
|
| | img_no_bg = remove(pil_image)
|
| |
|
| |
|
| |
|
| | fundo_preto = Image.new("RGB", img_no_bg.size, (0, 0, 0))
|
| |
|
| |
|
| | if img_no_bg.mode == 'RGBA':
|
| | mask = img_no_bg.split()[3]
|
| | fundo_preto.paste(img_no_bg, mask=mask)
|
| | return fundo_preto
|
| | else:
|
| | return img_no_bg.convert("RGB")
|
| |
|
| | except Exception as e:
|
| | print(f"AVISO: Falha na segmentação ({e}). Retornando original.")
|
| | return pil_image.convert("RGB")
|
| |
|
| |
|
| |
|
| | class FeatureExtractor:
|
| | def __init__(self, device=None):
|
| | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| | print(f"Usando dispositivo: {self.device}")
|
| |
|
| | self.processor = ConvNextImageProcessor.from_pretrained(
|
| | "facebook/convnext-large-224-22k-1k"
|
| | )
|
| | self.model = ConvNextForImageClassification.from_pretrained(
|
| | "facebook/convnext-large-224-22k-1k"
|
| | ).to(self.device)
|
| |
|
| | self.model.classifier = torch.nn.Identity()
|
| | self.model.eval()
|
| |
|
| | def extract_convnext(self, image_path: str) -> np.ndarray:
|
| | print(f"Processando imagem: {os.path.basename(image_path)}")
|
| | input_img = Image.open(image_path).convert("RGB")
|
| |
|
| |
|
| | final_image = segment_image(input_img)
|
| |
|
| |
|
| | inputs = self.processor(final_image, return_tensors="pt").to(self.device)
|
| |
|
| | with torch.no_grad():
|
| | features = self.model(**inputs).logits
|
| |
|
| | features_np = features.cpu().numpy().flatten()
|
| | return features_np
|
| |
|
| | def process_single_image(image_path: str):
|
| | extractor = FeatureExtractor()
|
| | return extractor.extract_convnext(image_path) |