garbage-segregate / ml /predict.py
Rahiq's picture
Deploy waste classification backend with ML model
bf17f74
"""
Inference script for waste classification
Optimized for CPU with fast preprocessing
"""
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import json
from pathlib import Path
class WasteClassifier:
"""Waste classification inference class"""
def __init__(self, model_path='ml/models/best_model.pth', device=None):
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
self.categories = checkpoint['categories']
# Create model
self.model = models.efficientnet_b0(pretrained=False)
num_features = self.model.classifier[1].in_features
self.model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.3),
torch.nn.Linear(num_features, len(self.categories))
)
# Load weights
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print(f"Model loaded successfully on {self.device}")
print(f"Categories: {self.categories}")
def preprocess_image(self, image_input):
"""
Preprocess image from various input formats
Accepts: PIL Image, file path, base64 string, or numpy array
"""
if isinstance(image_input, str):
if image_input.startswith('data:image'):
# Base64 encoded image
image_data = image_input.split(',')[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert('RGB')
else:
# File path
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, np.ndarray):
image = Image.fromarray(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input.convert('RGB')
else:
raise ValueError(f"Unsupported image input type: {type(image_input)}")
return self.transform(image).unsqueeze(0)
def predict(self, image_input):
"""
Predict waste category for input image
Returns:
dict: {
'category': str,
'confidence': float,
'probabilities': dict
}
"""
# Preprocess
image_tensor = self.preprocess_image(image_input).to(self.device)
# Inference
with torch.no_grad():
outputs = self.model(image_tensor)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
# Format results
predicted_category = self.categories[predicted_idx.item()]
confidence_score = confidence.item()
# Get all probabilities
prob_dict = {
category: float(prob)
for category, prob in zip(self.categories, probabilities[0].cpu().numpy())
}
return {
'category': predicted_category,
'confidence': confidence_score,
'probabilities': prob_dict,
'timestamp': int(np.datetime64('now').astype(int) / 1000000)
}
def predict_batch(self, image_inputs):
"""Predict for multiple images"""
results = []
for image_input in image_inputs:
results.append(self.predict(image_input))
return results
def export_to_onnx(model_path='ml/models/best_model.pth',
output_path='ml/models/model.onnx'):
"""Export PyTorch model to ONNX format for deployment"""
classifier = WasteClassifier(model_path)
# Create dummy input
dummy_input = torch.randn(1, 3, 224, 224).to(classifier.device)
# Export
torch.onnx.export(
classifier.model,
dummy_input,
output_path,
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"Model exported to ONNX: {output_path}")
if __name__ == "__main__":
# Test inference
classifier = WasteClassifier()
# Example usage
test_image = "ml/data/processed/test/recyclable/sample.jpg"
if Path(test_image).exists():
result = classifier.predict(test_image)
print("\nPrediction Result:")
print(json.dumps(result, indent=2))