|
import gradio as gr |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
from pathlib import Path |
|
from loguru import logger |
|
from src.model import LitEfficientNet |
|
from src.utils.aws_s3_services import S3Handler |
|
|
|
|
|
logger.add("logs/inference.log", rotation="1 MB", level="INFO") |
|
|
|
|
|
class MNISTClassifier: |
|
def __init__(self, checkpoint_path="./checkpoints/best_model.ckpt"): |
|
self.checkpoint_path = checkpoint_path |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Inference will run on device: {self.device}") |
|
|
|
|
|
self.model = self.load_model() |
|
self.model.eval() |
|
|
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize((28, 28)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,), (0.5,)), |
|
] |
|
) |
|
self.labels = [str(i) for i in range(10)] |
|
|
|
def load_model(self): |
|
""" |
|
Loads the model checkpoint for inference. |
|
""" |
|
if not Path(self.checkpoint_path).exists(): |
|
logger.error(f"Checkpoint not found: {self.checkpoint_path}") |
|
raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}") |
|
|
|
logger.info(f"Loading model from checkpoint: {self.checkpoint_path}") |
|
return LitEfficientNet.load_from_checkpoint(self.checkpoint_path).to( |
|
self.device |
|
) |
|
|
|
@torch.no_grad() |
|
def predict(self, image): |
|
""" |
|
Perform inference on a single image. |
|
|
|
Args: |
|
image: Input image in PIL format. |
|
|
|
Returns: |
|
dict: Predicted class probabilities. |
|
""" |
|
if image is None: |
|
logger.error("No image provided for prediction.") |
|
return None |
|
|
|
|
|
img_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
output = self.model(img_tensor) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
return {self.labels[idx]: float(prob) for idx, prob in enumerate(probabilities)} |
|
|
|
|
|
|
|
checkpoint_path = "./checkpoints/best_model.ckpt" |
|
|
|
|
|
s3_handler = S3Handler(bucket_name="deep-bucket-s3") |
|
s3_handler.download_folder( |
|
"checkpoints_test", |
|
"checkpoints", |
|
) |
|
|
|
classifier = MNISTClassifier(checkpoint_path=checkpoint_path) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classifier.predict, |
|
inputs=gr.Image(height=160, width=160, image_mode="L", type="pil"), |
|
outputs=gr.Label(num_top_classes=1), |
|
title="MNIST Classifier", |
|
description="Upload a handwritten digit image to classify it (0-9).", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |
|
|