import pandas as pd from datasets import load_dataset import json from tqdm import tqdm ds = load_dataset("OpenAssistant/oasst1") train = ds['train'] val = ds['validation'] # データフレームを連結 df = pd.concat([pd.DataFrame(train), pd.DataFrame(val)]) ds_ja = load_dataset("kunishou/oasst1-89k-ja") # データフレーム df_ja = pd.DataFrame(ds_ja['train']) # 'message_id' をキーにして df_ja と df を結合し、df_ja の列名が優先されるようにします。 merged_df = df_ja.merge(df, on='message_id', how='left', suffixes=('', '_y')) # 重複した列を削除します。 merged_df = merged_df.drop(columns=[col for col in merged_df.columns if col.endswith('_y')]) # 同じmessage_tree_idでデータをグループ化 grouped = merged_df.groupby('message_tree_id') def find_longest_chain(group, root_message_id): max_length = 0 # 最長のチェーンの長さを初期化 min_toxicity = 2.0 # 最小の毒性を初期化 leaf_id = None # 最長のチェーンの末端のメッセージIDを初期化 # グループ内の各行に対して処理を行う for _, row in group.iterrows(): current_id = row['message_id'] if current_id == root_message_id: continue # ルートメッセージを処理しない chain_length = 0 # チェーンの長さを初期化 toxicity = 1.0 # 毒性を初期化 # ルートメッセージにたどり着くまでチェーンを辿る while current_id != 'nan': chain_length += 1 detoxify_data = group.loc[group['message_id'] == current_id, 'detoxify'].iloc[0] toxicity = detoxify_data['toxicity'] if detoxify_data is not None else 1.0 # 毒性がない場合は1.0を代入 current_id = group.loc[group['message_id'] == current_id, 'parent_id'].values[0] # チェーンが現在の最長のチェーンと同じか長く、毒性が現在の最小の毒性以下の場合 if chain_length >= max_length and toxicity <= min_toxicity: max_length = chain_length min_toxicity = toxicity leaf_id = row['message_id'] # 末端のメッセージIDを更新 return leaf_id # 最長のチェーンの末端のメッセージIDを返す leafs = [] # 最長チェーンの末端のメッセージIDを格納するリストを初期化 for _, group in tqdm(grouped): # parent_idがnullのメッセージを見つける(ルートメッセージ) root_message = group[group['parent_id'] == 'nan'].iloc[0] root_message_id = root_message['message_id'] # 英語かスペイン語か日本語 if root_message['lang'] in ['en', 'es', 'ja']: leaf_id = find_longest_chain(group, root_message_id) leafs.append(leaf_id) # 最も深いメッセージから辿ってメッセージを作成する関数 def create_message_path(message): role = "User" if message['role'] == "prompter" else "Assistant" # メッセージの役割に応じて、UserかAssistantを選択 formatted_message = f"{role}:{message['text_ja']}" # 役割とメッセージを連結 if pd.isnull(message['parent_id']): # 親メッセージがない場合 return [formatted_message] else: parent_messages = merged_df[merged_df['message_id'] == message['parent_id']] # 親メッセージを検索 if parent_messages.empty: # 親メッセージが見つからない場合 return [formatted_message] parent_message = parent_messages.iloc[0] # 親メッセージを取得 # 親メッセージから再帰的にメッセージを作成し、現在のメッセージを追加 return create_message_path(parent_message) + [formatted_message] result = [] # 結果を格納するリストを初期化 for leaf_id in tqdm(leafs): # 進捗状況を表示するためにtqdmを使用 leaf_message = merged_df[merged_df['message_id'] == leaf_id].iloc[0] # 末端のメッセージを取得 leaf_text = create_message_path(leaf_message) # 末端のメッセージからメッセージのチェーンを作成 leaf_json = {} odd = len(leaf_text) % 2 if len(leaf_text) <= 3: # メッセージのチェーンが3つ以下の場合 leaf_json['instruction'] = leaf_text[0].replace("User:", "", 1) leaf_json['input'] = "" leaf_json['output'] = leaf_text[1].replace("Assistant:", "", 1) else: # メッセージのチェーンが4つ以上の場合 instruction = "" for t in leaf_text[0:-2-odd]: # 最後の2つのメッセージを除いて、指示文を作成 instruction += t + " " leaf_json['instruction'] = instruction leaf_json['input'] = leaf_text[-2-odd] # 入力メッセージを設定 leaf_json['output'] = leaf_text[-1-odd].replace("Assistant:", "", 1) # 出力メッセージを設定 result.append(leaf_json) # 結果リストにJSONを追加 # JSON データを作成 json_data = json.dumps(result, ensure_ascii=False, indent=4) # JSON をファイルに保存 with open("oasst1_ja.json", "w", encoding="utf-8") as json_file: json_file.write(json_data)