import gradio as gr import torch import torch.nn as nn from torchvision import transforms import torchvision.models as models from PIL import Image import numpy as np # Function to load class names from a file def load_class_names(file_path): with open(file_path, 'r') as f: class_names = [line.strip() for line in f.readlines()] return class_names # Function to load Wikipedia links from a file def load_class_info(file_path): with open(file_path, 'r') as f: class_info = [line.strip() for line in f.readlines()] return class_info # Function to load the model from a .pkl file def load_model(model_path, model_type='resnet'): # Load the model state dictionary model_state_dict = torch.load(model_path, map_location=torch.device('cpu')) # Create an instance of the model based on model_type if model_type == 'mobilenet': model = models.mobilenet_v2(pretrained=False) model.classifier[1] = nn.Linear(model.last_channel, num_classes) elif model_type == 'resnet': model = models.resnet50(pretrained=False) model.fc = nn.Linear(model.fc.in_features, num_classes) elif model_type == 'densenet': model = models.densenet121(pretrained=False) model.classifier = nn.Linear(model.classifier.in_features, num_classes) else: raise ValueError(f"Unsupported model type: {model_type}") # Load the state dictionary into the model model.load_state_dict(model_state_dict) # Set the model to evaluation mode model.eval() return model # Define the transformation val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Define the prediction function def predict_image(image): # Convert the NumPy array to a PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image.astype('uint8'), 'RGB') image = val_transform(image) image = image.unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs[0], dim=0) confidence, preds = torch.max(probabilities, 0) confidence_score = confidence.item() * 100 if confidence_score < 30: result = "Not identified" html_result = "" else: class_name = class_names[preds.item()] wiki_link = class_info[preds.item()] result = f"{class_name}: {confidence_score:.2f}%" html_result = f"