File size: 2,531 Bytes
ae8f111
 
 
 
 
 
 
 
 
 
412ca3a
 
 
 
ae8f111
300c3a7
18a45f0
 
 
412ca3a
ae8f111
3dee463
ae8f111
 
 
 
 
 
412ca3a
ae8f111
 
 
412ca3a
ae8f111
 
 
 
 
 
 
412ca3a
ae8f111
412ca3a
ae8f111
 
 
 
 
 
 
 
 
412ca3a
ae8f111
 
 
 
 
 
3dee463
412ca3a
 
 
 
 
ab15865
64f36cb
412ca3a
3dee463
 
 
 
3e7941e
 
3dee463
3e7941e
5dcad43
3dee463
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import os
import datetime

# Setup
device = torch.device("cpu")
save_dir = "/home/user/app/saved_predictions"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    print("📁 Folder created:", save_dir)
os.makedirs(save_dir, exist_ok=True)

# Load model
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
model.to(device)
model.eval()

# Grad-CAM
target_layer = model.layer4[-1]
cam = GradCAM(model=model, target_layers=[target_layer])

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Predict and save
def predict_retinopathy(image):
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    img = image.convert("RGB").resize((224, 224))
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_tensor)
        probs = F.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred].item()

    label = "DR" if pred == 0 else "NoDR"

    # Grad-CAM
    rgb_img_np = np.array(img).astype(np.float32) / 255.0
    rgb_img_np = np.ascontiguousarray(rgb_img_np)
    grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
    cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
    cam_pil = Image.fromarray(cam_image)

    # Save image with label and confidence
    filename = f"{timestamp}_{label}_{confidence:.2f}.png"
    cam_pil.save(os.path.join(save_dir, filename))

    return cam_pil, f"{label} (Confidence: {confidence:.2f})"

# Gradio app
gr.Interface(
    fn=predict_retinopathy,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Метод Grad-CAM"),
        gr.Text(label="Вероятность ДР в %")
    ],
    title="Диагностика диабетической ретинопатии",
    description="Загрузите ОКТ и смотрите ИИ-карту Grad-CAM heatmap"
).launch()