DCLR_Optimiser / app.py
RFTSystems's picture
Update app.py
1faeebc verified
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os
# === Simple CNN Model Definition ===
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = F.relu(self.fc1(x))
return self.fc2(x)
# === Model Loading ===
model = SimpleCNN()
model_path = 'simple_cnn_dclr_tuned.pth'
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print(f"Model loaded successfully from {model_path}")
else:
print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.")
# === CIFAR-10 Class Labels ===
class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
# === Image Preprocessing ===
preprocess = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
# === Inference Function ===
def inference(input_image: Image.Image):
if model.training:
model.eval()
processed_image = preprocess(input_image).unsqueeze(0)
with torch.no_grad():
outputs = model(processed_image)
probabilities = F.softmax(outputs, dim=1)
confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
return confidences
# === Results Viewer Function ===
def show_results(input_image: Image.Image):
preds = inference(input_image)
# Load plots if they exist
perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
# Load final test accuracy number
test_acc_text = "Final test accuracy not available."
if os.path.exists("final_test_accuracy.txt"):
with open("final_test_accuracy.txt", "r") as f:
test_acc_value = f.read().strip()
test_acc_text = f"Final Test Accuracy: {test_acc_value}%"
return preds, perf_plot, acc_plot, test_acc_text
# === Gradio Interface Setup ===
example_images = []
interface = gr.Interface(
fn=show_results,
inputs=gr.Image(type='pil', label='Upload Image'),
outputs=[
gr.Label(num_top_classes=3, label='Predictions'),
gr.Image(type='filepath', label='Training Performance'),
gr.Image(type='filepath', label='Final Test Accuracy Plot'),
gr.Textbox(label='Final Test Accuracy')
],
title='CIFAR-10 Image Classification with DCLR Optimizer',
description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.',
examples=example_images
)
if __name__ == '__main__':
interface.launch()