soutrik
orphan branch
c3d82b0
raw
history blame
2.95 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from pathlib import Path
from loguru import logger
from src.model import LitEfficientNet
from src.utils.aws_s3_services import S3Handler
# Configure Loguru for logging
logger.add("logs/inference.log", rotation="1 MB", level="INFO")
class MNISTClassifier:
def __init__(self, checkpoint_path="./checkpoints/best_model.ckpt"):
self.checkpoint_path = checkpoint_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Inference will run on device: {self.device}")
# Load the model
self.model = self.load_model()
self.model.eval()
# Define transforms
self.transform = transforms.Compose(
[
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
self.labels = [str(i) for i in range(10)] # MNIST labels are 0-9
def load_model(self):
"""
Loads the model checkpoint for inference.
"""
if not Path(self.checkpoint_path).exists():
logger.error(f"Checkpoint not found: {self.checkpoint_path}")
raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}")
logger.info(f"Loading model from checkpoint: {self.checkpoint_path}")
return LitEfficientNet.load_from_checkpoint(self.checkpoint_path).to(
self.device
)
@torch.no_grad()
def predict(self, image):
"""
Perform inference on a single image.
Args:
image: Input image in PIL format.
Returns:
dict: Predicted class probabilities.
"""
if image is None:
logger.error("No image provided for prediction.")
return None
# Convert to tensor and preprocess
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Perform inference
output = self.model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Map probabilities to labels
return {self.labels[idx]: float(prob) for idx, prob in enumerate(probabilities)}
# Instantiate the classifier
checkpoint_path = "./checkpoints/best_model.ckpt"
# Download checkpoint from S3 (if needed)
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
s3_handler.download_folder(
"checkpoints_test",
"checkpoints",
)
classifier = MNISTClassifier(checkpoint_path=checkpoint_path)
# Define Gradio interface
demo = gr.Interface(
fn=classifier.predict,
inputs=gr.Image(height=160, width=160, image_mode="L", type="pil"),
outputs=gr.Label(num_top_classes=1),
title="MNIST Classifier",
description="Upload a handwritten digit image to classify it (0-9).",
)
if __name__ == "__main__":
demo.launch(share=True)