File size: 5,234 Bytes
e068992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pandas as pd
from datasets import load_dataset
import json
from tqdm import tqdm

ds = load_dataset("OpenAssistant/oasst1")
train = ds['train']
val = ds['validation']

# データフレームを連結
df = pd.concat([pd.DataFrame(train), pd.DataFrame(val)])

ds_ja = load_dataset("kunishou/oasst1-89k-ja")

# データフレーム
df_ja = pd.DataFrame(ds_ja['train'])

# 'message_id' をキーにして df_ja と df を結合し、df_ja の列名が優先されるようにします。
merged_df = df_ja.merge(df, on='message_id', how='left', suffixes=('', '_y'))

# 重複した列を削除します。
merged_df = merged_df.drop(columns=[col for col in merged_df.columns if col.endswith('_y')])

# 同じmessage_tree_idでデータをグループ化
grouped = merged_df.groupby('message_tree_id')

def find_longest_chain(group, root_message_id):
    max_length = 0  # 最長のチェーンの長さを初期化
    min_toxicity = 2.0  # 最小の毒性を初期化
    leaf_id = None  # 最長のチェーンの末端のメッセージIDを初期化

    # グループ内の各行に対して処理を行う
    for _, row in group.iterrows():
        current_id = row['message_id']
        if current_id == root_message_id:
            continue  # ルートメッセージを処理しない

        chain_length = 0  # チェーンの長さを初期化
        toxicity = 1.0  # 毒性を初期化

        # ルートメッセージにたどり着くまでチェーンを辿る
        while current_id != 'nan':
            chain_length += 1
            detoxify_data = group.loc[group['message_id'] == current_id, 'detoxify'].iloc[0]
            toxicity = detoxify_data['toxicity'] if detoxify_data is not None else 1.0  # 毒性がない場合は1.0を代入
            current_id = group.loc[group['message_id'] == current_id, 'parent_id'].values[0]

            # チェーンが現在の最長のチェーンと同じか長く、毒性が現在の最小の毒性以下の場合
            if chain_length >= max_length and toxicity <= min_toxicity:
                max_length = chain_length
                min_toxicity = toxicity
                leaf_id = row['message_id']  # 末端のメッセージIDを更新

    return leaf_id  # 最長のチェーンの末端のメッセージIDを返す

leafs = []  # 最長チェーンの末端のメッセージIDを格納するリストを初期化


for _, group in tqdm(grouped):
    # parent_idがnullのメッセージを見つける(ルートメッセージ)
    root_message = group[group['parent_id'] == 'nan'].iloc[0]
    root_message_id = root_message['message_id']

    # 英語かスペイン語か日本語
    if root_message['lang'] in ['en', 'es', 'ja']:
        leaf_id = find_longest_chain(group, root_message_id)
        leafs.append(leaf_id)

# 最も深いメッセージから辿ってメッセージを作成する関数
def create_message_path(message):
    role = "User" if message['role'] == "prompter" else "Assistant"  # メッセージの役割に応じて、UserかAssistantを選択
    formatted_message = f"{role}:{message['text_ja']}"  # 役割とメッセージを連結
    if pd.isnull(message['parent_id']):  # 親メッセージがない場合
        return [formatted_message]
    else:
        parent_messages = merged_df[merged_df['message_id'] == message['parent_id']]  # 親メッセージを検索
        if parent_messages.empty:  # 親メッセージが見つからない場合
            return [formatted_message]
        parent_message = parent_messages.iloc[0]  # 親メッセージを取得
        # 親メッセージから再帰的にメッセージを作成し、現在のメッセージを追加
        return create_message_path(parent_message) + [formatted_message]

result = []  # 結果を格納するリストを初期化
for leaf_id in tqdm(leafs):  # 進捗状況を表示するためにtqdmを使用
    leaf_message = merged_df[merged_df['message_id'] == leaf_id].iloc[0]  # 末端のメッセージを取得
    leaf_text = create_message_path(leaf_message)  # 末端のメッセージからメッセージのチェーンを作成
    leaf_json = {}
    odd = len(leaf_text) % 2
    if len(leaf_text) <= 3:  # メッセージのチェーンが3つ以下の場合
        leaf_json['instruction'] = leaf_text[0].replace("User:", "", 1)
        leaf_json['input'] = ""
        leaf_json['output'] = leaf_text[1].replace("Assistant:", "", 1)
    else:  # メッセージのチェーンが4つ以上の場合
        instruction = ""
        for t in leaf_text[0:-2-odd]:  # 最後の2つのメッセージを除いて、指示文を作成
            instruction += t + " "
        leaf_json['instruction'] = instruction
        leaf_json['input'] = leaf_text[-2-odd]  # 入力メッセージを設定
        leaf_json['output'] = leaf_text[-1-odd].replace("Assistant:", "", 1)  # 出力メッセージを設定
    result.append(leaf_json)  # 結果リストにJSONを追加

# JSON データを作成
json_data = json.dumps(result, ensure_ascii=False, indent=4)

# JSON をファイルに保存
with open("oasst1_ja.json", "w", encoding="utf-8") as json_file:
    json_file.write(json_data)