|
--- |
|
license: apache-2.0 |
|
--- |
|
## few_shot_intent_gpt2 |
|
|
|
这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果. |
|
|
|
(1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。 |
|
|
|
(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。 |
|
|
|
|
|
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。 |
|
|
|
你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。 |
|
|
|
|
|
### TensorBoard 数集 |
|
|
|
**Eval Loss** 见下图: |
|
|
|
![eval_loss.jpg](docs/pictures/eval_loss.jpg) |
|
|
|
|
|
**Learning rate** 见下图: |
|
|
|
学习率从 2e-4 下降到 1.4e-4。 |
|
|
|
![learning_rate.jpg](docs/pictures/learning_rate.jpg) |
|
|
|
|
|
|
|
|
|
### 讨论 |
|
|
|
(1)最优解在不到 1 个 epoch 处得到。 |
|
|
|
* 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。 |
|
|
|
* 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。 |
|
|
|
(2)后续应考虑针对 prompt-response 中 response 部分进行训练。 |
|
|
|
* 即只优化 response 部分的损失以提升识别结果与 prompt 之间的注意力机制。当前的训练有可能只是使模型拟合了 few shot 数据的格式,而并没有拟合到意图识别的目的。 |
|
|
|
(3)模型使用中的体会。 |
|
|
|
* 如果在使用过程中,模型生成 response 不在 prompt 中给定的选项,这可能说明模型已经过拟合了。 |
|
|
|
* 如果模型生成 response 在 prompt 中,但答案不正确,则说明模型已学习到生成的表层模型,而没有学习到意图识别的目的。则建议在此模型基础上进一步优化 response 部分的损失。 |
|
|
|
|
|
|
|
### 其它 |
|
|
|
训练时加载数据集的代码 |
|
```python |
|
#!/usr/bin/python3 |
|
# -*- coding: utf-8 -*- |
|
import argparse |
|
import json |
|
|
|
from datasets import load_dataset |
|
from datasets.download.download_manager import DownloadMode |
|
from tqdm import tqdm |
|
|
|
from project_settings import project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str) |
|
parser.add_argument("--dataset_split", default=None, type=str) |
|
parser.add_argument( |
|
"--dataset_cache_dir", |
|
default=(project_path / "hub_datasets").as_posix(), |
|
type=str |
|
) |
|
|
|
parser.add_argument("--num_epochs", default=1, type=int) |
|
|
|
parser.add_argument("--train_subset", default="train.jsonl", type=str) |
|
parser.add_argument("--valid_subset", default="valid.jsonl", type=str) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
name_list = [ |
|
# "a_intent_prompt", |
|
"amazon_massive_intent_en_us_prompt", |
|
"amazon_massive_intent_zh_cn_prompt", |
|
"atis_intents_prompt", |
|
"banking77_prompt", |
|
"bi_text11_prompt", |
|
"bi_text27_prompt", |
|
# "book6_prompt", |
|
"carer_prompt", |
|
"chatbots_prompt", |
|
"chinese_news_title_prompt", |
|
"cmid_4class_prompt", |
|
"cmid_36class_prompt", |
|
"coig_cqia_prompt", |
|
"conv_intent_prompt", |
|
"crosswoz_prompt", |
|
"dmslots_prompt", |
|
"dnd_style_intents_prompt", |
|
"emo2019_prompt", |
|
"finance21_prompt", |
|
"ide_intent_prompt", |
|
"intent_classification_prompt", |
|
"jarvis_intent_prompt", |
|
"mobile_assistant_prompt", |
|
"mtop_intent_prompt", |
|
"out_of_scope_prompt", |
|
"ri_sawoz_domain_prompt", |
|
"ri_sawoz_general_prompt", |
|
"small_talk_prompt", |
|
"smp2017_task1_prompt", |
|
"smp2019_task1_domain_prompt", |
|
"smp2019_task1_intent_prompt", |
|
# "snips_built_in_intents_prompt", |
|
"star_wars_prompt", |
|
"suicide_intent_prompt", |
|
"snips_built_in_intents_prompt", |
|
"telemarketing_intent_cn_prompt", |
|
"telemarketing_intent_en_prompt", |
|
"vira_intents_prompt", |
|
] |
|
|
|
with open(args.train_subset, "w", encoding="utf-8") as f: |
|
for _ in range(args.num_epochs): |
|
for name in name_list: |
|
print(name) |
|
dataset = load_dataset( |
|
path=args.dataset_path, |
|
name=name, |
|
split="train", |
|
cache_dir=args.dataset_cache_dir, |
|
download_mode=DownloadMode.FORCE_REDOWNLOAD, |
|
ignore_verifications=True |
|
) |
|
for sample in tqdm(dataset): |
|
row = json.dumps(sample, ensure_ascii=False) |
|
f.write("{}\n".format(row)) |
|
|
|
with open(args.valid_subset, "w", encoding="utf-8") as f: |
|
for _ in range(args.num_epochs): |
|
for name in name_list: |
|
print(name) |
|
dataset = load_dataset( |
|
path=args.dataset_path, |
|
name=name, |
|
split="test", |
|
cache_dir=args.dataset_cache_dir, |
|
download_mode=DownloadMode.FORCE_REDOWNLOAD, |
|
ignore_verifications=True |
|
) |
|
for sample in tqdm(dataset): |
|
row = json.dumps(sample, ensure_ascii=False) |
|
f.write("{}\n".format(row)) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
``` |
|
|
|
|