AravindKumarRajendran commited on
Commit
d67c1ff
Β·
1 Parent(s): a84cb50

final working space

Browse files
README.md CHANGED
@@ -10,3 +10,172 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+
15
+ # ResNet50 Image Classifier
16
+
17
+ This is a Gradio web application that uses a trained ResNet50 model to classify images. The application provides real-time predictions with top-3 confidence scores for uploaded images.
18
+
19
+ ## Live Demo
20
+
21
+ Visit the application at [Hugging Face Spaces URL]
22
+
23
+ ## Features
24
+
25
+ - Real-time image classification
26
+ - Top-3 predictions with confidence scores
27
+ - Support for various image formats
28
+ - User-friendly interface
29
+ - Detailed prediction logging
30
+ - Example images for testing
31
+
32
+ ## Using the Application
33
+
34
+ ### Quick Start
35
+ 1. Visit the Hugging Face Space
36
+ 2. Upload an image using one of these methods:
37
+ - Click the "Upload Image" button
38
+ - Drag and drop an image into the input area
39
+ - Use the provided example images
40
+
41
+ ### Input Requirements
42
+ - Supported formats: JPG, PNG, BMP
43
+ - Both color and grayscale images accepted
44
+ - Images are automatically:
45
+ - Resized to 256 pixels
46
+ - Center cropped to 224x224
47
+ - Normalized using ImageNet statistics
48
+
49
+ ### Output Format
50
+ The model returns:
51
+ 1. **Predicted Class**: The most likely class
52
+ 2. **Top 3 Predictions**: Three most likely classes with confidence scores
53
+
54
+ Example output:
55
+ ```
56
+ Predicted Class: dog
57
+ Top 3 Predictions:
58
+ dog: 95.32%
59
+ cat: 3.45%
60
+ fox: 1.23%
61
+ ```
62
+
63
+ ## Technical Details
64
+
65
+ ### Model Architecture
66
+ - Base model: ResNet50
67
+ - Input size: 224x224 pixels
68
+ - Output: Class probabilities through softmax
69
+ - Model format: PyTorch (.pth)
70
+
71
+ ### Image Processing Pipeline
72
+ ```python
73
+ transform = transforms.Compose([
74
+ transforms.Resize(256),
75
+ transforms.CenterCrop(224),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(
78
+ mean=[0.485, 0.456, 0.406],
79
+ std=[0.229, 0.224, 0.225]
80
+ )
81
+ ])
82
+ ```
83
+
84
+ ### File Structure
85
+ ```
86
+ .
87
+ β”œβ”€β”€ app.py # Main application file
88
+ β”œβ”€β”€ requirements.txt # Dependencies
89
+ β”œβ”€β”€ README.md # Documentation
90
+ β”œβ”€β”€ src/
91
+ β”‚ └── model_10.pth # Trained model weights
92
+ β”‚ └── classes.txt # Class labels
93
+ β”œβ”€β”€ models/
94
+ β”‚ └── model_n.pth # other models
95
+ └── examples/ # Example images
96
+ β”œβ”€β”€ example1.jpg
97
+ └── example2.jpg
98
+ ```
99
+
100
+ ## Deployment Guide
101
+
102
+ ### Prerequisites
103
+ 1. Hugging Face account
104
+ 2. Trained ResNet50 model (.pth format)
105
+ 3. Class labels file (classes.txt)
106
+ 4. Example images (optional)
107
+
108
+ ### Deployment Steps
109
+ 1. Create a new Space:
110
+ - Go to huggingface.co/spaces
111
+ - Click "Create new Space"
112
+ - Select "Gradio" as the SDK
113
+ - Use the provided space configuration from this README
114
+
115
+ 2. Upload required files:
116
+ - All files from the File Structure section
117
+ - Ensure correct file paths in app.py
118
+
119
+ 3. The Space will automatically build and deploy
120
+
121
+
122
+ ### Space Configuration
123
+ ```yaml
124
+ title: ResNetonImageNet - ResNet50 Image Classifier
125
+ emoji: πŸ”
126
+ colorFrom: blue
127
+ colorTo: red
128
+ sdk: gradio
129
+ sdk_version: 5.9.1
130
+ app_file: app.py
131
+ pinned: false
132
+ ```
133
+
134
+ ## Troubleshooting
135
+
136
+ ### Common Issues
137
+ 1. **Model Loading Errors**
138
+ - Verify model path in app.py
139
+ - Check model format and class count
140
+
141
+ 2. **Image Upload Issues**
142
+ - Verify supported formats
143
+ - Check image file size
144
+
145
+ 3. **Prediction Errors**
146
+ - First prediction may be slower (model loading)
147
+ - Check input image quality
148
+
149
+ ### Performance Notes
150
+ - CPU inference by default
151
+ - GPU supported if available
152
+ - Batch processing not supported
153
+ - Real-time predictions
154
+
155
+ ## Development
156
+
157
+ ### Requirements
158
+ ```
159
+ torch>=2.0.0
160
+ torchvision>=0.15.0
161
+ gradio>=4.19.2
162
+ Pillow>=9.0.0
163
+ numpy>=1.21.0
164
+ ```
165
+
166
+ ### Local Development
167
+ 1. Clone the repository
168
+ 2. Install dependencies:
169
+ ```bash
170
+ pip install -r requirements.txt
171
+ ```
172
+ 3. Run locally:
173
+ ```bash
174
+ python app.py
175
+ ```
176
+
177
+ ## Support
178
+
179
+ - GitHub Issues: [Repository URL]
180
+ - Hugging Face Forum: [Forum URL]
181
+ - Documentation: [Docs URL]
app.py CHANGED
@@ -3,212 +3,151 @@ import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  from torchvision.models import resnet50
6
- import os
7
- import logging
8
- from typing import Optional, Union
9
- import numpy as np
10
  from pathlib import Path
 
 
 
