TETSU0701's picture
Update app.py
5fab04c verified
import gradio as gr
import torch
import numpy as np
from pathlib import Path
import re
from Model import OmniPathWithInterTaskAttention
from transformers import AutoModelForCausalLM, AutoTokenizer
import transformers
import os
from threading import Thread
from transformers import TextIteratorStreamer
# 强制设置 Gradio 为英文环境
os.environ["GRADIO_LOCALE"] = "en"
# 设备设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 预加载模型(避免重复加载)
@torch.no_grad()
def load_models():
"""Preload necessary models"""
# 1. Load classification model
ckpt_path = "best_model.pth"
if not Path(ckpt_path).exists():
raise FileNotFoundError(f"Model file not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
label_mappings = ckpt.get('label_mappings', None)
if not label_mappings:
raise ValueError("The checkpoint is missing label_mappings")
ck_cfg = ckpt.get('config', {})
feature_dim = 768 # Adjust according to your actual feature dimension
hidden_dim = int(ck_cfg.get('hidden_dim', 256))
dropout = float(ck_cfg.get('dropout', 0.3))
use_inter_task_attention = bool(ck_cfg.get('use_inter_task_attention', True))
inter_task_heads = int(ck_cfg.get('inter_task_heads', 4))
classification_model = OmniPathWithInterTaskAttention(
label_mappings=label_mappings,
feature_dim=feature_dim,
hidden_dim=hidden_dim,
dropout=dropout,
use_inter_task_attention=use_inter_task_attention,
inter_task_heads=inter_task_heads
).to(device)
classification_model.load_state_dict(ckpt['model_state_dict'], strict=False)
classification_model.eval()
# 2. Load text generation model
llm_model_name = "Qwen/Qwen3-0.6B"
# llm_model_name = "Qwen/QwQ-32B"
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
device_map="auto",
load_in_4bit=True
)
return classification_model, llm_model, tokenizer, label_mappings
# 预加载模型
classification_model, llm_model, tokenizer, label_mappings = load_models()
def analyze_npy_file(npy_file):
"""Analyze NPY file and return prediction results"""
if npy_file is None:
return None, "Please upload an NPY file first"
try:
# Read NPY file
arr = np.load(npy_file.name, allow_pickle=False)
if not isinstance(arr, np.ndarray) or arr.ndim != 2:
return None, "Error: NPY file must be a two-dimensional feature matrix"
features = torch.from_numpy(arr).float()
# Extract short ID
p = Path(npy_file.name)
m = re.search(r'(TCGA-[A-Z0-9]{2}-[A-Z0-9]{4})', p.name.upper())
short_id = m.group(1) if m else p.stem[:12]
# Inference
feat_batch = features.unsqueeze(0).to(device)
outputs = classification_model(feat_batch)
# Decode results
pred_names, pred_scores = {}, {}
for task_name, logits in outputs.items():
probs = torch.softmax(logits[0], dim=-1)
idx = int(torch.argmax(probs).item())
classes = label_mappings[task_name]['classes']
class_name = classes[idx] if 0 <= idx < len(classes) else str(idx)
pred_names[task_name] = class_name
pred_scores[task_name] = float(probs[idx].item())
# Format results
results_text = f"Patient ID: {short_id}\n\nPrediction Results:\n"
for task, name in pred_names.items():
results_text += f"- {task}: {name} (Confidence: {pred_scores.get(task, 0.0):.3f})\n"
return {"pred_names": pred_names, "pred_scores": pred_scores, "patient_id": short_id}, results_text
except Exception as e:
return None, f"An error occurred during processing: {str(e)}"
def generate_response(message, chat_history, analysis_results):
"""Generate streamed LLM response"""
if analysis_results is None:
yield "Please upload an NPY file first to analyze the patient data.", chat_history
return
pred_names = analysis_results["pred_names"]
pred_scores = analysis_results["pred_scores"]
patient_id = analysis_results["patient_id"]
context = f"Patient {patient_id} analysis results:\n"
for task, name in pred_names.items():
context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n"
if "diagnosis" in message.lower() or "result" in message.lower():
prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation."
elif "treatment" in message.lower() or "therapy" in message.lower():
prompt = f"{context}\nBased on the diagnosis, suggest appropriate treatment options and considerations."
elif "prognosis" in message.lower() or "outlook" in message.lower():
prompt = f"{context}\nDiscuss the prognosis and potential outcomes for this patient."
elif "stage" in message.lower():
prompt = f"{context}\nExplain the staging information and its clinical implications."
elif "histology" in message.lower() or "type" in message.lower():
prompt = f"{context}\nDescribe the histological characteristics and their significance."
else:
prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results."
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
thread = Thread(
target=lambda: llm_model.generate(
**model_inputs,
max_new_tokens=1024, # 🚀 改成较小输出以提升速度
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=streamer
)
)
thread.start()
partial = ""
for new_text in streamer:
partial += new_text
# 实时输出
yield "", chat_history + [(message, partial)]
# 完成后写回最终内容到历史
chat_history.append((message, partial))
yield "", chat_history
def upload_file(npy_file, chat_history, analysis_results):
"""Handle file upload and initial analysis"""
if npy_file is None:
return chat_history, analysis_results, "Please select a file to upload"
new_analysis_results, results_text = analyze_npy_file(npy_file)
if new_analysis_results is None:
return chat_history, analysis_results, results_text
# Add analysis results to chat
chat_history.append(("System", f"File uploaded and analyzed successfully!\n{results_text}"))
chat_history.append(("System", "You can now ask questions about this patient's diagnosis, treatment options, prognosis, etc."))
return chat_history, new_analysis_results, "Analysis completed successfully!"
def example_click(example):
"""Handle example question click"""
return example
# Create conversational interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏥 Medical Pathology Diagnostic Chat Assistant
Upload a pathology NPY file and chat with the AI assistant about the diagnosis, treatment options, prognosis, and more.
""")
# Store analysis results in session state
analysis_results = gr.State(value=None)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Upload Patient Data")
file_input = gr.File(
label="Upload NPY Feature File",
file_types=[".npy"],
type="filepath"
)
upload_btn = gr.Button("Upload & Analyze", variant="primary")
status_output = gr.Textbox(
label="Status",
lines=2,
interactive=False
)
with gr.Column(scale=2):
gr.Markdown("### Chat with Medical Assistant")
chatbot = gr.Chatbot(
label="Conversation",
height=400
)
with gr.Row():
msg = gr.Textbox(
label="Your Question",
placeholder="Ask about diagnosis, treatment, prognosis...",
lines=2,
scale=4
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat")
gr.Markdown("### Suggested Questions")
examples = gr.Examples(
examples=[
"What is the diagnosis?",
"What treatment options are available?",
"What is the prognosis?",
"Explain the staging information",
"Describe the histological findings"
],
inputs=msg, # 将示例应用到消息输入框
fn=example_click, # 点击示例时的处理函数
outputs=msg, # 输出到消息输入框
label="Click a question to use it"
)
# Event handlers
upload_btn.click(
upload_file,
inputs=[file_input, chatbot, analysis_results],
outputs=[chatbot, analysis_results, status_output]
)
send_btn.click(
generate_response,
inputs=[msg, chatbot, analysis_results],
outputs=[msg, chatbot]
)
msg.submit(
generate_response,
inputs=[msg, chatbot, analysis_results],
outputs=[msg, chatbot]
)
clear_btn.click(
lambda: ([], None, "Chat cleared"),
inputs=[],
outputs=[chatbot, analysis_results, status_output]
)
if __name__ == "__main__":
demo.launch(share=True)