File size: 11,920 Bytes
5e67680
 
 
 
 
 
 
 
 
7a8d7d7
5e67680
 
 
7a8d7d7
5e67680
 
 
 
 
7a8d7d7
 
 
 
 
 
 
 
780944d
7a8d7d7
dc21d7b
5e67680
7a8d7d7
5e67680
7a8d7d7
 
 
 
 
 
5e67680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a8d7d7
5e67680
 
 
 
 
 
7a8d7d7
 
5e67680
7a8d7d7
 
 
 
 
 
5e67680
 
 
 
 
7a8d7d7
 
 
 
 
 
 
 
 
 
 
5e67680
 
 
 
 
 
 
7a8d7d7
5e67680
 
 
 
 
 
 
 
 
7a8d7d7
 
 
 
5e67680
 
 
7a8d7d7
 
 
 
 
 
 
 
 
 
5e67680
 
 
7a8d7d7
5e67680
 
7a8d7d7
 
 
5e67680
 
 
 
7a8d7d7
 
5e67680
 
 
 
 
7a8d7d7
5e67680
7a8d7d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e67680
 
7a8d7d7
5e67680
7a8d7d7
5e67680
 
7a8d7d7
 
5e67680
 
 
7a8d7d7
 
5e67680
 
 
 
 
 
7a8d7d7
5e67680
7a8d7d7
5e67680
7a8d7d7
 
5e67680
 
7a8d7d7
5e67680
 
 
 
 
 
7a8d7d7
 
 
 
63bd90f
 
 
 
 
7a8d7d7
 
5e67680
 
 
 
7a8d7d7
 
5e67680
7a8d7d7
5e67680
7a8d7d7
 
5e67680
7a8d7d7
 
5e67680
7a8d7d7
 
 
 
 
5e67680
 
7a8d7d7
 
5e67680
 
 
 
7a8d7d7
5e67680
7a8d7d7
5e67680
 
 
 
590f060
5e67680
7a8d7d7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
---
language: ja
tags:
- ja
- japanese
- gpt
- text-generation
- lm
- nlp
- conversational
license: mit
datasets:
- kunishou/databricks-dolly-15k-ja
- kunishou/oasst1-89k-ja
widget:
- text: >-
    <s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n日本で一番広い湖は?\n[SEP]\n応答:\n
---

# 更新履歴
- 2023年5月7日
  
  「[oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja)」データセットを追加して**対話システム**に対応しました。1024トークンまで会話履歴を保存できます。
  前回のモデルで行った質疑応答の正答率は今回のモデルで下がりました。「日本で一番広い湖は?」が91%から89%、「世界で一番高い山は?」が84%から73%に下がりました。(対話は分けた方が良かったのか、それともoasst1の質が良くないとか)

