Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
# Modelo Autoencoder | |
class Autoencoder(nn.Module): | |
def __init__(self): | |
super(Autoencoder, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) | |
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) | |
self.fc1 = nn.Linear(128 * 8 * 8, 32) | |
self.fc2 = nn.Linear(32, 128 * 8 * 8) | |
self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) | |
self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) | |
self.conv6 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1) | |
def encode(self, x): | |
z = torch.tanh(self.conv1(x)) | |
z = torch.tanh(self.conv2(z)) | |
z = torch.tanh(self.conv3(z)) | |
z = z.view(z.size(0), -1) | |
z = torch.tanh(self.fc1(z)) | |
return z | |
def decode(self, x): | |
z = torch.tanh(self.fc2(x)) | |
z = z.view(z.size(0), 128, 8, 8) | |
z = torch.tanh(self.conv4(z)) | |
z = torch.tanh(self.conv5(z)) | |
z = torch.sigmoid(self.conv6(z)) | |
return z | |
def forward(self, x): | |
return self.decode(self.encode(x)) | |
# Cargar el modelo | |
model = Autoencoder() | |
model.load_state_dict(torch.load("autoencoder.pth", map_location=torch.device("cpu"))) | |
model.eval() | |
# Transformaci贸n | |
transform = transforms.Compose([ | |
transforms.Grayscale(), | |
transforms.Resize((64, 64)), | |
transforms.ToTensor() | |
]) | |
# Umbral de error (ajustable) | |
THRESHOLD = 0.01 | |
# Funci贸n de predicci贸n | |
def detectar_anomalia(imagen): | |
img_tensor = transform(imagen).unsqueeze(0) | |
with torch.no_grad(): | |
reconstruida = model(img_tensor) | |
mse = torch.mean((img_tensor - reconstruida) ** 2).item() | |
resultado = "An贸mala" if mse > THRESHOLD else "Normal" | |
return resultado | |
# Interfaz Gradio | |
demo = gr.Interface( | |
fn=detectar_anomalia, | |
inputs=gr.Image(type="pil", label="Sube una imagen para analizar"), | |
outputs=gr.Label(label="Resultado"), | |
examples=["anomalous.png", "normal.png"], | |
title="Detecci贸n de Anomal铆as con Autoencoder (PyTorch)", | |
description="Este Space utiliza un autoencoder entrenado con PyTorch para detectar anomal铆as en im谩genes de textiles.", | |
) | |
demo.launch() | |