|
import gradio as gr |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from torchvision import datasets |
|
from torchvision.transforms import ToTensor |
|
import torch.nn.functional as F |
|
|
|
device = ( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" |
|
if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
|
|
class CNN(nn.Module): |
|
def __init__(self): |
|
super(CNN, self).__init__() |
|
|
|
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) |
|
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) |
|
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) |
|
|
|
|
|
self.fc1 = nn.Linear(128 * 3 * 3, 256) |
|
self.fc2 = nn.Linear(256, 10) |
|
|
|
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
self.dropout = nn.Dropout(0.25) |
|
|
|
def forward(self, x): |
|
|
|
x = self.pool(F.relu(self.conv1(x))) |
|
x = self.pool(F.relu(self.conv2(x))) |
|
x = self.pool(F.relu(self.conv3(x))) |
|
|
|
|
|
x = x.view(-1, 128 * 3 * 3) |
|
|
|
|
|
x = F.relu(self.fc1(x)) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
|
|
return x |
|
|
|
model = CNN().to(device) |
|
|
|
|
|
model = CNN().to(device) |
|
model.load_state_dict(torch.load("model_mnist_cnn.pth", map_location=torch.device('cpu'))) |
|
|
|
def predict(im): |
|
|
|
imagen = np.array(im["composite"]) |
|
imagen = imagen[:,:,3] |
|
|
|
print(imagen.shape) |
|
print(imagen.dtype) |
|
|
|
|
|
|
|
imagen_pil = Image.fromarray(imagen, mode='L') |
|
img_resize = imagen_pil.resize((28, 28)) |
|
|
|
|
|
img_np = np.array(img_resize) |
|
|
|
|
|
img_np = img_np.astype(np.float32) / 255.0 |
|
|
|
|
|
img_tensor = torch.from_numpy(img_np) |
|
|
|
|
|
|
|
img_tensor = img_tensor.unsqueeze(0) |
|
|
|
print(img_tensor.shape) |
|
print(img_tensor.dtype) |
|
|
|
classes = [ |
|
"Cero", |
|
"Uno", |
|
"Dos", |
|
"Tres", |
|
"Cuatro", |
|
"Cinco", |
|
"Seis", |
|
"Siete", |
|
"Ocho", |
|
"Nueve", |
|
] |
|
|
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
|
|
img_tensor = img_tensor.unsqueeze(0).to(device) |
|
print("Dentro del grad") |
|
print("Forma de x: ", img_tensor.shape) |
|
print("Tipo de datos de x: ", img_tensor.dtype) |
|
print("\n") |
|
|
|
pred = model(img_tensor) |
|
print("Dentro del model") |
|
print(pred) |
|
print("Forma de pred: ", pred.shape) |
|
print("Tipo de datos de pred: ", pred.dtype) |
|
print("\n") |
|
|
|
|
|
|
|
|
|
predicted = classes[pred[0].argmax(0)] |
|
print(f'Predicci贸n: "{predicted}"') |
|
|
|
|
|
img_tensor = img_tensor.squeeze(0).cpu() |
|
img_tensor = img_tensor.permute(1, 2, 0) |
|
print("Dentro del squeeze") |
|
print("Forma de x: ", img_tensor.shape) |
|
print("Tipo de datos de x: ", img_tensor.dtype) |
|
print("\n") |
|
|
|
|
|
|
|
return im["composite"], predicted |
|
|
|
|
|
with gr.Blocks() as demo: |
|
descripcion = """ |
|
# MNIST |
|
Creado por Gabriel Olmos Leiva |
|
""".strip() |
|
gr.Markdown(descripcion) |
|
with gr.Row(): |
|
with gr.Column(): |
|
im = gr.Sketchpad(type="pil", image_mode='RGBA',) |
|
with gr.Column(): |
|
|
|
prediction_text = gr.Textbox(label="Predicci贸n") |
|
im_preview = gr.Image() |
|
|
|
|
|
|
|
im.change(predict, outputs=[im_preview, prediction_text], inputs=im, show_progress="full", ) |
|
|
|
|
|
demo.launch(share=True, debug=False) |
|
|