import gradio as gr import torch import torch.nn as nn from torchvision import transforms import torchvision.models as models from PIL import Image import numpy as np # Function to load class names from a file def load_class_names(file_path): with open(file_path, 'r') as f: class_names = [line.strip() for line in f.readlines()] return class_names # Function to load Wikipedia links from a file def load_class_info(file_path): with open(file_path, 'r') as f: class_info = [line.strip() for line in f.readlines()] return class_info # Function to load the model from a .pkl file def load_model(model_path, model_type): # Load the model state dictionary model_state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Create an instance of the model based on model_type if model_type == 'mobilenetv2': model = models.mobilenet_v2(pretrained=False) model.classifier[1] = nn.Linear(model.last_channel, num_classes) elif model_type == 'resnet18': model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, num_classes) elif model_type == 'densenet121': model = models.densenet121(pretrained=False) model.classifier = nn.Linear(model.classifier.in_features, num_classes) else: raise ValueError(f"Unsupported model type: {model_type}") # Load the state dictionary into the model model.load_state_dict(model_state_dict) # Set the model to evaluation mode model.eval() return model # Define the transformation val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Define the prediction function def predict_image(image, model_choice): global model, current_model # Check if a model is selected if model_choice not in model_paths: return "Error: Please select a valid model.", "" # Load the selected model if it's not already loaded if model_choice != current_model: model_path = model_paths[model_choice] model = load_model(model_path, model_choice) current_model = model_choice # Convert the NumPy array to a PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8'), 'RGB') image = val_transform(image) image = image.unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) confidence, preds = torch.max(probabilities, 0) confidence_score = confidence.item() * 100 if confidence_score < 30: result = "Not identified" html_result = "" else: class_name = class_names[preds.item()] wiki_link = class_info[preds.item()] result = f"{class_name}: {confidence_score:.2f}%" html_result = f"


More Info

" return result, html_result # Load class names and class info class_file_path = 'classes.txt' class_info_path = 'classinfo.txt' class_names = load_class_names(class_file_path) class_info = load_class_info(class_info_path) num_classes = len(class_names) # Define model paths model_paths = { 'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl', 'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl', 'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl' } # Set default model current_model = 'densenet121' model = load_model(model_paths[current_model], current_model) # Create the Gradio interface iface = gr.Interface( fn=predict_image, inputs=[ gr.Image(height=500), gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model") ], outputs=[gr.Label(num_top_classes=1), gr.HTML()], title="Animal Classification", description="Upload an image to get the predicted label", allow_flagging="never", ) # Launch the interface iface.launch()