dolly-japanese-gpt-1b / train_data /make_json_from_oasst1_ja.py
inu-ai's picture
Upload train_data
e068992
import pandas as pd
from datasets import load_dataset
import json
from tqdm import tqdm
ds = load_dataset("OpenAssistant/oasst1")
train = ds['train']
val = ds['validation']
# データフレームを連結
df = pd.concat([pd.DataFrame(train), pd.DataFrame(val)])
ds_ja = load_dataset("kunishou/oasst1-89k-ja")
# データフレーム
df_ja = pd.DataFrame(ds_ja['train'])
# 'message_id' をキーにして df_ja と df を結合し、df_ja の列名が優先されるようにします。
merged_df = df_ja.merge(df, on='message_id', how='left', suffixes=('', '_y'))
# 重複した列を削除します。
merged_df = merged_df.drop(columns=[col for col in merged_df.columns if col.endswith('_y')])
# 同じmessage_tree_idでデータをグループ化
grouped = merged_df.groupby('message_tree_id')
def find_longest_chain(group, root_message_id):
max_length = 0 # 最長のチェーンの長さを初期化
min_toxicity = 2.0 # 最小の毒性を初期化
leaf_id = None # 最長のチェーンの末端のメッセージIDを初期化
# グループ内の各行に対して処理を行う
for _, row in group.iterrows():
current_id = row['message_id']
if current_id == root_message_id:
continue # ルートメッセージを処理しない
chain_length = 0 # チェーンの長さを初期化
toxicity = 1.0 # 毒性を初期化
# ルートメッセージにたどり着くまでチェーンを辿る
while current_id != 'nan':
chain_length += 1
detoxify_data = group.loc[group['message_id'] == current_id, 'detoxify'].iloc[0]
toxicity = detoxify_data['toxicity'] if detoxify_data is not None else 1.0 # 毒性がない場合は1.0を代入
current_id = group.loc[group['message_id'] == current_id, 'parent_id'].values[0]
# チェーンが現在の最長のチェーンと同じか長く、毒性が現在の最小の毒性以下の場合
if chain_length >= max_length and toxicity <= min_toxicity:
max_length = chain_length
min_toxicity = toxicity
leaf_id = row['message_id'] # 末端のメッセージIDを更新
return leaf_id # 最長のチェーンの末端のメッセージIDを返す
leafs = [] # 最長チェーンの末端のメッセージIDを格納するリストを初期化
for _, group in tqdm(grouped):
# parent_idがnullのメッセージを見つける(ルートメッセージ)
root_message = group[group['parent_id'] == 'nan'].iloc[0]
root_message_id = root_message['message_id']
# 英語かスペイン語か日本語
if root_message['lang'] in ['en', 'es', 'ja']:
leaf_id = find_longest_chain(group, root_message_id)
leafs.append(leaf_id)
# 最も深いメッセージから辿ってメッセージを作成する関数
def create_message_path(message):
role = "User" if message['role'] == "prompter" else "Assistant" # メッセージの役割に応じて、UserかAssistantを選択
formatted_message = f"{role}:{message['text_ja']}" # 役割とメッセージを連結
if pd.isnull(message['parent_id']): # 親メッセージがない場合
return [formatted_message]
else:
parent_messages = merged_df[merged_df['message_id'] == message['parent_id']] # 親メッセージを検索
if parent_messages.empty: # 親メッセージが見つからない場合
return [formatted_message]
parent_message = parent_messages.iloc[0] # 親メッセージを取得
# 親メッセージから再帰的にメッセージを作成し、現在のメッセージを追加
return create_message_path(parent_message) + [formatted_message]
result = [] # 結果を格納するリストを初期化
for leaf_id in tqdm(leafs): # 進捗状況を表示するためにtqdmを使用
leaf_message = merged_df[merged_df['message_id'] == leaf_id].iloc[0] # 末端のメッセージを取得
leaf_text = create_message_path(leaf_message) # 末端のメッセージからメッセージのチェーンを作成
leaf_json = {}
odd = len(leaf_text) % 2
if len(leaf_text) <= 3: # メッセージのチェーンが3つ以下の場合
leaf_json['instruction'] = leaf_text[0].replace("User:", "", 1)
leaf_json['input'] = ""
leaf_json['output'] = leaf_text[1].replace("Assistant:", "", 1)
else: # メッセージのチェーンが4つ以上の場合
instruction = ""
for t in leaf_text[0:-2-odd]: # 最後の2つのメッセージを除いて、指示文を作成
instruction += t + " "
leaf_json['instruction'] = instruction
leaf_json['input'] = leaf_text[-2-odd] # 入力メッセージを設定
leaf_json['output'] = leaf_text[-1-odd].replace("Assistant:", "", 1) # 出力メッセージを設定
result.append(leaf_json) # 結果リストにJSONを追加
# JSON データを作成
json_data = json.dumps(result, ensure_ascii=False, indent=4)
# JSON をファイルに保存
with open("oasst1_ja.json", "w", encoding="utf-8") as json_file:
json_file.write(json_data)