import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image from torchvision.models import resnet50 from pathlib import Path import logging import warnings warnings.filterwarnings('ignore') # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Path configurations MODEL_PATH = Path('src/model_10.pth') CLASSES_PATH = Path('src/classes.txt') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Image preprocessing - using the same transforms as training transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def load_classes(): with open(CLASSES_PATH) as f: return [line.strip() for line in f.readlines()] def load_model(): """ Load the trained ResNet50 model """ try: # Initialize model model = resnet50(weights=None) num_classes = len(load_classes()) model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # Load checkpoint checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) # Extract state dict from checkpoint if isinstance(checkpoint, dict): if "model" in checkpoint: state_dict = checkpoint["model"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] elif "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint else: state_dict = checkpoint # Clean state dict keys new_state_dict = {} for k, v in state_dict.items(): name = k.replace("module.", "") if name.startswith("model."): name = name[6:] new_state_dict[name] = v # Load state dict and set to eval mode model.load_state_dict(new_state_dict, strict=False) model.to(DEVICE) model.eval() logger.info("Model loaded successfully") return model except Exception as e: logger.error(f"Error loading model: {e}") raise # Global variables CLASSES = load_classes() MODEL = load_model() def predict_image(image): """ Predict class for input image with top-3 accuracy """ try: if image is None: return "No image provided", "Please upload an image" # Convert to PIL Image if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) # Preprocess image input_tensor = transform(image).unsqueeze(0).to(DEVICE) # Get prediction with torch.no_grad(): output = MODEL(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get top-3 predictions top3_prob, top3_indices = torch.topk(probabilities, k=3) # Format predictions predictions = [] for prob, idx in zip(top3_prob, top3_indices): class_name = CLASSES[idx] confidence = prob.item() * 100 predictions.append(f"{class_name}: {confidence:.2f}%") # Join predictions with newlines predictions_text = "\n".join(predictions) # Get top prediction predicted_class = CLASSES[top3_indices[0]] # Log predictions logger.info(f"Predicted class: {predicted_class}") logger.info(f"Top 3 predictions:\n{predictions_text}") return predicted_class, predictions_text except Exception as e: logger.error(f"Prediction error: {e}") return "Error in prediction", str(e) # Create Gradio interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[ gr.Textbox(label="Predicted Class"), gr.Textbox(label="Top 3 Predictions", lines=3) ], title="ResNet50 Image Classifier", description=( "Upload an image to classify.\n" "The model will predict the class and show confidence scores for the top 3 predictions." ), examples=[ ["examples/example1.jpg"], ["examples/example2.jpg"] ] if Path("examples").exists() else None, theme=gr.themes.Base() ) # Launch the app if __name__ == "__main__": iface.launch()