AravindKumarRajendran commited on
Commit
7ae96e1
·
1 Parent(s): ac1ff66
Files changed (2) hide show
  1. app.py +194 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from torchvision.models import resnet50
6
+ import os
7
+ import logging
8
+ from typing import Optional, Union
9
+ import numpy as np
10
+ from pathlib import Path
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Directory Configuration
17
+ BASE_DIR = Path(__file__).resolve().parent
18
+ MODELS_DIR = BASE_DIR / "models"
19
+ EXAMPLES_DIR = BASE_DIR / "examples"
20
+ STATIC_DIR = BASE_DIR / "static" / "uploaded"
21
+
22
+ # Ensure directories exist
23
+ STATIC_DIR.mkdir(parents=True, exist_ok=True)
24
+
25
+ # Global variables
26
+ MODEL_PATH = MODELS_DIR / "resnet_50.pth"
27
+ CLASSES_PATH = MODELS_DIR / "classes.txt"
28
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+
30
+ def load_class_labels() -> Optional[list]:
31
+ """
32
+ Load class labels from the classes.txt file
33
+ """
34
+ try:
35
+ if not CLASSES_PATH.exists():
36
+ raise FileNotFoundError(f"Classes file not found at {CLASSES_PATH}")
37
+
38
+ with open(CLASSES_PATH, 'r') as f:
39
+ return [line.strip() for line in f.readlines()]
40
+ except Exception as e:
41
+ logger.error(f"Error loading class labels: {str(e)}")
42
+ return None
43
+
44
+ # Load class labels
45
+ CLASS_NAMES = load_class_labels()
46
+ if CLASS_NAMES is None:
47
+ raise RuntimeError("Failed to load class labels from classes.txt")
48
+
49
+ # Cache the model to avoid reloading for each prediction
50
+ model = None
51
+
52
+ def load_model() -> Optional[torch.nn.Module]:
53
+ """
54
+ Load the ResNet50 model with error handling
55
+ """
56
+ global model
57
+
58
+ try:
59
+ if model is not None:
60
+ return model
61
+
62
+ if not MODEL_PATH.exists():
63
+ raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
64
+
65
+ logger.info(f"Loading model on {DEVICE}")
66
+ model = resnet50(pretrained=False)
67
+ model.fc = torch.nn.Linear(model.fc.in_features, len(CLASS_NAMES))
68
+
69
+ # Load the model weights
70
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
71
+
72
+ if 'state_dict' in state_dict:
73
+ state_dict = state_dict['state_dict']
74
+
75
+ model.load_state_dict(state_dict)
76
+ model.to(DEVICE)
77
+ model.eval()
78
+
79
+ logger.info("Model loaded successfully")
80
+ return model
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error loading model: {str(e)}")
84
+ return None
85
+
86
+ def preprocess_image(image: Union[np.ndarray, Image.Image]) -> Optional[torch.Tensor]:
87
+ """
88
+ Preprocess the input image with error handling
89
+ """
90
+ try:
91
+ if isinstance(image, np.ndarray):
92
+ image = Image.fromarray(image)
93
+
94
+ transform = transforms.Compose([
95
+ transforms.Resize((224, 224)),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize(
98
+ mean=[0.485, 0.456, 0.406],
99
+ std=[0.229, 0.224, 0.225]
100
+ )
101
+ ])
102
+
103
+ return transform(image).unsqueeze(0).to(DEVICE)
104
+
105
+ except Exception as e:
106
+ logger.error(f"Error preprocessing image: {str(e)}")
107
+ return None
108
+
109
+ def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]:
110
+ """
111
+ Make predictions on the input image with comprehensive error handling
112
+ Returns the predicted class and top 5 confidence scores
113
+ """
114
+ try:
115
+ if image is None:
116
+ return "Error: No image provided", {}
117
+
118
+ model = load_model()
119
+ if model is None:
120
+ return "Error: Failed to load model", {}
121
+
122
+ input_tensor = preprocess_image(image)
123
+ if input_tensor is None:
124
+ return "Error: Failed to preprocess image", {}
125
+
126
+ with torch.no_grad():
127
+ output = model(input_tensor)
128
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
129
+
130
+ predicted_class_idx = torch.argmax(probabilities).item()
131
+ predicted_class = CLASS_NAMES[predicted_class_idx]
132
+
133
+ # Get top 5 predictions
134
+ top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
135
+
136
+ # Create confidence dictionary for top 5 classes
137
+ confidences = {
138
+ CLASS_NAMES[idx.item()]: float(prob.item())
139
+ for prob, idx in zip(top_5_probs, top_5_indices)
140
+ }
141
+
142
+ return predicted_class, confidences
143
+
144
+ except Exception as e:
145
+ logger.error(f"Prediction error: {str(e)}")
146
+ return f"Error during prediction: {str(e)}", {}
147
+
148
+ def get_example_list() -> list:
149
+ """
150
+ Get list of example images from the examples directory
151
+ """
152
+ try:
153
+ examples = []
154
+ for ext in ['.jpg', '.jpeg', '.png']:
155
+ examples.extend(list(EXAMPLES_DIR.glob(f'*{ext}')))
156
+ return [[str(ex)] for ex in sorted(examples)]
157
+ except Exception as e:
158
+ logger.error(f"Error loading examples: {str(e)}")
159
+ return []
160
+
161
+ # Create Gradio interface with error handling
162
+ try:
163
+ iface = gr.Interface(
164
+ fn=predict,
165
+ inputs=gr.Image(type="numpy", label="Upload Image"),
166
+ outputs=[
167
+ gr.Label(label="Predicted Class", num_top_classes=1),
168
+ gr.Label(label="Top 5 Predictions", num_top_classes=5)
169
+ ],
170
+ title="Image Classification with ResNet50",
171
+ description=(
172
+ "Upload an image to classify:\n"
173
+ "The model will predict the class and show top 5 confidence scores."
174
+ ),
175
+ examples=get_example_list(),
176
+ cache_examples=True,
177
+ theme=gr.themes.Base()
178
+ )
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error creating Gradio interface: {str(e)}")
182
+ raise
183
+
184
+ if __name__ == "__main__":
185
+ try:
186
+ load_model() # Pre-load the model
187
+ iface.launch(
188
+ share=False,
189
+ server_name="0.0.0.0",
190
+ server_port=7860,
191
+ debug=False
192
+ )
193
+ except Exception as e:
194
+ logger.error(f"Error launching application: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=3.50.2
4
+ Pillow>=9.0.0
5
+ numpy>=1.21.0
6
+ typing-extensions>=4.0.0