Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import albumentations as A | |
import imageio | |
from PIL import Image | |
import gradio as gr | |
import os | |
import glob | |
# 設置模型路徑 | |
model_path = "MRIy14.pt" | |
# 設定設備 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 創建模型 | |
model = torch.load(model_path) | |
model = model.to(device) | |
model.eval() | |
class SimpleCNN(nn.Module): | |
def __init__(self, num_classes): | |
super(SimpleCNN, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1) | |
self.relu1 = nn.ReLU() | |
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1) | |
self.relu2 = nn.ReLU() | |
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1) | |
self.relu3 = nn.ReLU() | |
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.fc = nn.Linear(32 * 28 * 28, num_classes) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.relu1(x) | |
x = self.pool1(x) | |
x = self.conv2(x) | |
x = self.relu2(x) | |
x = self.pool2(x) | |
x = self.conv3(x) | |
x = self.relu3(x) | |
x = self.pool3(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc(x) | |
return x | |
def predict(image_path): | |
image = imageio.imread(image_path) | |
im = Image.fromarray(image).convert('L') | |
im = np.array(im) | |
im = A.Resize(224, 224)(image=im)['image'] | |
im = np.stack([im]*3, axis=0)[None, ...] #1, 3, 224, 224 | |
im = im/np.max(im) | |
im_d = torch.from_numpy(im).to(device).float() | |
with torch.no_grad(): | |
outputs = model(im_d) | |
_, predicted = torch.max(outputs, 1) | |
prediction = predicted.cpu().numpy() | |
if prediction == 1: | |
return "Cancer" | |
else: | |
return "Healthy" | |
# Gradio 介面函數 | |
def process_input(file_obj): | |
if file_obj is None: | |
return "請上傳檔案" | |
if os.path.isdir(file_obj.name): | |
results = [] | |
for filename in os.listdir(file_obj.name): | |
file_path = os.path.join(file_obj.name, filename) | |
result = predict(file_path) | |
results.append((filename, result)) | |
return results | |
else: | |
return predict(file_obj.name) | |
# 創建 Gradio 介面 | |
iface = gr.Interface( | |
fn=process_input, | |
inputs=gr.inputs.File(type="file", label="上傳單個圖片或資料夾"), | |
outputs="text", | |
title="MRI 腫瘤辨識系統", | |
description="上傳一張 MRI 圖片或包含多張圖片的資料夾進行腫瘤辨識" | |
) | |
# 啟動介面 | |
iface.launch() | |