File size: 2,453 Bytes
ca1f33d
 
 
43aa94f
ca1f33d
8bb6835
ca1f33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0af68ff
ca1f33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccca058
ca1f33d
97dd5a4
ccca058
 
 
 
 
 
 
ca1f33d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
from joblib import load
from gradio import File
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import io

# Transformation and device setup
device = torch.device("cpu")
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the Isolation Forest model
clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib')

# Load feature extractor
feature_extractor_path = 'Models/feature_extractor.pth'
feature_extractor = models.resnet50(weights=None)
feature_extractor.fc = nn.Sequential()
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device))
feature_extractor.to(device)
feature_extractor.eval()

# Load gastric classification model
GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth'
model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device)
model_ft.to(device)
model_ft.eval()

# Anomaly detection and classification function
def classify_image(uploaded_image):
    image = Image.open(uploaded_image).convert('RGB')
    input_image = data_transforms(image).unsqueeze(0).to(device)

    # Anomaly detection
    if is_anomaly(clf, feature_extractor):
        return "Anomaly detected. Image will not be classified.", None

    # Classification
    with torch.no_grad():
        outputs = model_ft(input_image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)

    predicted_class_index = predicted.item()
    class_names = ['abnormal', 'normal']
    predicted_class_name = class_names[predicted_class_index]
    predicted_probability = probabilities[0][predicted_class_index].item() * 100

    return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None

iface = gr.Interface(
    fn=classify_image,
    inputs=File(type="filepath"),
    outputs=gr.Image(),
    title="GastroHub AI Gastric Image Classifier",
    description="Upload an image to classify it as normal or abnormal.",
    article="Above is a sample image to test the results of the model. Click it to see the results.",
    examples=[
        ["Gastric_Images/Ladybug.png"],
    ],
    allow_flagging="never",
)

# Run the Gradio app
iface.launch()