Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms import ToTensor | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import matplotlib.pyplot as plt | |
| # Load the pre-trained model | |
| model = torch.load('model.pth', map_location=torch.device('cpu')) | |
| model.eval() | |
| #define the target layer to pull for gradcam | |
| target_layers = [model.layer4[-1]] | |
| # Define the class labels | |
| class_labels = ['Crazing', 'Inclusion', 'Patches', 'Pitted', 'Rolled', 'Scratches'] | |
| # Transformations for input images | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.4562, 0.4562, 0.4562], std=[0.2502, 0.2502, 0.2502]), | |
| ]) | |
| inv_normalize = transforms.Normalize( | |
| mean=[0.4562, 0.4562, 0.4562], | |
| std=[0.2502, 0.2502, 0.2502] | |
| ) | |
| # Gradio app interface | |
| def classify_image(inp, transperancy=0.8): | |
| model.to("cpu") | |
| input_tensor = preprocess(inp) | |
| input_batch = input_tensor.unsqueeze(0).to('cpu') # Create a batch | |
| cam = GradCAM(model=model,use_cuda=False, target_layers=target_layers) | |
| grayscale_cam = cam(input_tensor=input_batch, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| img = input_tensor.squeeze(0) | |
| img = inv_normalize(img) | |
| rgb_img = np.transpose(img, (1, 2, 0)) | |
| rgb_img = rgb_img.numpy() | |
| rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min()) | |
| visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transperancy) | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| probabilities = F.softmax(output[0], dim=0) | |
| pred_class_idx = torch.argmax(probabilities).item() | |
| class_probabilities = {class_labels[i]: float(probabilities[i]) for i in range(len(class_labels))} | |
| #prob_string = "\n".join([f"{label}: {prob:.2f}" for label, prob in class_probabilities.items()]) | |
| return inp, class_probabilities, visualization | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=[gr.Image(shape=(200, 200),type="pil", label="Input Image"), | |
| gr.Slider(0, 1, value = 0.8, label="Opacity of GradCAM")], | |
| outputs=[ | |
| gr.Image(shape=(200,200),type="numpy", label="Input Image").style(width=300, height=300), | |
| gr.Label(label="Probability of Defect", num_top_classes=3), | |
| gr.Image(shape=(200,200), type="numpy", label="GradCam").style(width=300, height=300) | |
| ], | |
| title="Metal Defects Image Classification", | |
| description="The classification depends on the microscopic scale of the image being uploaded :)" | |
| ) | |
| iface.launch() | |