Tudohuang's picture
Update app.py
39c5317
import tqdm
import random
from sklearn.metrics import f1_score
import torchvision.models as models
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
import imageio
from PIL import Image
import gradio as gr
import os
import glob
import cv2
# 創建模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 設置模型路徑
model_path = "MRIy14.pt"
class SimpleCNN(nn.Module):
def __init__(self, num_classes):
super(SimpleCNN, self).__init__()
# 卷積層定義
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
self.relu3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
# 全連接層
self.fc = nn.Linear(32 * 28 * 28, num_classes)
def forward(self, x):
# 前向傳播,保存卷積層輸出
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.pool3(self.relu3(self.conv3(x)))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 加載模型
model = torch.load(model_path, map_location='cpu')
model = model.to(device)
model.eval()
# 圖像變換
transform = A.Compose([
A.Resize(224, 224),
])
def generate_heatmap(last_conv_layer, input_tensor, pred_class):
"""
使用 Grad-CAM 生成更精確的熱力圖。
"""
# 獲取特徵圖和梯度
features = None
def hook_function(module, input, output):
nonlocal features
features = output
hook = last_conv_layer.register_forward_hook(hook_function)
model_output = model(input_tensor)
hook.remove()
# 使用梯度和特徵圖生成熱力圖
# 修正:直接使用 model_output 代替 model.output
gradients = torch.autograd.grad(outputs=model_output[:, pred_class], inputs=features)
pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
for i in range(features.shape[1]):
features[:, i, :, :] *= pooled_gradients[i]
heatmap = torch.mean(features, dim=1).squeeze()
heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
heatmap /= np.max(heatmap)
return heatmap
def overlay_heatmap(image, heatmap, intensity=0.5, threshold=0.5):
"""
優化熱力圖叠加邏輯。
"""
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
mask = heatmap > np.max(heatmap) * threshold
superimposed_img = image.copy()
superimposed_img[mask] = superimposed_img[mask] * (1 - intensity) + heatmap[mask] * intensity
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
return superimposed_img
def predict(image):
"""
預測並生成熱力圖。
"""
processed_image = transform(image=np.array(image))['image']
processed_image = np.transpose(processed_image, (2, 0, 1)).astype(np.float32)
input_tensor = torch.from_numpy(processed_image[None, ...]).to(device)
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output, 1)
prediction = predicted.item()
if prediction == 1:
heatmap = generate_heatmap(model.conv3, input_tensor, prediction)
result_image = overlay_heatmap(np.array(image), heatmap)
return "癌症", result_image
else:
return "健康", image
demo = gr.Interface(fn=predict, inputs=gr.Image(), outputs=["text", "image"])
demo.launch()