RFTSystems commited on
Commit
a1b4ce7
·
verified ·
1 Parent(s): 8492c41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -33
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torchvision.transforms as transforms
3
  import torchvision
@@ -5,9 +6,16 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  from PIL import Image
7
  import gradio as gr
8
- import os
9
  import numpy as np
10
 
 
 
 
 
 
 
 
 
11
  # === Simple CNN Model Definition ===
12
  class SimpleCNN(nn.Module):
13
  def __init__(self):
@@ -25,29 +33,30 @@ class SimpleCNN(nn.Module):
25
  x = F.relu(self.fc1(x))
26
  return self.fc2(x)
27
 
28
- # === Model Loading ===
29
  model = SimpleCNN()
30
- model_path = 'simple_cnn_dclr_tuned.pth'
31
-
32
- if os.path.exists(model_path):
33
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
34
  model.eval()
35
- print(f"Model loaded successfully from {model_path}")
36
  else:
37
- print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.")
38
 
39
  # === CIFAR-10 Class Labels ===
40
  class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
41
 
42
- # === Image Preprocessing ===
43
  preprocess = transforms.Compose([
44
  transforms.Resize(32),
45
  transforms.ToTensor(),
46
- transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
47
  ])
48
 
49
  # === CIFAR-10 Test Loader for Benchmark Mode ===
50
- test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
 
 
 
51
  test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
52
 
53
  # === Inference Function (single image) ===
@@ -61,9 +70,16 @@ def inference(input_image: Image.Image):
61
  confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
62
  return confidences
63
 
64
- # === Benchmark Mode: Evaluate on full test set ===
65
- def benchmark():
66
- model.eval()
 
 
 
 
 
 
 
67
  correct = 0
68
  total = 0
69
  class_correct = np.zeros(10)
@@ -71,7 +87,7 @@ def benchmark():
71
 
72
  with torch.no_grad():
73
  for inputs, labels in test_loader:
74
- outputs = model(inputs)
75
  _, predicted = outputs.max(1)
76
  total += labels.size(0)
77
  correct += predicted.eq(labels).sum().item()
@@ -81,33 +97,79 @@ def benchmark():
81
  class_correct[label] += c[i].item()
82
  class_total[label] += 1
83
 
84
- overall_acc = 100.0 * correct / total
85
  classwise_acc = {class_labels[i]: round(100.0 * class_correct[i] / class_total[i], 2) for i in range(10)}
86
 
87
- # Load plots if they exist
88
- perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
89
- acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
90
-
91
- return overall_acc, classwise_acc, perf_plot, acc_plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # === Gradio Interface Setup ===
94
  with gr.Blocks() as demo:
95
- gr.Markdown("## CIFAR-10 Image Classification with DCLR Optimizer")
96
- gr.Markdown("Upload an image for prediction, or run Benchmark Mode to see full test accuracy.")
97
 
98
- with gr.Tab("Single Image Inference"):
99
- inp = gr.Image(type='pil', label='Upload Image')
100
  out = gr.Label(num_top_classes=3, label='Predictions')
 
 
 
 
 
 
 
 
101
  inp.change(fn=inference, inputs=inp, outputs=out)
102
 
103
- with gr.Tab("Benchmark Mode"):
104
- btn = gr.Button("Run Benchmark on CIFAR-10 Test Set")
105
- overall = gr.Textbox(label="Overall Test Accuracy")
106
- classwise = gr.JSON(label="Per-Class Accuracy (%)")
107
- perf_plot = gr.Image(type='filepath', label='Training Performance')
108
- acc_plot = gr.Image(type='filepath', label='Final Test Accuracy Plot')
109
-
110
- btn.click(fn=benchmark, inputs=None, outputs=[overall, classwise, perf_plot, acc_plot])
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == '__main__':
113
  demo.launch()
 
1
+ import os
2
  import torch
3
  import torchvision.transforms as transforms