11
 
12
- # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- # Directory Configuration
17
- BASE_DIR = Path(__file__).resolve().parent
18
- MODELS_DIR = BASE_DIR / "models"
19
- EXAMPLES_DIR = BASE_DIR / "examples"
20
- STATIC_DIR = BASE_DIR / "static" / "uploaded"
21
-
22
- # Ensure directories exist
23
- STATIC_DIR.mkdir(parents=True, exist_ok=True)
24
-
25
- # Global variables
26
- MODEL_PATH = MODELS_DIR / "resnet_50.pth"
27
- CLASSES_PATH = BASE_DIR / "classes.txt"
28
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
 
30
- def load_class_labels() -> Optional[list]:
31
- """
32
- Load class labels from the classes.txt file
33
- """
34
- try:
35
- if not CLASSES_PATH.exists():
36
- raise FileNotFoundError(f"Classes file not found at {CLASSES_PATH}")
37
-
38
- with open(CLASSES_PATH, 'r') as f:
39
- return [line.strip() for line in f.readlines()]
40
- except Exception as e:
41
- logger.error(f"Error loading class labels: {str(e)}")
42
- return None
43
-
44
- # Load class labels
45
- CLASS_NAMES = load_class_labels()
46
- if CLASS_NAMES is None:
47
- raise RuntimeError("Failed to load class labels from classes.txt")
48
-
49
- # Cache the model to avoid reloading for each prediction
50
- model = None
51
-
52
- def load_model() -> Optional[torch.nn.Module]:
53
  """
54
- Load the ResNet50 model with error handling
55
  """
56
- global model
57
-
58
  try:
59
- if model is not None:
60
- return model
61
-
62
- if not MODEL_PATH.exists():
63
- raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
64
 
65
- logger.info(f"Loading model on {DEVICE}")
66
- model = resnet50(pretrained=False)
67
- model.fc = torch.nn.Linear(model.fc.in_features, len(CLASS_NAMES))
68
 
69
- # Load the model weights
70
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
71
-
72
- if 'state_dict' in state_dict:
73
- state_dict = state_dict['state_dict']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- model.load_state_dict(state_dict)
 
76
  model.to(DEVICE)
77
  model.eval()
78
 
79
  logger.info("Model loaded successfully")
80
  return model
81
-
82
  except Exception as e:
83
- logger.error(f"Error loading model: {str(e)}")
84
- return None
85
 
86
- def preprocess_image(image: Union[np.ndarray, Image.Image]) -> Optional[torch.Tensor]:
87
- """
88
- Preprocess the input image with error handling
89
- """
90
- try:
91
- if isinstance(image, np.ndarray):
92
- image = Image.fromarray(image)
93
-
94
- transform = transforms.Compose([
95
- transforms.Resize((224, 224)),
96
- transforms.ToTensor(),
97
- transforms.Normalize(
98
- mean=[0.485, 0.456, 0.406],
99
- std=[0.229, 0.224, 0.225]
100
- )
101
- ])
102
-
103
- return transform(image).unsqueeze(0).to(DEVICE)
104
-
105
- except Exception as e:
106
- logger.error(f"Error preprocessing image: {str(e)}")
107
- return None
108
 
