ds-ai-app / app.py
yxccai's picture
Update app.py
764612d verified
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import torch
import re
from transformers import AutoConfig, LlamaConfig, AutoModelForSequenceClassification, AutoTokenizer
import torch
# ==== 自定义配置类 ====
class CustomLlamaConfig(LlamaConfig):
model_type = "custom_llama" # 新名称
def _rope_scaling_validation(self):
pass # 禁用RoPE验证
# ==== 注册配置 ====
AutoConfig.register("custom_llama", CustomLlamaConfig) # 使用新名称
# ==== 加载模型 ====
# 1. 加载配置
config = CustomLlamaConfig.from_pretrained("unsloth/DeepSeek-R1-Distill-Llama-8B")
# 2. 加载模型(关键修改)
model = AutoModelForSequenceClassification.from_pretrained(
"unsloth/DeepSeek-R1-Distill-Llama-8B",
config=config,
trust_remote_code=True,
_config_class=CustomLlamaConfig # 明确指定配置类
)
# 3. 加载适配器
model.load_adapter("yxccai/ds-ai-app")
# 4. 加载分词器
tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")
# 2. 加载你的适配器
# model.load_adapter("yxccai/ds-ai-app") # 替换为你的仓库名
# model = LlamaForSequenceClassification.from_pretrained(
# "yxccai/ds-ai-model",
# trust_remote_code=True # 添加这行
# )
# tokenizer = LlamaTokenizer.from_pretrained("yxccai/ds-ai-model")
# tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")
# 疾病标签映射(必须与训练时完全一致!)
disease_labels = [
"脑梗死",
"动脉狭窄",
"动脉闭塞",
"脑缺血",
"其他脑血管病变",
"脑出血",
"动脉瘤",
"动脉壶腹",
# 根据实际标签补充完整列表...
]
# 标准化输入模板(与训练时完全一致)
MEDICAL_PROMPT = """以下是描述任务的指令,请写出一个适当完成请求的回答。
### 指令:
你是一位专业医生,需要根据患者的主诉和检查结果给出诊断结论。回答必须严格按照以下格式:
诊断结论:[具体疾病名称]
### 问题:
{}
### 回答:
{}""" # 第二个占位符保留用于兼容性
def medical_diagnosis(symptoms):
try:
# 输入预处理
symptoms = symptoms.strip()
if not symptoms:
return "⚠️ 请输入有效的症状描述"
# 检测危险关键词
emergency_keywords = ["昏迷", "胸痛", "呼吸困难", "意识丧失"]
if any(kw in symptoms for kw in emergency_keywords):
return "🚨 检测到危急症状!请立即前往急诊科就诊!"
# 构建标准化输入
formatted_input = MEDICAL_PROMPT.format(symptoms, "")
# 模型推理
inputs = tokenizer(
formatted_input,
max_length=1024,
truncation=True,
padding=True,
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
# 后处理
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities).item()
confidence = probabilities[0][predicted_class].item()
# 生成结构化报告
diagnosis_report = f"""
🩺 诊断报告:
-----------------------------
▪ 主要症状:{extract_key_symptoms(symptoms)}
▪ 最可能诊断:{disease_labels[predicted_class]}
▪ 置信度:{confidence*100:.1f}%
▪ 鉴别诊断:{get_differential_diagnosis(predicted_class)}
-----------------------------
⚠️ 注意:本结果仅供参考,请以临床检查为准
"""
return diagnosis_report
except Exception as e:
return f"❌ 诊断过程中出现错误:{str(e)}"
def extract_key_symptoms(text):
"""提取关键症状"""
keywords = ["头晕", "肢体无力", "言语不利", "麻木", "呕吐"]
found = [kw for kw in keywords if kw in text]
return "、".join(found[:3]) + "等" if len(found) > 3 else "、".join(found)
def get_differential_diagnosis(disease_id):
"""获取鉴别诊断"""
differential_map = {
0: ["脑出血", "短暂性脑缺血发作", "颅内肿瘤"],
1: ["动脉粥样硬化", "血管炎", "纤维肌性发育不良"],
2: ["动脉栓塞", "大动脉炎", "血栓形成"],
3: ["梅尼埃病", "前庭神经炎", "低血糖反应"],
}
return " | ".join(differential_map.get(disease_id, []))
# 创建医疗专用界面
interface = gr.Interface(
fn=medical_diagnosis,
inputs=gr.Textbox(
label="患者症状描述",
placeholder="请输入详细症状(示例:持续头痛三天,伴随恶心呕吐)",
lines=5
),
outputs=gr.Markdown(
label="AI辅助诊断报告",
show_copy_button=True
),
title="神经内科疾病辅助诊断系统",
description="**专业提示**:请输入完整的症状描述,包括:\n- 主要症状及持续时间\n- 伴随症状\n- 既往病史\n- 检查结果",
examples=[
["主诉:左侧肢体无力3天,伴言语不清。既往脑梗死病史5年..."],
["头晕伴行走不稳2天,MRI显示小脑梗死灶..."],
["突发右侧肢体麻木,CTA显示颈动脉狭窄..."],
],
allow_flagging="never",
theme="soft"
)
# 安全设置
interface.launch(
server_name="0.0.0.0",
server_port=7860,
auth=("doctor", "dsaimodel") # 建议修改为自定义账号密码
)