4
  import torchvision
 
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  import gradio as gr
 
9
  import numpy as np
10
 
11
+ # === Paths ===
12
+ ART_DIR = "artifacts"
13
+ DCLR_MODEL_PATH = os.path.join(ART_DIR, "dclr_simple_cnn.pth")
14
+ DCLR_PERF_PNG = os.path.join(ART_DIR, "dclr_training_performance.png")
15
+ DCLR_ACC_PNG = os.path.join(ART_DIR, "dclr_final_test_accuracy.png")
16
+ DCLR_ACC_TXT = os.path.join(ART_DIR, "dclr_final_test_accuracy.txt")
17
+ BENCHMARK_TXT = os.path.join(ART_DIR, "benchmark_results.txt")
18
+
19
  # === Simple CNN Model Definition ===
20
  class SimpleCNN(nn.Module):
21
  def __init__(self):
 
33
  x = F.relu(self.fc1(x))
34
  return self.fc2(x)
35
 
36
+ # === Load DCLR model (for inference tab) ===
37
  model = SimpleCNN()
38
+ if os.path.exists(DCLR_MODEL_PATH):
39
+ model.load_state_dict(torch.load(DCLR_MODEL_PATH, map_location=torch.device('cpu')))
 
 
40
  model.eval()
41
+ print(f"Model loaded successfully from {DCLR_MODEL_PATH}")
42
  else:
43
+ print(f"Warning: Model file '{DCLR_MODEL_PATH}' not found. Run train_dclr_model.py.")
44
 
45
  # === CIFAR-10 Class Labels ===
46
  class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
47
 
48
+ # === Image Preprocessing (consistent with training normalization) ===
49
  preprocess = transforms.Compose([
50
  transforms.Resize(32),
51
  transforms.ToTensor(),
52
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
53
  ])
54
 
55
  # === CIFAR-10 Test Loader for Benchmark Mode ===
56
+ test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
57
+ transforms.ToTensor(),
58
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
59
+ ]))
60
  test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
61
 
62
  # === Inference Function (single image) ===
 
70
  confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
71
  return confidences
72
 
73
+ # === Benchmark Mode: Evaluate DCLR on full test set (real-time) ===
74
+ def benchmark_dclr_realtime():
75
+ if not os.path.exists(DCLR_MODEL_PATH):
76
+ return "Model missing. Run training first.", {}, None, None
77
+
78
+ # Load weights fresh to avoid any accidental state drift
79
+ local_model = SimpleCNN()
80
+ local_model.load_state_dict(torch.load(DCLR_MODEL_PATH, map_location=torch.device('cpu')))
81
+ local_model.eval()
82
+
83
  correct = 0
84
  total = 0
85
  class_correct = np.zeros(10)
 
87
 
88
  with torch.no_grad():
89
  for inputs, labels in test_loader:
90
+ outputs = local_model(inputs)
91
  _, predicted = outputs.max(1)
92
  total += labels.size(0)
93
  correct += predicted.eq(labels).sum().item()
 
97
  class_correct[label] += c[i].item()
98
  class_total[label] += 1
99
 
100
+ overall_acc = round(100.0 * correct / total, 2)
101
  classwise_acc = {class_labels[i]: round(100.0 * class_correct[i] / class_total[i], 2) for i in range(10)}
102
 