- 2023年4月13日

  「[japanese-gpt-1b](https://huggingface.co/rinna/japanese-gpt-1b)」モデルを「[databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja)」データセットで~~**RLHF** (人間のフィードバックからの強化学習)~~**Instruction Tuning**しました。

# dolly-japanese-gpt-1b

1.3Bパラメータの日本語GPT-2モデルを使用した対話型のAIです。VRAM 7GB または RAM 7GB が必要で、問題なく動作すると思われます。

rinna社の「[japanese-gpt-1b](https://huggingface.co/rinna/japanese-gpt-1b)」を、
日本語データセット「[databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja)」、
「[oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja)」、
「[OjousamaTalkScriptDataset](https://github.com/matsuvr/OjousamaTalkScriptDataset)」、
「[train_data/zundamon.json](train_data/zundamon.json)」
を使用して学習させました。

学習データやモデルを作成および配布してくださった方々に心から感謝申し上げます。

# モデルの使用方法

## モデルの読み込み

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)
```

## ChatGPT/GPT-4によるサンプルコード(少し修正)

```python
MAX_ASSISTANT_LENGTH = 100
MAX_INPUT_LENGTH = 1024
INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'
USER_NAME = "User"
ASSISTANT_NAME = "Assistant"

def prepare_input(role_instruction, conversation_history, new_conversation):
    instruction = "".join([f"{text} " for text in role_instruction])
    instruction += " ".join(conversation_history)
    input_text = f"{USER_NAME}:{new_conversation}"

    return INPUT_PROMPT.format(instruction=instruction, input=input_text)

def format_output(output):
    output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
    return output

def generate_response(role_instruction, conversation_history, new_conversation):
    # 入力トークン数1024におさまるようにする
    for _ in range(8):
        input_text = prepare_input(role_instruction, conversation_history, new_conversation)
        token_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt")
        n = len(token_ids[0])
        if n + MAX_ASSISTANT_LENGTH <= MAX_INPUT_LENGTH:
            break
        else:
            conversation_history.pop(0)
            conversation_history.pop(0)

    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            min_length=n,
            max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
            temperature=0.7,
            repetition_penalty=1.0, # 数値を大きくすると、文字列の繰り返しが減る
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_words_ids=[[tokenizer.unk_token_id]]
        )

    output = tokenizer.decode(output_ids.tolist()[0])
    formatted_output_all = format_output(output)

    response = f"{ASSISTANT_NAME}:{formatted_output_all.split('応答:')[-1].strip()}"
    conversation_history.append(f"{USER_NAME}:{new_conversation}".replace("\n", "\\n"))
    conversation_history.append(response.replace("\n", "\\n"))

    return formatted_output_all, response 

role_instruction = [
    f"{USER_NAME}:きみは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットなのだ。一人称は「ボク」で語尾に「なのだー」を付けてしゃべるのだ。",
    f"{ASSISTANT_NAME}:了解したのだ。",
    f"{USER_NAME}:きみは同じ言葉を繰り返さず、何でも正確に要約して答えられるのだ。",
    f"{ASSISTANT_NAME}:了解したのだ。",
]

conversation_history = [
]

questions = [
    "日本で一番高い山は?",
    "日本で一番広い湖は?",
    "冗談を言ってください。",
    "世界で一番高い山は?",
    "世界で一番広い湖は?",
    "最初の質問は何ですか?",
    "今何問目?",
    "自己紹介をしてください。",
]

# 各質問に対して応答を生成して表示
for question in questions:
    formatted_output_all, response = generate_response(role_instruction, conversation_history, question)
    print(f"{USER_NAME}:{question}\n{response}\n---")
```

## 出力

```
User:日本で一番高い山は?
Assistant:富士山
---
User:日本で一番広い湖は?
Assistant:日本列島には、太平洋、日本海、オホーツク海、インド洋、大西洋など、さまざまな海があります。
---
User:冗談を言ってください。
Assistant:冗談を言おう!
---
User:世界で一番高い山は?
Assistant:世界で最も高い山はエベレストです。
---
User:世界で一番広い湖は?
Assistant:太平洋に浮かぶ世界最大の島、ツバル
---
User:最初の質問は何ですか?
Assistant:私が最初の質問をするとするならば、それは「ずんだもん」です。
---
User:今何問目?
Assistant:今何問目ですか?
---
User:自己紹介をしてください。
Assistant:私は「ずんだもん」というあだ名で呼ばれています。
---
```

## ChatGPT/GPT-4による説明(少し修正)

このコードは、質問に答えるAIアシスタントを実装しています。質問リストに対して、役割指示に従った応答を生成し、会話を表示します。

# 評価
1000回の「入力」のような質問を行い、それらに対する「応答」に正解の文字列が含まれるかで評価しています。
一番正答率が高い10エポック目のモデルを選択しました。(やり過ぎたかもしれないです。)

| 入力                  | 応答        | 正答率[%] |
|-----------------------|-------------|-------|
| 日本で一番広い湖は? | 琵琶湖     | 89    |
| 世界で一番高い山は? | エベレスト | 73    |

# 学習データのフォーマット

[alpaca](https://github.com/tatsu-lab/stanford_alpaca)と同じように、以下のようなフォーマットにしています。

```
<s> 
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
[SEP] 
指示:
User:きみは「ずんだもん」なのだ。東北ずん子の武器である「ずんだアロー」に変身する妖精またはマスコットなのだ。一人称は「ボク」で語尾に「なのだー」を付けてしゃべるのだ。 Assistant:了解したのだ。 User:きみは同じ言葉を繰り返さず、何でも正確に要約して答えられるのだ。 Assistant:了解したのだ。 
[SEP] 
入力:
User:日本で一番高い山は?
[SEP] 
応答:
富士山
</s>
```

transformersのコードでtxtファイルを学習する場合、1データ1行のようなので改行コードを一旦`\n`に置き換えています。
学習データは[dolly-oasst1-ja.txt](train_data/dolly-oasst1-ja.txt)です。

また学習データを作った過程のスクリプトとjsonファイルも[train_data](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/tree/main/train_data)に置いておきます。

作成時のスクリプトと作成手順を記載します。

1. [make_json_from_oasst1_ja.py](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/make_json_from_oasst1_ja.py)スクリプトで[oasst1_ja.json](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/oasst1_ja.json)ファイルを作成
2. [oasst1_ja.json](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/oasst1_ja.json)ファイル、[databricks-dolly-15k-ja.json](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/databricks-dolly-15k-ja.json)ファイル、[ojousamatalkscript200.json](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/ojousamatalkscript200.json)ファイル、[zundamon.json](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/zundamon.json)ファイルから[merge_json.py](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/merge_json.py)スクリプトで一つのjsonファイルにマージ
3. マージしたjsonファイルから[make_train_data_from_merged_json.py](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/make_train_data_from_merged_json.py)スクリプトで[dolly-oasst1-ja.txt](https://huggingface.co/inu-ai/dolly-japanese-gpt-1b/blob/main/train_data/dolly-oasst1-ja.txt)を作成

になります。

# 学習のハイパーパラメータ

学習時には以下のハイパーパラメータを使用:

※VRAMが足りない場合、optimをadafactorにするとVRAM使用量が減りました。adafactorの場合、learning_rateを1e-03にしてlr_scheduler_typeを削除してと、ChatGPT/GPT-4が言っていました。
```
venv/Scripts/python.exe transformers/examples/pytorch/language-modeling/run_clm.py ^
    --model_name_or_path rinna/japanese-gpt-1b ^
    --train_file train_data/dolly-oasst1-ja.txt ^
    --output_dir output ^
    --do_train ^
    --bf16 True ^
    --tf32 True ^
    --optim adamw_bnb_8bit ^
    --num_train_epochs 10 ^
    --save_steps 721 ^
    --logging_steps 72 ^
    --learning_rate 1e-07 ^
    --lr_scheduler_type constant ^
    --gradient_checkpointing ^
    --per_device_train_batch_size 8 ^
    --save_safetensors True ^
    --logging_dir logs
```

# ライブラリのバージョン

- Transformers 4.28.1
- Pytorch 2.0.0+cu117
- Datasets 2.11.0
- Tokenizers 0.13.3
- bitsandbytes 0.37.2

# ライセンス
MITで大丈夫そうです。

- [japanese-gpt-1b](rinna/japanese-gpt-1b) - mit
- [databricks-dolly-15k-ja](https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja) - CC BY SA 3.0
- [oasst1-89k-ja](https://huggingface.co/datasets/kunishou/oasst1-89k-ja) - apache-2.0
- [OjousamaTalkScriptDataset](https://github.com/matsuvr/OjousamaTalkScriptDataset) - mit
- [train_data/zundamon.json](train_data/zundamon.json) - mit