109
- def predict(image: Union[np.ndarray, None]) -> tuple[str, dict]:
110
  """
111
- Make predictions on the input image with comprehensive error handling
112
- Returns the predicted class and top 5 confidence scores
113
  """
114
  try:
115
  if image is None:
116
- return "Error: No image provided", {}
117
 
118
- model = load_model()
119
- if model is None:
120
- return "Error: Failed to load model", {}
121
-
122
- # Ensure model is in eval mode
123
- model.eval()
124
 
125
- input_tensor = preprocess_image(image)
126
- if input_tensor is None:
127
- return "Error: Failed to preprocess image", {}
128
-
129
  with torch.no_grad():
130
- input_tensor = input_tensor.to(DEVICE)
131
- output = model(input_tensor)
132
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
133
-
134
- # Get predictions and confidences
135
- top_5_probs, top_5_indices = torch.topk(probabilities, k=5)
136
 
137
- # Format confidences with exactly 2 decimal places
138
- confidences = {
139
- CLASS_NAMES[idx.item()]: "{:.2f}".format(float(prob.item() * 100))
140
- for prob, idx in zip(top_5_probs, top_5_indices)
141
- }
142
 
143
- predicted_class = CLASS_NAMES[top_5_indices[0].item()]
 
 
 
 
 
144
 
145
- return predicted_class, confidences
146
-
147
- except Exception as e:
148
- logger.error(f"Prediction error: {str(e)}")
149
- return f"Error during prediction: {str(e)}", {}
150
-
151
- def get_example_list() -> list:
152
- """
153
- Get list of example images from the examples directory
154
- """
155
- try:
156
- examples = []
157
- for ext in ['.jpg', '.jpeg', '.png']:
158
- examples.extend(list(EXAMPLES_DIR.glob(f'*{ext}')))
159
- return [[str(ex)] for ex in sorted(examples)]
160
- except Exception as e:
161
- logger.error(f"Error loading examples: {str(e)}")
162
- return []
163
-
164
- # Create Gradio interface with error handling
165
- try:
166
- with gr.Blocks(theme=gr.themes.Base()) as iface:
167
- gr.Markdown("# Image Classification with ResNet50")
168
- gr.Markdown("Upload an image to classify. The model will predict the class and show top 5 confidence scores.")
169
 
170
- with gr.Row():
171
- with gr.Column(scale=1):
172
- input_image = gr.Image(type="numpy", label="Upload Image")
173
- predict_btn = gr.Button("Predict")
174
-
175
- with gr.Column(scale=1):
176
- output_label = gr.Label(label="Predicted Class", num_top_classes=1)
177
- confidence_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)
178
-
179
- # Add examples
180
- gr.Examples(
181
- examples=get_example_list(),
182
- inputs=input_image,
183
- outputs=[output_label, confidence_label],
184
- fn=predict,
185
- cache_examples=True
186
- )
187
-
188
- # Set up prediction event
189
- predict_btn.click(
190
- fn=predict,
191
- inputs=input_image,
192
- outputs=[output_label, confidence_label]
193
- )
194
- input_image.change(
195
- fn=predict,
196
- inputs=input_image,
197
- outputs=[output_label, confidence_label]
198
- )
199
-
200
- except Exception as e:
201
- logger.error(f"Error creating Gradio interface: {str(e)}")
202
- raise
203
-
204
- if __name__ == "__main__":
205
- try:
206
- load_model() # Pre-load the model
207
- iface.launch(
208
- share=False,
209
- server_name="0.0.0.0",
210
- server_port=7860,
211
- debug=False
212
- )
213
  except Exception as e:
