Spaces:
Running
Running
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| # === Simple CNN Model Definition === | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super(SimpleCNN, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 32, 3, padding=1) | |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.fc1 = nn.Linear(64 * 8 * 8, 512) | |
| self.fc2 = nn.Linear(512, 10) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 64 * 8 * 8) | |
| x = F.relu(self.fc1(x)) | |
| return self.fc2(x) | |
| # === Model Loading === | |
| model = SimpleCNN() | |
| model_path = 'simple_cnn_dclr_tuned.pth' | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| print(f"Model loaded successfully from {model_path}") | |
| else: | |
| print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.") | |
| # === CIFAR-10 Class Labels === | |
| class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck'] | |
| # === Image Preprocessing === | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(32), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) | |
| ]) | |
| # === Inference Function === | |
| def inference(input_image: Image.Image): | |
| if model.training: | |
| model.eval() | |
| processed_image = preprocess(input_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(processed_image) | |
| probabilities = F.softmax(outputs, dim=1) | |
| confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))} | |
| return confidences | |
| # === Results Viewer Function === | |
| def show_results(input_image: Image.Image): | |
| preds = inference(input_image) | |
| # Load plots if they exist | |
| perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None | |
| acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None | |
| # Load final test accuracy number | |
| test_acc_text = "Final test accuracy not available." | |
| if os.path.exists("final_test_accuracy.txt"): | |
| with open("final_test_accuracy.txt", "r") as f: | |
| test_acc_value = f.read().strip() | |
| test_acc_text = f"Final Test Accuracy: {test_acc_value}%" | |
| return preds, perf_plot, acc_plot, test_acc_text | |
| # === Gradio Interface Setup === | |
| example_images = [] | |
| interface = gr.Interface( | |
| fn=show_results, | |
| inputs=gr.Image(type='pil', label='Upload Image'), | |
| outputs=[ | |
| gr.Label(num_top_classes=3, label='Predictions'), | |
| gr.Image(type='filepath', label='Training Performance'), | |
| gr.Image(type='filepath', label='Final Test Accuracy Plot'), | |
| gr.Textbox(label='Final Test Accuracy') | |
| ], | |
| title='CIFAR-10 Image Classification with DCLR Optimizer', | |
| description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.', | |
| examples=example_images | |
| ) | |
| if __name__ == '__main__': | |
| interface.launch() | |