MedicalAILabo's picture
Update app.py
cfdde14 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
import torchvision.transforms as T
from PIL import Image
from lib.framework import create_model
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
from lib.dataloader import ImageMixin
# ===========================================
# 1) パス設定
# ===========================================
WEIGHT_PATH = "./cxp_projection_rotation.pt"
PARAMETER_JSON = "./parameters.json"
# ===========================================
# 2) クラスラベル定義
# ===========================================
LABEL_APorPA = ["AP", "PA", "Lateral"]
LABEL_ROUND = ["Upright", "Inverted", "Left-rotation", "Right-rotation"]
# ===========================================
# 3) 前処理クラス
# ===========================================
class ImageHandler(ImageMixin):
def __init__(self, params):
self.params = params
self.transform = T.Compose([
T.ToTensor(),
])
def set_image(self, image: Image.Image):
tensor = self.transform(image) # [C,H,W], float32 in [0,1]
return {"image": tensor.unsqueeze(0)} # バッチ次元追加
# ===========================================
# 4) パラメータロード
# ===========================================
def load_parameter(parameter_path):
_args = ParamSet()
params = _retrieve_parameter(parameter_path)
for k, v in params.items():
setattr(_args, k, v)
# 推論用に上書き
_args.augmentation = "no"
_args.sampler = "no"
_args.pretrained = False
_args.mlp = None
_args.net = _args.model
_args.device = torch.device("cpu")
return (
_dispatch_by_group(_args, "model"),
_dispatch_by_group(_args, "dataloader"),
)
args_model, args_dataloader = load_parameter(PARAMETER_JSON)
# ===========================================
# 5) モデル作成&重みロード
# ===========================================
model = create_model(args_model)
print(f"Loading weights from {WEIGHT_PATH}")
model.load_weight(WEIGHT_PATH)
model.eval()
# ===========================================
# 6) 推論+HTML生成
# ===========================================
def predict_html(image_path: str) -> str:
# 画像読み込み
img = Image.open(image_path).convert("L")
handler = ImageHandler(args_dataloader)
batch = handler.set_image(img)
with torch.no_grad():
outputs = model(batch)
logits_proj = outputs.get("label_APorPA")
logits_rot = outputs.get("label_round")
# softmax で確率に変換
probs_proj = F.softmax(logits_proj, dim=1)[0].cpu().numpy()
probs_rot = F.softmax(logits_rot, dim=1)[0].cpu().numpy()
# argmax でラベル選択
idx_proj = int(probs_proj.argmax())
idx_rot = int(probs_rot.argmax())
pred_proj = LABEL_APorPA[idx_proj]
pred_rot = LABEL_ROUND[idx_rot]
conf_proj = float(probs_proj[idx_proj])
conf_rot = float(probs_rot[idx_rot])
# ファイル名から元ラベル取得(例: "1_AP_Upright.png" → orig_proj="AP", orig_rot="Upright")
base = os.path.splitext(os.path.basename(image_path))[0]
parts = base.split("_")
if len(parts) >= 3:
orig_proj = parts[1]
orig_rot = parts[2]
else:
orig_proj = orig_rot = None
# 警告HTML作成用ヘルパー
def make_warning(kind, orig, pred, conf):
# kind: "projection" or "rotation"
high_thr = 0.8
med_thr = 0.5
if orig and orig != pred:
if conf >= high_thr:
return (
f"<p style='color:red'>⚠ Potentially mislabeled {kind}: "
f"filename says {orig}, model predicts {pred} (confidence {conf:.2f})</p>"
)
elif conf >= med_thr:
return (
f"<p style='color:orange'>⚠ There is a possibility of mislabeled {kind}: "
f"model predicts {pred} with moderate confidence ({conf:.2f})</p>"
)
if conf < med_thr:
return (
f"<p style='color:orange'>⚠ Low confidence for {kind} ({conf:.2f}); "
f"please check image quality or framing.</p>"
)
return ""
# 警告HTML
warn_html = ""
warn_html += make_warning("projection", orig_proj, pred_proj, conf_proj)
warn_html += make_warning("rotation", orig_rot, pred_rot, conf_rot)
# クラスごとのスコア表示用HTML
scores_proj = ", ".join(
f"{LABEL_APorPA[i]}: {p:.2f}" for i, p in enumerate(probs_proj)
)
scores_rot = ", ".join(
f"{LABEL_ROUND[i]}: {p:.2f}" for i, p in enumerate(probs_rot)
)
# 結果表示用HTML
html = (
f"<p><strong>Projection :</strong> {pred_proj} "
f"<small>({scores_proj})</small></p>"
f"<p><strong>Rotation :</strong> {pred_rot} "
f"<small>({scores_rot})</small></p>"
f"{warn_html}"
)
return html
# ===========================================
# 7) Gradio UI
# ===========================================
html_header = """
<div style="padding:10px;border:1px solid #ddd;border-radius:5px">
<h2>Chest X‑ray Projection & Rotation Classification</h2>
<p>Upload a 256×256 grayscale PNG. The model predicts projection (AP/PA/Lateral)
and rotation (Upright/Inverted/Left/Right) and shows softmax confidences.
It warns if filename label differs or if confidence is low. Please name the files using the format: [Number]_projection_rotation.png.
For the projection part of the filename, please use one of the following three terms: AP/PA/Lateral
For the rotation part of the filename, please use one of the following four terms: Upright/Inverted/Left90/Right90
As samples, We have prepared two sets of images: A PA view in the Upright position. An AP view with Left rotation. For each set, we have created two versions, one of which includes a mislabel in either its projection or rotation tag.</p>
</div>
"""
with gr.Blocks(title="CXR Projection & Rotation") as demo:
gr.HTML(html_header)
with gr.Row():
input_image = gr.Image(
label="Upload PNG (256×256)",
type="filepath",
image_mode="L"
)
output_html = gr.HTML()
send_btn = gr.Button("Run Inference")
send_btn.click(
fn=predict_html,
inputs=input_image,
outputs=output_html
)
# サンプル例
gr.Examples(
examples=[
"./sample/1_AP_Upright.png",
"./sample/1_PA_Inverted.png",
"./sample/2_AP_Right-rotation.png",
"./sample/2_Lateral_Left-rotation.png",
],
inputs=input_image
)
# サンプルのファイル名を一覧で表示
gr.Markdown(
"**Sample filenames:** 𝚮\n"
"- 1_AP_Upright.png \n"
"- 1_PA_Inverted.png \n"
"- 2_AP_Right90.png \n"
"- 2_Lateral_Left90.png"
)
if __name__ == "__main__":
demo.launch(debug=True)