File size: 2,858 Bytes
3d93357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, SwinForImageClassification
from torchvision import transforms
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load Swin Transformer model with original classifier (1000 classes)
swin_processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window12-384")
model = SwinForImageClassification.from_pretrained("microsoft/swin-large-patch4-window12-384")
# Modify input channels to 4 (RGB + mask)
original_conv = model.swin.embeddings.patch_embeddings.projection
new_conv = torch.nn.Conv2d(
in_channels=4,
out_channels=original_conv.out_channels,
kernel_size=original_conv.kernel_size,
stride=original_conv.stride,
padding=original_conv.padding,
bias=original_conv.bias is not None
)
with torch.no_grad():
new_conv.weight[:, :3] = original_conv.weight.clone()
new_conv.weight[:, 3] = original_conv.weight.mean(dim=1)
model.swin.embeddings.patch_embeddings.projection = new_conv
# Load the trained state dict from best_model.pth
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()
# Define transformations for Swin Transformer input
swin_transform = transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define label mapping for the first 7 classes
label_to_idx = {
'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
'mel': 4, 'nv': 5, 'vasc': 6
}
idx_to_label = {v: k for k, v in label_to_idx.items()}
# Prediction function
def predict(image):
# Convert numpy array to PIL Image if necessary
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Process image for Swin Transformer
swin_image = swin_transform(image).to(device)
# Generate a dummy mask channel (all zeros)
mask = torch.zeros(1, 384, 384).to(device)
# Combine image and dummy mask
combined = torch.cat([swin_image, mask], dim=0).unsqueeze(0) # Add batch dimension
# Get prediction using only the first 7 logits
with torch.no_grad():
outputs = model(combined).logits[:, :7] # Take only the first 7 classes
_, pred = torch.max(outputs, 1)
pred_label = idx_to_label[pred.item()]
return pred_label
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Text(),
title="Skin Cancer Classification",
description="Upload an image to classify the type of skin cancer. Supported classes: akiec, bcc, bkl, df, mel, nv, vasc."
)
# Launch the interface
iface.launch() |