Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| import re | |
| from Model import OmniPathWithInterTaskAttention | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import transformers | |
| import os | |
| from threading import Thread | |
| from transformers import TextIteratorStreamer | |
| # 强制设置 Gradio 为英文环境 | |
| os.environ["GRADIO_LOCALE"] = "en" | |
| # 设备设置 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # 预加载模型(避免重复加载) | |
| def load_models(): | |
| """Preload necessary models""" | |
| # 1. Load classification model | |
| ckpt_path = "best_model.pth" | |
| if not Path(ckpt_path).exists(): | |
| raise FileNotFoundError(f"Model file not found: {ckpt_path}") | |
| ckpt = torch.load(ckpt_path, map_location=device) | |
| label_mappings = ckpt.get('label_mappings', None) | |
| if not label_mappings: | |
| raise ValueError("The checkpoint is missing label_mappings") | |
| ck_cfg = ckpt.get('config', {}) | |
| feature_dim = 768 # Adjust according to your actual feature dimension | |
| hidden_dim = int(ck_cfg.get('hidden_dim', 256)) | |
| dropout = float(ck_cfg.get('dropout', 0.3)) | |
| use_inter_task_attention = bool(ck_cfg.get('use_inter_task_attention', True)) | |
| inter_task_heads = int(ck_cfg.get('inter_task_heads', 4)) | |
| classification_model = OmniPathWithInterTaskAttention( | |
| label_mappings=label_mappings, | |
| feature_dim=feature_dim, | |
| hidden_dim=hidden_dim, | |
| dropout=dropout, | |
| use_inter_task_attention=use_inter_task_attention, | |
| inter_task_heads=inter_task_heads | |
| ).to(device) | |
| classification_model.load_state_dict(ckpt['model_state_dict'], strict=False) | |
| classification_model.eval() | |
| # 2. Load text generation model | |
| llm_model_name = "Qwen/Qwen3-0.6B" | |
| # llm_model_name = "Qwen/QwQ-32B" | |
| tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, | |
| device_map="auto", | |
| load_in_4bit=True | |
| ) | |
| return classification_model, llm_model, tokenizer, label_mappings | |
| # 预加载模型 | |
| classification_model, llm_model, tokenizer, label_mappings = load_models() | |
| def analyze_npy_file(npy_file): | |
| """Analyze NPY file and return prediction results""" | |
| if npy_file is None: | |
| return None, "Please upload an NPY file first" | |
| try: | |
| # Read NPY file | |
| arr = np.load(npy_file.name, allow_pickle=False) | |
| if not isinstance(arr, np.ndarray) or arr.ndim != 2: | |
| return None, "Error: NPY file must be a two-dimensional feature matrix" | |
| features = torch.from_numpy(arr).float() | |
| # Extract short ID | |
| p = Path(npy_file.name) | |
| m = re.search(r'(TCGA-[A-Z0-9]{2}-[A-Z0-9]{4})', p.name.upper()) | |
| short_id = m.group(1) if m else p.stem[:12] | |
| # Inference | |
| feat_batch = features.unsqueeze(0).to(device) | |
| outputs = classification_model(feat_batch) | |
| # Decode results | |
| pred_names, pred_scores = {}, {} | |
| for task_name, logits in outputs.items(): | |
| probs = torch.softmax(logits[0], dim=-1) | |
| idx = int(torch.argmax(probs).item()) | |
| classes = label_mappings[task_name]['classes'] | |
| class_name = classes[idx] if 0 <= idx < len(classes) else str(idx) | |
| pred_names[task_name] = class_name | |
| pred_scores[task_name] = float(probs[idx].item()) | |
| # Format results | |
| results_text = f"Patient ID: {short_id}\n\nPrediction Results:\n" | |
| for task, name in pred_names.items(): | |
| results_text += f"- {task}: {name} (Confidence: {pred_scores.get(task, 0.0):.3f})\n" | |
| return {"pred_names": pred_names, "pred_scores": pred_scores, "patient_id": short_id}, results_text | |
| except Exception as e: | |
| return None, f"An error occurred during processing: {str(e)}" | |
| def generate_response(message, chat_history, analysis_results): | |
| """Generate streamed LLM response""" | |
| if analysis_results is None: | |
| yield "Please upload an NPY file first to analyze the patient data.", chat_history | |
| return | |
| pred_names = analysis_results["pred_names"] | |
| pred_scores = analysis_results["pred_scores"] | |
| patient_id = analysis_results["patient_id"] | |
| context = f"Patient {patient_id} analysis results:\n" | |
| for task, name in pred_names.items(): | |
| context += f"- {task}: {name} (confidence: {pred_scores.get(task, 0.0):.3f})\n" | |
| if "diagnosis" in message.lower() or "result" in message.lower(): | |
| prompt = f"{context}\nBased on the above analysis results, provide a detailed diagnosis summary and interpretation." | |
| elif "treatment" in message.lower() or "therapy" in message.lower(): | |
| prompt = f"{context}\nBased on the diagnosis, suggest appropriate treatment options and considerations." | |
| elif "prognosis" in message.lower() or "outlook" in message.lower(): | |
| prompt = f"{context}\nDiscuss the prognosis and potential outcomes for this patient." | |
| elif "stage" in message.lower(): | |
| prompt = f"{context}\nExplain the staging information and its clinical implications." | |
| elif "histology" in message.lower() or "type" in message.lower(): | |
| prompt = f"{context}\nDescribe the histological characteristics and their significance." | |
| else: | |
| prompt = f"{context}\nUser question: {message}\nPlease provide a helpful response based on the analysis results." | |
| messages = [{"role": "user", "content": prompt}] | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| thread = Thread( | |
| target=lambda: llm_model.generate( | |
| **model_inputs, | |
| max_new_tokens=1024, # 🚀 改成较小输出以提升速度 | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| streamer=streamer | |
| ) | |
| ) | |
| thread.start() | |
| partial = "" | |
| for new_text in streamer: | |
| partial += new_text | |
| # 实时输出 | |
| yield "", chat_history + [(message, partial)] | |
| # 完成后写回最终内容到历史 | |
| chat_history.append((message, partial)) | |
| yield "", chat_history | |
| def upload_file(npy_file, chat_history, analysis_results): | |
| """Handle file upload and initial analysis""" | |
| if npy_file is None: | |
| return chat_history, analysis_results, "Please select a file to upload" | |
| new_analysis_results, results_text = analyze_npy_file(npy_file) | |
| if new_analysis_results is None: | |
| return chat_history, analysis_results, results_text | |
| # Add analysis results to chat | |
| chat_history.append(("System", f"File uploaded and analyzed successfully!\n{results_text}")) | |
| chat_history.append(("System", "You can now ask questions about this patient's diagnosis, treatment options, prognosis, etc.")) | |
| return chat_history, new_analysis_results, "Analysis completed successfully!" | |
| def example_click(example): | |
| """Handle example question click""" | |
| return example | |
| # Create conversational interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🏥 Medical Pathology Diagnostic Chat Assistant | |
| Upload a pathology NPY file and chat with the AI assistant about the diagnosis, treatment options, prognosis, and more. | |
| """) | |
| # Store analysis results in session state | |
| analysis_results = gr.State(value=None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload Patient Data") | |
| file_input = gr.File( | |
| label="Upload NPY Feature File", | |
| file_types=[".npy"], | |
| type="filepath" | |
| ) | |
| upload_btn = gr.Button("Upload & Analyze", variant="primary") | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=2, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Chat with Medical Assistant") | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask about diagnosis, treatment, prognosis...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat") | |
| gr.Markdown("### Suggested Questions") | |
| examples = gr.Examples( | |
| examples=[ | |
| "What is the diagnosis?", | |
| "What treatment options are available?", | |
| "What is the prognosis?", | |
| "Explain the staging information", | |
| "Describe the histological findings" | |
| ], | |
| inputs=msg, # 将示例应用到消息输入框 | |
| fn=example_click, # 点击示例时的处理函数 | |
| outputs=msg, # 输出到消息输入框 | |
| label="Click a question to use it" | |
| ) | |
| # Event handlers | |
| upload_btn.click( | |
| upload_file, | |
| inputs=[file_input, chatbot, analysis_results], | |
| outputs=[chatbot, analysis_results, status_output] | |
| ) | |
| send_btn.click( | |
| generate_response, | |
| inputs=[msg, chatbot, analysis_results], | |
| outputs=[msg, chatbot] | |
| ) | |
| msg.submit( | |
| generate_response, | |
| inputs=[msg, chatbot, analysis_results], | |
| outputs=[msg, chatbot] | |
| ) | |
| clear_btn.click( | |
| lambda: ([], None, "Chat cleared"), | |
| inputs=[], | |
| outputs=[chatbot, analysis_results, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |