rmayormartins's picture
Subindo arquivos2abcd
e56fc69
import os
import shutil
import joblib # Para salvar como .pkl
import gradio as gr
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import io
from sklearn.metrics import classification_report, confusion_matrix
# Adicionando os comandos para instalar o protobuf-compiler
os.system('apt-get update')
os.system('apt-get install -y protobuf-compiler')
# Modelos para seleção
model_dict = {
'AlexNet': models.alexnet,
'ResNet18': models.resnet18,
'ResNet34': models.resnet34,
'ResNet50': models.resnet50,
'MobileNetV2': models.mobilenet_v2
}
# Variáveis globais
model = None
train_loader = None
val_loader = None
test_loader = None
dataset_path = 'dataset'
class_dirs = []
test_dataset_path = 'test_dataset'
test_class_dirs = []
num_classes = 2 # Valor padrão para o número de classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Função para configurar as classes
def setup_classes(num_classes_value):
global class_dirs, dataset_path, num_classes
num_classes = int(num_classes_value) # Atualizar a variável global num_classes
# Limpar diretório antigo se existir
if os.path.exists(dataset_path):
shutil.rmtree(dataset_path)
os.makedirs(dataset_path)
# Criar diretórios para cada classe
class_dirs = [os.path.join(dataset_path, f'class_{i}') for i in range(num_classes)]
for class_dir in class_dirs:
os.makedirs(class_dir)
return f"Criados {num_classes} diretórios para classes."
# Função para upload de imagens
def upload_images(class_id, images):
class_dir = class_dirs[int(class_id)]
for image in images:
shutil.copy(image, class_dir)
return f"Imagens salvas na classe {class_id}."
# Função para preparação dos dados
def prepare_data(batch_size=32, resize=(224, 224)):
global train_loader, val_loader, test_loader, num_classes
# Transformações para os dados de treinamento e validação
transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
])
dataset = datasets.ImageFolder(dataset_path, transform=transform)
if len(dataset.classes) != num_classes:
return f"Erro: Número de classes detectadas ({len(dataset.classes)}) não corresponde ao número esperado ({num_classes}). Verifique suas imagens."
# Divisão do conjunto de dados em treinamento, validação e teste
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return "Preparação dos dados concluída com sucesso."
# Função para iniciar o treinamento
def start_training(model_name, epochs, lr):
global model, train_loader, val_loader, device
if train_loader is None or val_loader is None:
return "Erro: Dados não preparados."
model = model_dict[model_name](pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=float(lr))
for epoch in range(int(epochs)):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")
torch.save(model.state_dict(), 'modelo.pth')
return f"Treinamento concluído com sucesso. Modelo salvo."
# Função para avaliação do modelo com conjunto de teste
def evaluate_model(loader):
global model, device, num_classes
if model is None:
return "Erro: Modelo não treinado."
if loader is None:
return "Erro: Conjunto de dados de teste não está preparado."
model.eval()
all_preds = []
all_labels = []
try:
with torch.no_grad():
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
report = classification_report(all_labels, all_preds, labels=list(range(num_classes)), target_names=[f"class_{i}" for i in range(num_classes)], zero_division=0)
return report
except Exception as e:
return f"Erro durante a avaliação: {str(e)}"
# Função para mostrar a matriz de confusão
def show_confusion_matrix(loader):
global model, device, num_classes
if model is None:
return "Erro: Modelo não treinado."
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
plt.figure(figsize=(6, 4.8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=[f"class_{i}" for i in range(num_classes)], yticklabels=[f"class_{i}" for i in range(num_classes)])
plt.xlabel('Predictions')
plt.ylabel('Actuals')
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
return Image.open(buf)
# Função para predição de imagens desconhecidas
def predict_images(images):
global model, device, num_classes
if model is None:
return "Erro: Modelo não treinado."
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
model.eval()
results = []
for image in images:
try:
img = transform(Image.open(image)).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
_, preds = torch.max(outputs, 1)
predicted_class = preds.item()
results.append(f"Imagem {os.path.basename(image)} - Classe prevista: class_{predicted_class}")
except Exception as e:
results.append(f"Erro ao processar a imagem {image}: {str(e)}")
return results
# Função para exportar o modelo
def export_model(format):
global model
if model is None:
return "Erro: Modelo não treinado."
file_path = f"modelo_exportado.{format}"
if format == "pth":
torch.save(model.state_dict(), file_path)
elif format == "onnx":
try:
dummy_input = torch.randn(1, 3, 224, 224).to(device)
torch.onnx.export(model, dummy_input, file_path, export_params=True, opset_version=10, input_names=['input'], output_names=['output'])
except Exception as e:
return f"Erro ao exportar para ONNX: {str(e)}"
elif format == "pkl":
joblib.dump(model, file_path)
else:
return f"Formato {format} não suportado."
return f"Modelo exportado com sucesso para {file_path}"
# Função para configurar diretórios de teste
def setup_test_classes():
global test_class_dirs, test_dataset_path
if os.path.exists(test_dataset_path):
shutil.rmtree(test_dataset_path)
os.makedirs(test_dataset_path)
# Criar diretórios para cada classe
test_class_dirs = [os.path.join(test_dataset_path, f'class_{i}') for i in range(num_classes)]
for class_dir in test_class_dirs:
os.makedirs(class_dir)
return f"Criados {num_classes} diretórios para classes de teste."
# Função para upload de imagens de testee
def upload_test_images(class_id, images):
class_dir = test_class_dirs[int(class_id)]
for image in images:
shutil.copy(image, class_dir)
return f"Imagens de teste salvas na classe {class_id}."
# Função para preparar dados de teste
def prepare_test_data(batch_size=32, resize=(224, 224)):
global test_loader, num_classes
transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
])
test_dataset = datasets.ImageFolder(test_dataset_path, transform=transform)
if len(test_dataset.classes) != num_classes:
return f"Erro: Número de classes detectadas ({len(test_dataset.classes)}) não corresponde ao número esperado ({num_classes}). Verifique suas imagens."
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return "Preparação dos dados de teste concluída com sucesso."
# Interface Gradio
def main():
with gr.Blocks() as demo:
gr.Markdown("# Image Classification Training")
with gr.Tab("Configurar Classes"):
num_classes_input = gr.Number(label="Número de Classes", value=2, precision=0)
setup_button = gr.Button("Configurar Classes")
setup_output = gr.Textbox()
setup_button.click(setup_classes, inputs=num_classes_input, outputs=setup_output)
with gr.Tab("Upload de Imagens"):
upload_inputs = []
for i in range(num_classes):
with gr.Column():
gr.Markdown(f"### Classe {i}")
class_id = gr.Number(label=f"ID da Classe {i}", value=i, precision=0)
images = gr.File(label="Upload de Imagens", file_count="multiple", type="filepath")
upload_button = gr.Button("Upload")
upload_output = gr.Textbox()
upload_inputs.append((class_id, images, upload_button, upload_output))
upload_button.click(upload_images, inputs=[class_id, images], outputs=upload_output)
with gr.Tab("Preparação de Dados"):
batch_size = gr.Number(label="Tamanho do Batch", value=32)
resize = gr.Textbox(label="Resize (Ex: 224,224)", value="224,224")
prepare_button = gr.Button("Preparar Dados")
prepare_output = gr.Textbox()
prepare_button.click(lambda batch_size, resize: prepare_data(batch_size=batch_size, resize=tuple(map(int, resize.split(',')))), inputs=[batch_size, resize], outputs=prepare_output)
with gr.Tab("Treinamento"):
model_name = gr.Dropdown(label="Modelo", choices=list(model_dict.keys()))
epochs = gr.Number(label="Épocas", value=30)
lr = gr.Number(label="Taxa de Aprendizado", value=0.001)
train_button = gr.Button("Iniciar Treinamento")
train_output = gr.Textbox()
train_button.click(start_training, inputs=[model_name, epochs, lr], outputs=train_output)
with gr.Tab("Avaliação do Modelo"):
eval_button = gr.Button("Avaliar Modelo")
eval_output = gr.Textbox()
eval_button.click(lambda: evaluate_model(test_loader), outputs=eval_output)
cm_button = gr.Button("Mostrar Matriz de Confusão")
cm_output = gr.Image()
cm_button.click(lambda: show_confusion_matrix(test_loader), outputs=cm_output)
with gr.Tab("Predição e Avaliação"):
predict_images_input = gr.File(label="Upload de Imagens para Predição", file_count="multiple", type="filepath")
predict_button = gr.Button("Predizer")
predict_output = gr.Textbox()
predict_button.click(predict_images, inputs=predict_images_input, outputs=predict_output)
gr.Markdown("### Upload de Imagens de Teste")
setup_test_button = gr.Button("Configurar Diretórios de Teste")
setup_test_output = gr.Textbox()
setup_test_button.click(setup_test_classes, outputs=setup_test_output)
upload_test_inputs = []
for i in range(num_classes):
with gr.Column():
gr.Markdown(f"### Classe de Teste {i}")
test_class_id = gr.Number(label=f"ID da Classe {i}", value=i, precision=0)
test_images = gr.File(label="Upload de Imagens de Teste", file_count="multiple", type="filepath")
upload_test_button = gr.Button("Upload Imagens de Teste")
upload_test_output = gr.Textbox()
upload_test_inputs.append((test_class_id, test_images, upload_test_button, upload_test_output))
upload_test_button.click(upload_test_images, inputs=[test_class_id, test_images], outputs=upload_test_output)
prepare_test_button = gr.Button("Preparar Dados de Teste")
prepare_test_output = gr.Textbox()
prepare_test_button.click(lambda batch_size, resize: prepare_test_data(batch_size=batch_size, resize=tuple(map(int, resize.split(',')))), inputs=[batch_size, resize], outputs=prepare_test_output)
eval_test_button = gr.Button("Avaliar Conjunto de Teste")
eval_test_output = gr.Textbox()
eval_test_button.click(lambda: evaluate_model(test_loader), outputs=eval_test_output)
cm_test_button = gr.Button("Mostrar Matriz de Confusão do Conjunto de Teste")
cm_test_output = gr.Image()
cm_test_button.click(lambda: show_confusion_matrix(test_loader), outputs=cm_test_output)
with gr.Tab("Exportação"):
export_format = gr.Radio(label="Formato", choices=["pth", "onnx", "pkl"])
export_button = gr.Button("Exportar Modelo")
export_output = gr.Textbox()
export_button.click(export_model, inputs=export_format, outputs=export_output)
demo.launch()
if __name__ == "__main__":
main()