90AnimalClassification / src /img_processing.py
arturevs's picture
repo
526a74f
import os
import cv2
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet18 # Usaremos uma CNN pré-treinada (ResNet18)
# Configurações
DATASET_PATH = "./animals/" # Pasta principal com as subpastas de animais
CSV_FOLDER = "./csv_folder" # Pasta para salvar os CSVs
IMAGE_SIZE = (224, 224) # Tamanho das imagens para a CNN
# Criar a pasta csv_folder se não existir
if not os.path.exists(CSV_FOLDER):
os.makedirs(CSV_FOLDER)
# Carregar uma CNN pré-treinada (ResNet18) e remover a camada final (fully connected)
cnn_model = resnet18(pretrained=True)
cnn_model = nn.Sequential(*list(cnn_model.children())[:-1]) # Remove a última camada
cnn_model.eval() # Colocar o modelo em modo de avaliação
# Transformações para a imagem
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalização para o modelo pré-treinado
])
# Função para extrair características de uma imagem usando a CNN
def extract_features(image):
with torch.no_grad(): # Desativar cálculo de gradientes
image_tensor = transform(image).unsqueeze(0) # Adicionar dimensão do batch
features = cnn_model(image_tensor) # Extrair características
return features.flatten().numpy() # Achatar e converter para numpy array
# Função para processar uma subpasta (espécie) e salvar em um DataFrame
def process_animal_folder(animal_class, class_path):
# Lista para armazenar os dados da subpasta
data = []
# Percorrer as imagens da subpasta
for image_name in os.listdir(class_path):
image_path = os.path.join(class_path, image_name)
try:
# Verificar se o arquivo é uma imagem válida
if not image_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
print(f"Ignorando arquivo não suportado: {image_path}")
continue
# Carregar a imagem
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Falha ao carregar a imagem: {image_path}")
# Extrair características usando a CNN
features = extract_features(image)
# Adicionar ao dataset com o label sendo o nome da subpasta
data.append([animal_class] + list(features))
except Exception as e:
print(f"Erro ao processar {image_path}: {e}")
# Verificar se há dados antes de criar o DataFrame
if not data:
print(f"Nenhuma imagem válida encontrada na pasta: {class_path}")
return None
# Criar DataFrame
columns = ["label"] + [f"feature_{i}" for i in range(len(data[0]) - 1)]
df = pd.DataFrame(data, columns=columns)
return df
# Percorrer as subpastas
for animal_class in os.listdir(DATASET_PATH):
class_path = os.path.join(DATASET_PATH, animal_class)
# Verificar se é uma pasta
if os.path.isdir(class_path):
print(f"Processando imagens da classe: {animal_class}")
# Processar a subpasta e obter o DataFrame
df = process_animal_folder(animal_class, class_path)
if df is not None:
# Salvar CSV com o nome do animal
csv_filename = os.path.join(CSV_FOLDER, f"{animal_class}_dataset.csv")
try:
df.to_csv(csv_filename, index=False)
print(f"Dataset salvo como '{csv_filename}'")
except Exception as e:
print(f"Erro ao salvar o dataset {csv_filename}: {e}")
print("Processamento concluído!")