Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import numpy as np | |
| # Function to load class names from a file | |
| def load_class_names(file_path): | |
| with open(file_path, 'r') as f: | |
| class_names = [line.strip() for line in f.readlines()] | |
| return class_names | |
| # Function to load Wikipedia links from a file | |
| def load_class_info(file_path): | |
| with open(file_path, 'r') as f: | |
| class_info = [line.strip() for line in f.readlines()] | |
| return class_info | |
| # Function to load the model from a .pkl file | |
| def load_model(model_path, model_type): | |
| # Load the model state dictionary | |
| model_state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| # Create an instance of the model based on model_type | |
| if model_type == 'mobilenetv2': | |
| model = models.mobilenet_v2(pretrained=False) | |
| model.classifier[1] = nn.Linear(model.last_channel, num_classes) | |
| elif model_type == 'resnet18': | |
| model = models.resnet18(pretrained=False) | |
| model.fc = nn.Linear(model.fc.in_features, num_classes) | |
| elif model_type == 'densenet121': | |
| model = models.densenet121(pretrained=False) | |
| model.classifier = nn.Linear(model.classifier.in_features, num_classes) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| # Load the state dictionary into the model | |
| model.load_state_dict(model_state_dict) | |
| # Set the model to evaluation mode | |
| model.eval() | |
| return model | |
| # Define the transformation | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Define the prediction function | |
| def predict_image(image, model_choice): | |
| global model, current_model | |
| # Check if a model is selected | |
| if model_choice not in model_paths: | |
| return "Error: Please select a valid model.", "" | |
| # Check if an image is provided | |
| if image is None: | |
| return "Error: Please upload an image.", "" | |
| # Load the selected model if it's not already loaded | |
| if model_choice != current_model: | |
| model_path = model_paths[model_choice] | |
| model = load_model(model_path, model_choice) | |
| current_model = model_choice | |
| # Convert the NumPy array to a PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| image = val_transform(image) | |
| image = image.unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| confidence, preds = torch.max(probabilities, 0) | |
| confidence_score = confidence.item() * 100 | |
| if confidence_score < 30: | |
| result = "Not identified" | |
| html_result = "" | |
| else: | |
| class_name = class_names[preds.item()] | |
| wiki_link = class_info[preds.item()] | |
| result = f"{class_name}: {confidence_score:.2f}%" | |
| html_result = f"<h1><br><a href='{wiki_link}' target='_blank'>More Info</a></h1>" | |
| return result, html_result | |
| # Load class names and class info | |
| class_file_path = 'classes.txt' | |
| class_info_path = 'classinfo.txt' | |
| class_names = load_class_names(class_file_path) | |
| class_info = load_class_info(class_info_path) | |
| num_classes = len(class_names) | |
| # Define model paths | |
| model_paths = { | |
| 'densenet121': 'densenet121_15EpochsPretrainedNoExtractionNoLR_model.pkl', | |
| 'resnet18': 'resnet18_25EpochsPretrainedExtractionNoLR_model.pkl', | |
| 'mobilenetv2': 'mobilenetv2_25EpochsPretrainedExtractionNoLR_model.pkl' | |
| } | |
| # Set default model | |
| current_model = 'densenet121' | |
| model = load_model(model_paths[current_model], current_model) | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_image, | |
| inputs=[ | |
| gr.Image(height=500), | |
| gr.Dropdown(choices=["densenet121", "resnet18", "mobilenetv2"], value="densenet121", label="Select Model") | |
| ], | |
| outputs=[gr.Label(num_top_classes=1), gr.HTML()], | |
| title="Animal Classification", | |
| description="Upload an image to get the predicted label", | |
| allow_flagging="never", | |
| ) | |
| # Launch the interface | |
| iface.launch() | |