File size: 2,858 Bytes
3d93357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, SwinForImageClassification
from torchvision import transforms

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Swin Transformer model with original classifier (1000 classes)
swin_processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window12-384")
model = SwinForImageClassification.from_pretrained("microsoft/swin-large-patch4-window12-384")

# Modify input channels to 4 (RGB + mask)
original_conv = model.swin.embeddings.patch_embeddings.projection
new_conv = torch.nn.Conv2d(
    in_channels=4,
    out_channels=original_conv.out_channels,
    kernel_size=original_conv.kernel_size,
    stride=original_conv.stride,
    padding=original_conv.padding,
    bias=original_conv.bias is not None
)
with torch.no_grad():
    new_conv.weight[:, :3] = original_conv.weight.clone()
    new_conv.weight[:, 3] = original_conv.weight.mean(dim=1)
model.swin.embeddings.patch_embeddings.projection = new_conv

# Load the trained state dict from best_model.pth
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()

# Define transformations for Swin Transformer input
swin_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define label mapping for the first 7 classes
label_to_idx = {
    'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
    'mel': 4, 'nv': 5, 'vasc': 6
}
idx_to_label = {v: k for k, v in label_to_idx.items()}

# Prediction function
def predict(image):
    # Convert numpy array to PIL Image if necessary
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Process image for Swin Transformer
    swin_image = swin_transform(image).to(device)
    
    # Generate a dummy mask channel (all zeros)
    mask = torch.zeros(1, 384, 384).to(device)
    
    # Combine image and dummy mask
    combined = torch.cat([swin_image, mask], dim=0).unsqueeze(0)  # Add batch dimension
    
    # Get prediction using only the first 7 logits
    with torch.no_grad():
        outputs = model(combined).logits[:, :7]  # Take only the first 7 classes
        _, pred = torch.max(outputs, 1)
        pred_label = idx_to_label[pred.item()]
    
    return pred_label

# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Text(),
    title="Skin Cancer Classification",
    description="Upload an image to classify the type of skin cancer. Supported classes: akiec, bcc, bkl, df, mel, nv, vasc."
)

# Launch the interface
iface.launch()