inu-ai commited on
Commit
e068992
·
1 Parent(s): 8a482d8

Upload train_data

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ train_data/databricks-dolly-15k-ja.json filter=lfs diff=lfs merge=lfs -text
36
+ train_data/oasst1_ja.json filter=lfs diff=lfs merge=lfs -text
train_data/databricks-dolly-15k-ja.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7053bd9081719ea68765e0e743c6e222fd78de65a41de3411d11362b631815e
3
+ size 17061804
train_data/make_json_from_oasst1_ja.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import load_dataset
3
+ import json
4
+ from tqdm import tqdm
5
+
6
+ ds = load_dataset("OpenAssistant/oasst1")
7
+ train = ds['train']
8
+ val = ds['validation']
9
+
10
+ # データフレームを連結
11
+ df = pd.concat([pd.DataFrame(train), pd.DataFrame(val)])
12
+
13
+ ds_ja = load_dataset("kunishou/oasst1-89k-ja")
14
+
15
+ # データフレーム
16
+ df_ja = pd.DataFrame(ds_ja['train'])
17
+
18
+ # 'message_id' をキーにして df_ja と df を結合し、df_ja の列名が優先されるようにします。
19
+ merged_df = df_ja.merge(df, on='message_id', how='left', suffixes=('', '_y'))
20
+
21
+ # 重複した列を削除します。
22
+ merged_df = merged_df.drop(columns=[col for col in merged_df.columns if col.endswith('_y')])
23
+
24
+ # 同じmessage_tree_idでデータをグループ化
25
+ grouped = merged_df.groupby('message_tree_id')
26
+
27
+ def find_longest_chain(group, root_message_id):
28
+ max_length = 0 # 最長のチェーンの長さを初期化
29
+ min_toxicity = 2.0 # 最小の毒性を初期化
30
+ leaf_id = None # 最長のチェーンの末端のメッセージIDを初期化
31
+
32
+ # グループ内の各行に対して処理を行う
33
+ for _, row in group.iterrows():
34
+ current_id = row['message_id']
35
+ if current_id == root_message_id:
36
+ continue # ルートメッセージを処理しない
37
+
38
+ chain_length = 0 # チェーンの長さを初期化
39
+ toxicity = 1.0 # 毒性を初期化
40
+
41
+ # ルートメッセージにたどり着くまでチェーンを辿る
42
+ while current_id != 'nan':
43
+ chain_length += 1
44
+ detoxify_data = group.loc[group['message_id'] == current_id, 'detoxify'].iloc[0]
45
+ toxicity = detoxify_data['toxicity'] if detoxify_data is not None else 1.0 # 毒性がない場合は1.0を代入
46
+ current_id = group.loc[group['message_id'] == current_id, 'parent_id'].values[0]
47
+
48
+ # チェーンが現在の最長のチェーンと同じか長く、毒性が現在の最小の毒性以下の場合
49
+ if chain_length >= max_length and toxicity <= min_toxicity:
50
+ max_length = chain_length
51
+ min_toxicity = toxicity
52
+ leaf_id = row['message_id'] # 末端のメッセージIDを更新
53
+
54
+ return leaf_id # 最長のチェーンの末端のメッセージIDを返す
55
+
56
+ leafs = [] # 最長チェーンの末端のメッセージIDを格納するリストを初期化
57
+
58
+
59
+ for _, group in tqdm(grouped):
60
+ # parent_idがnullのメッセージを見つける(ルートメッセージ)
61
+ root_message = group[group['parent_id'] == 'nan'].iloc[0]
62
+ root_message_id = root_message['message_id']
63
+
64
+ # 英語かスペイン語か日本語
65
+ if root_message['lang'] in ['en', 'es', 'ja']:
66
+ leaf_id = find_longest_chain(group, root_message_id)
67
+ leafs.append(leaf_id)
68
+
69
+ # 最も深いメッセージから辿ってメッセージを作成する関数
70
+ def create_message_path(message):
71
+ role = "User" if message['role'] == "prompter" else "Assistant" # メッセージの役割に応じて、UserかAssistantを選択
72
+ formatted_message = f"{role}:{message['text_ja']}" # 役割とメッセージを連結
73
+ if pd.isnull(message['parent_id']): # 親メッセージがない場合
74
+ return [formatted_message]
75
+ else:
76
+ parent_messages = merged_df[merged_df['message_id'] == message['parent_id']] # 親メッセージを検索
77
+ if parent_messages.empty: # 親メッセージが見つからない場合
78
+ return [formatted_message]
79
+ parent_message = parent_messages.iloc[0] # 親メッセージを取得
80
+ # 親メッセージから再帰的にメッセージを作成し、現在のメッセージを追加
81
+ return create_message_path(parent_message) + [formatted_message]
82
+
83
+ result = [] # 結果を格納するリストを初期化
84
+ for leaf_id in tqdm(leafs): # 進捗状況を表示するためにtqdmを使用
85
+ leaf_message = merged_df[merged_df['message_id'] == leaf_id].iloc[0] # 末端のメッセージを取得
86
+ leaf_text = create_message_path(leaf_message) # 末端のメッセージからメッセージのチェーンを作成
87
+ leaf_json = {}
88
+ odd = len(leaf_text) % 2
89
+ if len(leaf_text) <= 3: # メッセージのチェーンが3つ以下の場合
90
+ leaf_json['instruction'] = leaf_text[0].replace("User:", "", 1)
91
+ leaf_json['input'] = ""
92
+ leaf_json['output'] = leaf_text[1].replace("Assistant:", "", 1)
93
+ else: # メッセージのチェーンが4つ以上の場合
94
+ instruction = ""
95
+ for t in leaf_text[0:-2-odd]: # 最後の2つのメッセージを除いて、指示文を作成
96
+ instruction += t + " "
97
+ leaf_json['instruction'] = instruction
98
+ leaf_json['input'] = leaf_text[-2-odd] # 入力メッセージを設定
99
+ leaf_json['output'] = leaf_text[-1-odd].replace("Assistant:", "", 1) # 出力メッセージを設定
100
+ result.append(leaf_json) # 結果リスト��JSONを追加
101
+
102
+ # JSON データを作成
103
+ json_data = json.dumps(result, ensure_ascii=False, indent=4)
104
+
105
+ # JSON をファイルに保存
106
+ with open("oasst1_ja.json", "w", encoding="utf-8") as json_file:
107
+ json_file.write(json_data)
train_data/oasst1_ja.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f726ad7113b60c4bfe322884d385c9d9c7da1c99b0e7deeac27dc6934ac3b3c1
3
+ size 14463337