RFTSystems commited on
Commit
1faeebc
·
verified ·
1 Parent(s): a1b4ce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -121
app.py CHANGED
@@ -1,20 +1,10 @@
1
- import os
2
  import torch
3
  import torchvision.transforms as transforms
4
- import torchvision
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  import gradio as gr
9
- import numpy as np
10
-
11
- # === Paths ===
12
- ART_DIR = "artifacts"
13
- DCLR_MODEL_PATH = os.path.join(ART_DIR, "dclr_simple_cnn.pth")
14
- DCLR_PERF_PNG = os.path.join(ART_DIR, "dclr_training_performance.png")
15
- DCLR_ACC_PNG = os.path.join(ART_DIR, "dclr_final_test_accuracy.png")
16
- DCLR_ACC_TXT = os.path.join(ART_DIR, "dclr_final_test_accuracy.txt")
17
- BENCHMARK_TXT = os.path.join(ART_DIR, "benchmark_results.txt")
18
 
19
  # === Simple CNN Model Definition ===
20
  class SimpleCNN(nn.Module):
@@ -33,33 +23,28 @@ class SimpleCNN(nn.Module):
33
  x = F.relu(self.fc1(x))
34
  return self.fc2(x)
35
 
36
- # === Load DCLR model (for inference tab) ===
37
  model = SimpleCNN()
38
- if os.path.exists(DCLR_MODEL_PATH):
39
- model.load_state_dict(torch.load(DCLR_MODEL_PATH, map_location=torch.device('cpu')))
 
 
40
  model.eval()
41
- print(f"Model loaded successfully from {DCLR_MODEL_PATH}")
42
  else:
43
- print(f"Warning: Model file '{DCLR_MODEL_PATH}' not found. Run train_dclr_model.py.")
44
 
45
  # === CIFAR-10 Class Labels ===
46
  class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
47
 
48
- # === Image Preprocessing (consistent with training normalization) ===
49
  preprocess = transforms.Compose([
50
  transforms.Resize(32),
51
  transforms.ToTensor(),
52
- transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
53
  ])
54
 
55
- # === CIFAR-10 Test Loader for Benchmark Mode ===
56
- test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
57
- transforms.ToTensor(),
58
- transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
59
- ]))
60
- test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
61
-
62
- # === Inference Function (single image) ===
63
  def inference(input_image: Image.Image):
64
  if model.training:
65
  model.eval()
@@ -70,106 +55,39 @@ def inference(input_image: Image.Image):
70
  confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
71
  return confidences
72
 
73
- # === Benchmark Mode: Evaluate DCLR on full test set (real-time) ===
74
- def benchmark_dclr_realtime():
75
- if not os.path.exists(DCLR_MODEL_PATH):
76
- return "Model missing. Run training first.", {}, None, None
77
-
78
- # Load weights fresh to avoid any accidental state drift
79
- local_model = SimpleCNN()
80
- local_model.load_state_dict(torch.load(DCLR_MODEL_PATH, map_location=torch.device('cpu')))
81
- local_model.eval()
82
-
83
- correct = 0
84
- total = 0
85
- class_correct = np.zeros(10)
86
- class_total = np.zeros(10)
87
-
88
- with torch.no_grad():
89
- for inputs, labels in test_loader:
90
- outputs = local_model(inputs)
91
- _, predicted = outputs.max(1)
92
- total += labels.size(0)
93
- correct += predicted.eq(labels).sum().item()
94
- c = (predicted == labels).squeeze()
95
- for i in range(len(labels)):
96
- label = labels[i].item()
97
- class_correct[label] += c[i].item()
98
- class_total[label] += 1
99
-
100
- overall_acc = round(100.0 * correct / total, 2)
101
- classwise_acc = {class_labels[i]: round(100.0 * class_correct[i] / class_total[i], 2) for i in range(10)}
102
-
103
- perf_plot = DCLR_PERF_PNG if os.path.exists(DCLR_PERF_PNG) else None
104
- acc_plot = DCLR_ACC_PNG if os.path.exists(DCLR_ACC_PNG) else None
105
-
106
- return f"{overall_acc}%", classwise_acc, perf_plot, acc_plot
107
 
108
- # === Benchmark Comparison: Read real ledger (DCLR vs Adam vs Lion) ===
109
- def benchmark_comparison():
110
- if os.path.exists(BENCHMARK_TXT):
111
- with open(BENCHMARK_TXT, "r") as f:
112
- return f.read()
113
- return "No benchmark_results.txt found. Please run train_dclr_model.py to generate real numbers."
114
 
115
- # === Prepare CIFAR-10 Sample Gallery (one per class with captions) ===
116
- sample_dir = "examples"
117
- os.makedirs(sample_dir, exist_ok=True)
 
 
 
118
 
119
- transform_gallery = transforms.Compose([transforms.ToPILImage()])
120
- raw_test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
121
-
122
- example_images = []
123
- seen_classes = set()
124
- for idx in range(len(raw_test_set)):
125
- img, label = raw_test_set[idx]
126
- if label not in seen_classes:
127
- pil_img = transform_gallery(img)
128
- file_path = os.path.join(sample_dir, f"example_{class_labels[label]}.png")
129
- pil_img.save(file_path)
130
- example_images.append([file_path, f"Sample {class_labels[label]}"])
131
- seen_classes.add(label)
132
- if len(seen_classes) == 10:
133
- break
134
 
135
  # === Gradio Interface Setup ===