214
- logger.error(f"Error launching application: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  from torchvision.models import resnet50
 
 
 
 
6
  from pathlib import Path
7
+ import logging
8
+ import warnings
9
+ warnings.filterwarnings('ignore')
10
 
11
+ # Setup logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
+ # Path configurations
16
+ MODEL_PATH = Path('src/model_10.pth')
17
+ CLASSES_PATH = Path('models/classes.txt')
 
 
 
 
 
 
 
 
 
18
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
 
20
+ # Image preprocessing - using the same transforms as training
21
+ transform = transforms.Compose([
22
+ transforms.Resize(256),
23
+ transforms.CenterCrop(224),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(
26
+ mean=[0.485, 0.456, 0.406],
27
+ std=[0.229, 0.224, 0.225]
28
+ )
29
+ ])
30
+
31
+ def load_classes():
32
+ with open(CLASSES_PATH) as f:
33
+ return [line.strip() for line in f.readlines()]
34
+
35
+ def load_model():
 
 
 
 
 
 
 
36
  """
37
+ Load the trained ResNet50 model
38
  """
 
 
39
  try:
40
+ # Initialize model
41
+ model = resnet50(weights=None)
42
+ num_classes = len(load_classes())
43
+ model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
 
44
 
45
+ # Load checkpoint
46
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
 
47
 
48
+ # Extract state dict from checkpoint
49
+ if isinstance(checkpoint, dict):
50
+ if "model" in checkpoint:
51
+ state_dict = checkpoint["model"]
52
+ elif "state_dict" in checkpoint:
53
+ state_dict = checkpoint["state_dict"]
54
+ elif "model_state_dict" in checkpoint:
55
+ state_dict = checkpoint["model_state_dict"]
56
+ else:
57
+ state_dict = checkpoint
58
+ else:
59
+ state_dict = checkpoint
60
+
61
+ # Clean state dict keys
62
+ new_state_dict = {}
63
+ for k, v in state_dict.items():
64
+ name = k.replace("module.", "")
65
+ if name.startswith("model."):
66
+ name = name[6:]
67
+ new_state_dict[name] = v
68
 
69
+ # Load state dict and set to eval mode
70
+ model.load_state_dict(new_state_dict, strict=False)
71
  model.to(DEVICE)
72
  model.eval()
73
 
74
  logger.info("Model loaded successfully")
75
  return model
76
+
77
  except Exception as e:
78
+ logger.error(f"Error loading model: {e}")
79
+ raise
80
 
81
+ # Global variables
82
+ CLASSES = load_classes()
83
+ MODEL = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def predict_image(image):
86
  """
87
+ Predict class for input image with top-3 accuracy
 
88
  """
89
  try:
90
  if image is None:
91
+ return "No image provided", "Please upload an image"
92
 
93
+ # Convert to PIL Image if needed
94
+ if not isinstance(image, Image.Image):
95
+ image = Image.fromarray(image)
 
 
 
96
 
97
+ # Preprocess image
98
+ input_tensor = transform(image).unsqueeze(0).to(DEVICE)
99
+
100
+ # Get prediction
101
  with torch.no_grad():
102
+ output = MODEL(input_tensor)
 
103
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
 
 
 
104
 
105
+ # Get top-3 predictions
106
+ top3_prob, top3_indices = torch.topk(probabilities, k=3)
 
 
 
107
 
108
+ # Format predictions
109
+ predictions = []
110
+ for prob, idx in zip(top3_prob, top3_indices):
111
+ class_name = CLASSES[idx]
112
+ confidence = prob.item() * 100
113
+ predictions.append(f"{class_name}: {confidence:.2f}%")
114
 
115
+ # Join predictions with newlines
116
+ predictions_text = "\n".join(predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Get top prediction
119
+ predicted_class = CLASSES[top3_indices[0]]
120
+
121
+ # Log predictions
122
+ logger.info(f"Predicted class: {predicted_class}")
123
+ logger.info(f"Top 3 predictions:\n{predictions_text}")
124
+
125
+ return predicted_class, predictions_text
126
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  except Exception as e:
128
+ logger.error(f"Prediction error: {e}")
129
+ return "Error in prediction", str(e)
130
+
131
+ # Create Gradio interface
132
+ iface = gr.Interface(
133
+ fn=predict_image,
134
+ inputs=gr.Image(type="pil", label="Upload Image"),
135
+ outputs=[
136
+ gr.Textbox(label="Predicted Class"),
137
+ gr.Textbox(label="Top 3 Predictions", lines=3)
138
+ ],
139
+ title="ResNet50 Image Classifier",
140
+ description=(
141
+ "Upload an image to classify.\n"
142
+ "The model will predict the class and show confidence scores for the top 3 predictions."
143
+ ),
144
+ examples=[
145
+ ["examples/example1.jpg"],
146
+ ["examples/example2.jpg"]
147
+ ] if Path("examples").exists() else None,
148
+ theme=gr.themes.Base()
149
+ )
150
+
151
+ # Launch the app
152
+ if __name__ == "__main__":
153
+ iface.launch()
models/{resnet_50.pth β†’ model_14.pth} RENAMED
File without changes
classes.txt β†’ src/classes.txt RENAMED
File without changes
{models β†’ src}/model_10.pth RENAMED
File without changes