103
+ perf_plot = DCLR_PERF_PNG if os.path.exists(DCLR_PERF_PNG) else None
104
+ acc_plot = DCLR_ACC_PNG if os.path.exists(DCLR_ACC_PNG) else None
105
+
106
+ return f"{overall_acc}%", classwise_acc, perf_plot, acc_plot
107
+
108
+ # === Benchmark Comparison: Read real ledger (DCLR vs Adam vs Lion) ===
109
+ def benchmark_comparison():
110
+ if os.path.exists(BENCHMARK_TXT):
111
+ with open(BENCHMARK_TXT, "r") as f:
112
+ return f.read()
113
+ return "No benchmark_results.txt found. Please run train_dclr_model.py to generate real numbers."
114
+
115
+ # === Prepare CIFAR-10 Sample Gallery (one per class with captions) ===
116
+ sample_dir = "examples"
117
+ os.makedirs(sample_dir, exist_ok=True)
118
+
119
+ transform_gallery = transforms.Compose([transforms.ToPILImage()])
120
+ raw_test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
121
+
122
+ example_images = []
123
+ seen_classes = set()
124
+ for idx in range(len(raw_test_set)):
125
+ img, label = raw_test_set[idx]
126
+ if label not in seen_classes:
127
+ pil_img = transform_gallery(img)
128
+ file_path = os.path.join(sample_dir, f"example_{class_labels[label]}.png")
129
+ pil_img.save(file_path)
130
+ example_images.append([file_path, f"Sample {class_labels[label]}"])
131
+ seen_classes.add(label)
132
+ if len(seen_classes) == 10:
133
+ break
134
 
135
  # === Gradio Interface Setup ===
136
  with gr.Blocks() as demo:
137
+ gr.Markdown("# DCLR Optimiser — CIFAR-10 Artifact Viewer")
138
+ gr.Markdown("Upload an image for prediction, or use Benchmark tabs for real test results. All numbers are computed from CIFAR-10 runs and saved as reproducible artifacts.")
139
 
140
+ with gr.Tab("Single Image Inference (DCLR)"):
141
+ inp = gr.Image(type='pil', label='Upload Image (32x32 assumed)')
142
  out = gr.Label(num_top_classes=3, label='Predictions')
143
+ perf_img = gr.Image(type='filepath', label='DCLR Training Performance', value=DCLR_PERF_PNG if os.path.exists(DCLR_PERF_PNG) else None)
144
+ acc_img = gr.Image(type='filepath', label='DCLR Final Test Accuracy Plot', value=DCLR_ACC_PNG if os.path.exists(DCLR_ACC_PNG) else None)
145
+ acc_text = gr.Textbox(label='DCLR Final Test Accuracy')
146
+ # If the accuracy text file exists, load it at UI init
147
+ if os.path.exists(DCLR_ACC_TXT):
148
+ with open(DCLR_ACC_TXT, "r") as f:
149
+ acc_text.value = f"Final Test Accuracy: {f.read().strip()}%"
150
+ # Hook
151
  inp.change(fn=inference, inputs=inp, outputs=out)
152
 
153
+ gr.Examples(
154
+ examples=example_images,
155
+ inputs=inp,
156
+ label="CIFAR-10 Samples (one per class)"
157
+ )
158
+
159
+ with gr.Tab("Benchmark Mode (DCLR real-time)"):
160
+ btn = gr.Button("Run DCLR Benchmark on CIFAR-10 Test Set")
161
+ overall = gr.Textbox(label="Overall Test Accuracy (DCLR)")
162
+ classwise = gr.JSON(label="Per-Class Accuracy (%) (DCLR)")
163
+ perf_plot = gr.Image(type='filepath', label='DCLR Training Performance')
164
+ acc_plot = gr.Image(type='filepath', label='DCLR Final Test Accuracy Plot')
165
+
166
+ btn.click(fn=benchmark_dclr_realtime, inputs=None, outputs=[overall, classwise, perf_plot, acc_plot])
167
+
168
+ with gr.Tab("Benchmark Comparison (DCLR vs Adam vs Lion)"):
169
+ gr.Markdown("Reads real results from artifacts/benchmark_results.txt produced by training.")
170
+ show_btn = gr.Button("Show Real Benchmark Ledger")
171
+ ledger_box = gr.Textbox(label="Benchmark Results", lines=10)
172
+ show_btn.click(fn=benchmark_comparison, inputs=None, outputs=ledger_box)
173
 
174
  if __name__ == '__main__':
175
  demo.launch()