import gradio as gr import torch from torchvision import transforms from PIL import Image # Define the lesion type mapping lesion_type_dict = { 0: 'Actinic keratoses', 1: 'Basal cell carcinoma', 2: 'Benign keratosis-like lesions ', 3: 'Dermatofibroma', 4: 'Melanocytic nevi', 5: 'Melanoma', 6: 'Vascular lesions', } # Function to load the model def load_model(): # Load the PyTorch model return torch.load("export_cpu.pkl", map_location=torch.device("cpu")) model = load_model() # Load the model onto CPU # Function to preprocess the image and predict the class def classify_img(img): # Preprocess the image preprocess = transforms.Compose([ transforms.Resize((224, 224)), # Adjust as needed transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # Common normalization ]) # Convert to PIL image and then apply preprocessing img = Image.fromarray(img) input_tensor = preprocess(img).unsqueeze(0) # Add batch dimension # Predict using the model with torch.no_grad(): output = model(input_tensor) # Get model output probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get class probabilities sorted_probabilities = sorted( enumerate(probabilities.tolist()), key=lambda x: x[1], reverse=True ) # Get the predicted class predicted_class = sorted_probabilities[0][0] predicted_name = lesion_type_dict[predicted_class] # Format the output as a readable text block output_text = f"Predicted Class: {predicted_name}\n\nClass Probabilities:\n" for index, prob in sorted_probabilities: class_name = lesion_type_dict[index] output_text += f"{class_name}: {prob:.4f}\n" # Add formatted text for each probability return output_text # Return formatted text as a single string # Gradio interface setup image = gr.Image() # Input is an image label = gr.Textbox() # Output as formatted text # Set up Gradio interface iface = gr.Interface( fn=classify_img, # Prediction function inputs=image, # Image input outputs=label, # Output as formatted text examples=[ ["examples/Actinic keratoses.jpg", "Actinic keratoses"], # First image and its label ["examples/Basal cell carcinoma.jpg", "Basal cell carcinoma"], # Second image and its label ["examples/Benign keratosis-like lesions.jpg", "Benign keratosis-like lesions"], # And so on ["examples/Dermatofibroma.jpg", "Dermatofibroma"], ["examples/melanoma.jpg", "Melanoma"], ["examples/Melanocytic nevi.jpg", "Melanocytic nevi"], ["examples/Vascular lesions.jpg", "Vascular lesions"], ] ) # Launch the Gradio interface iface.launch() # Start the local Gradio server