import torch import torchvision.transforms as transforms import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr import os # === Simple CNN Model Definition === class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64 * 8 * 8, 512) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 64 * 8 * 8) x = F.relu(self.fc1(x)) return self.fc2(x) # === Model Loading === model = SimpleCNN() model_path = 'simple_cnn_dclr_tuned.pth' if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() print(f"Model loaded successfully from {model_path}") else: print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.") # === CIFAR-10 Class Labels === class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck'] # === Image Preprocessing === preprocess = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) # === Inference Function === def inference(input_image: Image.Image): if model.training: model.eval() processed_image = preprocess(input_image).unsqueeze(0) with torch.no_grad(): outputs = model(processed_image) probabilities = F.softmax(outputs, dim=1) confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))} return confidences # === Results Viewer Function === def show_results(input_image: Image.Image): preds = inference(input_image) # Load plots if they exist perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None # Load final test accuracy number test_acc_text = "Final test accuracy not available." if os.path.exists("final_test_accuracy.txt"): with open("final_test_accuracy.txt", "r") as f: test_acc_value = f.read().strip() test_acc_text = f"Final Test Accuracy: {test_acc_value}%" return preds, perf_plot, acc_plot, test_acc_text # === Gradio Interface Setup === example_images = [] interface = gr.Interface( fn=show_results, inputs=gr.Image(type='pil', label='Upload Image'), outputs=[ gr.Label(num_top_classes=3, label='Predictions'), gr.Image(type='filepath', label='Training Performance'), gr.Image(type='filepath', label='Final Test Accuracy Plot'), gr.Textbox(label='Final Test Accuracy') ], title='CIFAR-10 Image Classification with DCLR Optimizer', description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.', examples=example_images ) if __name__ == '__main__': interface.launch()