File size: 4,919 Bytes
f729637 cc77e93 51d6ebd c9129a5 a849b40 06e742c a849b40 06e742c c9129a5 51d6ebd 6599416 6484cbc ebb49a0 e420f46 3dd6650 e420f46 6484cbc 4f6421c 6484cbc a849b40 ef5bd87 fe14f27 a849b40 6599416 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
---
license: apache-2.0
---
## few_shot_intent_gpt2
这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果.
(1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。
(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。
### TensorBoard 数集
**Eval Loss** 见下图:
![eval_loss.jpg](docs/pictures/eval_loss.jpg)
**Learning rate** 见下图:
学习率从 2e-4 下降到 1.4e-4。
![learning_rate.jpg](docs/pictures/learning_rate.jpg)
### 讨论
(1)最优解在不到 1 个 epoch 处得到。
* 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。
* 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。
### 其它
训练时加载数据集的代码
```python
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
parser.add_argument("--dataset_split", default=None, type=str)
parser.add_argument(
"--dataset_cache_dir",
default=(project_path / "hub_datasets").as_posix(),
type=str
)
parser.add_argument("--num_epochs", default=1, type=int)
parser.add_argument("--train_subset", default="train.jsonl", type=str)
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
args = parser.parse_args()
return args
def main():
args = get_args()
name_list = [
# "a_intent_prompt",
"amazon_massive_intent_en_us_prompt",
"amazon_massive_intent_zh_cn_prompt",
"atis_intents_prompt",
"banking77_prompt",
"bi_text11_prompt",
"bi_text27_prompt",
# "book6_prompt",
"carer_prompt",
"chatbots_prompt",
"chinese_news_title_prompt",
"cmid_4class_prompt",
"cmid_36class_prompt",
"coig_cqia_prompt",
"conv_intent_prompt",
"crosswoz_prompt",
"dmslots_prompt",
"dnd_style_intents_prompt",
"emo2019_prompt",
"finance21_prompt",
"ide_intent_prompt",
"intent_classification_prompt",
"jarvis_intent_prompt",
"mobile_assistant_prompt",
"mtop_intent_prompt",
"out_of_scope_prompt",
"ri_sawoz_domain_prompt",
"ri_sawoz_general_prompt",
"small_talk_prompt",
"smp2017_task1_prompt",
"smp2019_task1_domain_prompt",
"smp2019_task1_intent_prompt",
# "snips_built_in_intents_prompt",
"star_wars_prompt",
"suicide_intent_prompt",
"snips_built_in_intents_prompt",
"telemarketing_intent_cn_prompt",
"telemarketing_intent_en_prompt",
"vira_intents_prompt",
]
with open(args.train_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="train",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
with open(args.valid_subset, "w", encoding="utf-8") as f:
for _ in range(args.num_epochs):
for name in name_list:
print(name)
dataset = load_dataset(
path=args.dataset_path,
name=name,
split="test",
cache_dir=args.dataset_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
ignore_verifications=True
)
for sample in tqdm(dataset):
row = json.dumps(sample, ensure_ascii=False)
f.write("{}\n".format(row))
return
if __name__ == '__main__':
main()
```
|