136
- with gr.Blocks() as demo:
137
- gr.Markdown("# DCLR Optimiser — CIFAR-10 Artifact Viewer")
138
- gr.Markdown("Upload an image for prediction, or use Benchmark tabs for real test results. All numbers are computed from CIFAR-10 runs and saved as reproducible artifacts.")
139
-
140
- with gr.Tab("Single Image Inference (DCLR)"):
141
- inp = gr.Image(type='pil', label='Upload Image (32x32 assumed)')
142
- out = gr.Label(num_top_classes=3, label='Predictions')
143
- perf_img = gr.Image(type='filepath', label='DCLR Training Performance', value=DCLR_PERF_PNG if os.path.exists(DCLR_PERF_PNG) else None)
144
- acc_img = gr.Image(type='filepath', label='DCLR Final Test Accuracy Plot', value=DCLR_ACC_PNG if os.path.exists(DCLR_ACC_PNG) else None)
145
- acc_text = gr.Textbox(label='DCLR Final Test Accuracy')
146
- # If the accuracy text file exists, load it at UI init
147
- if os.path.exists(DCLR_ACC_TXT):
148
- with open(DCLR_ACC_TXT, "r") as f:
149
- acc_text.value = f"Final Test Accuracy: {f.read().strip()}%"
150
- # Hook
151
- inp.change(fn=inference, inputs=inp, outputs=out)
152
-
153
- gr.Examples(
154
- examples=example_images,
155
- inputs=inp,
156
- label="CIFAR-10 Samples (one per class)"
157
- )
158
-
159
- with gr.Tab("Benchmark Mode (DCLR real-time)"):
160
- btn = gr.Button("Run DCLR Benchmark on CIFAR-10 Test Set")
161
- overall = gr.Textbox(label="Overall Test Accuracy (DCLR)")
162
- classwise = gr.JSON(label="Per-Class Accuracy (%) (DCLR)")
163
- perf_plot = gr.Image(type='filepath', label='DCLR Training Performance')
164
- acc_plot = gr.Image(type='filepath', label='DCLR Final Test Accuracy Plot')
165
-
166
- btn.click(fn=benchmark_dclr_realtime, inputs=None, outputs=[overall, classwise, perf_plot, acc_plot])
167
 
168
- with gr.Tab("Benchmark Comparison (DCLR vs Adam vs Lion)"):
169
- gr.Markdown("Reads real results from artifacts/benchmark_results.txt produced by training.")
170
- show_btn = gr.Button("Show Real Benchmark Ledger")
171
- ledger_box = gr.Textbox(label="Benchmark Results", lines=10)
172
- show_btn.click(fn=benchmark_comparison, inputs=None, outputs=ledger_box)
 
 
 
 
 
 
 
 
173
 
174
  if __name__ == '__main__':
175
- demo.launch()
 
 
1
  import torch
2
  import torchvision.transforms as transforms
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import gradio as gr
7
+ import os
 
 
 
 
 
 
 
 
8
 
9
  # === Simple CNN Model Definition ===
10
  class SimpleCNN(nn.Module):
 
23
  x = F.relu(self.fc1(x))
24
  return self.fc2(x)
25
 
26
+ # === Model Loading ===
27
  model = SimpleCNN()
28
+ model_path = 'simple_cnn_dclr_tuned.pth'
29
+
30
+ if os.path.exists(model_path):
31
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
32
  model.eval()
33
+ print(f"Model loaded successfully from {model_path}")
34
  else:
35
+ print(f"Warning: Model file '{model_path}' not found. Please run train_dclr_model.py first.")
36
 
37
  # === CIFAR-10 Class Labels ===
38
  class_labels = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']
39
 
40
+ # === Image Preprocessing ===
41
  preprocess = transforms.Compose([
42
  transforms.Resize(32),
43
  transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
45
  ])
46
 
47
+ # === Inference Function ===
 
 
 
 
 
 
 
48
  def inference(input_image: Image.Image):
49
  if model.training:
50
  model.eval()
 
55
  confidences = {class_labels[i]: float(probabilities[0,i]) for i in range(len(class_labels))}
56
  return confidences
57
 
58
+ # === Results Viewer Function ===
59
+ def show_results(input_image: Image.Image):
60
+ preds = inference(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Load plots if they exist
63
+ perf_plot = "training_performance.png" if os.path.exists("training_performance.png") else None
64
+ acc_plot = "final_test_accuracy.png" if os.path.exists("final_test_accuracy.png") else None
 
 
 
65
 
66
+ # Load final test accuracy number
67
+ test_acc_text = "Final test accuracy not available."
68
+ if os.path.exists("final_test_accuracy.txt"):
69
+ with open("final_test_accuracy.txt", "r") as f:
70
+ test_acc_value = f.read().strip()
71
+ test_acc_text = f"Final Test Accuracy: {test_acc_value}%"
72
 
73
+ return preds, perf_plot, acc_plot, test_acc_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # === Gradio Interface Setup ===
76
+ example_images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ interface = gr.Interface(
79
+ fn=show_results,
80
+ inputs=gr.Image(type='pil', label='Upload Image'),
81
+ outputs=[
82
+ gr.Label(num_top_classes=3, label='Predictions'),
83
+ gr.Image(type='filepath', label='Training Performance'),
84
+ gr.Image(type='filepath', label='Final Test Accuracy Plot'),
85
+ gr.Textbox(label='Final Test Accuracy')
86
+ ],
87
+ title='CIFAR-10 Image Classification with DCLR Optimizer',
88
+ description='Upload an image to see predictions. Training/test plots and accuracy show benchmark results on CIFAR-10.',
89
+ examples=example_images
90
+ )
91
 
92
  if __name__ == '__main__':
93
+ interface.launch()