Tudohuang's picture
Create app.py
df3846f
raw
history blame
2.79 kB
import numpy as np
import torch
import torch.nn as nn
import albumentations as A
import imageio
from PIL import Image
import gradio as gr
import os
import glob
# 設置模型路徑
model_path = "MRIy14.pt"
# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 創建模型
model = torch.load(model_path)
model = model.to(device)
model.eval()
class SimpleCNN(nn.Module):
def __init__(self, num_classes):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
self.relu3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(32 * 28 * 28, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.pool3(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def predict(image_path):
image = imageio.imread(image_path)
im = Image.fromarray(image).convert('L')
im = np.array(im)
im = A.Resize(224, 224)(image=im)['image']
im = np.stack([im]*3, axis=0)[None, ...] #1, 3, 224, 224
im = im/np.max(im)
im_d = torch.from_numpy(im).to(device).float()
with torch.no_grad():
outputs = model(im_d)
_, predicted = torch.max(outputs, 1)
prediction = predicted.cpu().numpy()
if prediction == 1:
return "Cancer"
else:
return "Healthy"
# Gradio 介面函數
def process_input(file_obj):
if file_obj is None:
return "請上傳檔案"
if os.path.isdir(file_obj.name):
results = []
for filename in os.listdir(file_obj.name):
file_path = os.path.join(file_obj.name, filename)
result = predict(file_path)
results.append((filename, result))
return results
else:
return predict(file_obj.name)
# 創建 Gradio 介面
iface = gr.Interface(
fn=process_input,
inputs=gr.inputs.File(type="file", label="上傳單個圖片或資料夾"),
outputs="text",
title="MRI 腫瘤辨識系統",
description="上傳一張 MRI 圖片或包含多張圖片的資料夾進行腫瘤辨識"
)
# 啟動介面
iface.launch()