基于 RoBERTa 微调的医学问诊意图识别模型
项目简介
项目来源:中科(安徽)G60智慧健康创新研究院(以下简称 “中科”)围绕心理健康大模型研发的对话导诊系统,本项目为其中的意图识别任务。
模型用途:将用户输入对话系统中的
query
文本进行意图识别,判别其意向是【问诊】or【闲聊】。
数据描述
数据来源:由 Hugging Face 的开源对话数据集,以及中科内部的垂域医学对话数据集经过清洗和预处理融合构建而成。
数据划分:共计 6000 条样本,其中,训练集 4800 条,测试集1200 条,并在数据构建过程中确保了正负样例的平衡。
数据样例:
[
{
"query": "最近热门的5部电影叫什么名字",
"label": "nonmed"
},
{
"query": "关节疼痛,足痛可能是什么原因",
"label": "med"
},
{
"query": "最近出冷汗,肚子疼,恶心与呕吐,严重影响学习工作",
"label": "med"
}
]
实验环境
CPU:6核 E5-2680 V4
GPU:RTX3060,12.6GB显存
预装镜像:Ubuntu 20.04,Python 3.9/3.10,PyTorch 2.0.1,TensorFlow 2.13.0,Docker 20.10.10, CUDA 尽量维持在最新版本
需手动安装的库:
pip install transformers datasets evaluate accelerate
训练方式
- 基于 Hugging Face 的
transformers
库对哈工大讯飞联合实验室 (HFL) 发布的 chinese-roberta-wwm-ext 中文预训练模型进行微调。
训练参数、效果与局限性
训练参数
{ output_dir: "output", num_train_epochs: 2, learning_rate: 3e-5, lr_scheduler_type: "cosine", per_device_train_batch_size: 16, per_device_eval_batch_size: 16, weight_decay: 0.01, warmup_ratio: 0.02, logging_steps: 0.01, logging_strategy: "steps", fp16: True, eval_strategy: "steps", eval_steps: 0.1, save_strategy: 'epoch' }
微调效果
数据集 准确率 F1分数 测试集 0.99 0.98 局限性
整体而言,微调后模型对于医学问诊的意图识别效果不错;但碍于本次用于模型训练的数据量终究有限且样本多样性欠佳,故在某些情况下的效果可能存在偏差。
如何使用
单样本推理示例
from transformers import AutoTokenizer from transformers import AutoModelForSequenceClassification ID2LABEL = {0: "闲聊", 1: "问诊"} MODEL_NAME = 'StevenZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, torch_dtype='auto' ) query = '这孩子目前28岁,情绪不好时经常无征兆吐血,呼吸系统和消化系统做过多次检查,没有检查出结果,最近三天连续早晨出现吐血现象' tokenized_query = tokenizer(query, return_tensors='pt') tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()} outputs = model(**tokenized_query) pred_id = outputs.logits.argmax(-1).item() intent = ID2LABEL[pred_id] print(intent) # "问诊"
批次数据推理示例
from transformers import AutoTokenizer from transformers import AutoModelForSequenceClassification ID2LABEL = {0: "闲聊", 1: "问诊"} MODEL_NAME = 'StevenZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base' tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left') model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, torch_dtype='auto' ) query = [ '胃痛,连续拉肚子好几天了,有时候半夜还呕吐', '腿上的毛怎样去掉,不用任何药学和医学器械', '你好,感冒咳嗽用什么药?', '你觉得今天天气如何?我感觉咱可以去露营了!' ] tokenized_query = tokenizer(query, return_tensors='pt', padding=True, truncation=True) tokenized_query = {k: v.to(model.device) for k, v in tokenized_query.items()} outputs = model(**tokenized_query) pred_ids = outputs.logits.argmax(-1).tolist() intent = [ID2LABEL[pred_id] for pred_id in pred_ids] print(intent) # ["问诊", "闲聊", "问诊", "闲聊"]
- Downloads last month
- 11
Model tree for StevenZhun/RoBERTa-Chinese-Med-Inquiry-Intention-Recognition-base
Base model
hfl/chinese-roberta-wwm-ext