|
import pandas as pd |
|
from datasets import load_dataset |
|
import json |
|
from tqdm import tqdm |
|
|
|
ds = load_dataset("OpenAssistant/oasst1") |
|
train = ds['train'] |
|
val = ds['validation'] |
|
|
|
|
|
df = pd.concat([pd.DataFrame(train), pd.DataFrame(val)]) |
|
|
|
ds_ja = load_dataset("kunishou/oasst1-89k-ja") |
|
|
|
|
|
df_ja = pd.DataFrame(ds_ja['train']) |
|
|
|
|
|
merged_df = df_ja.merge(df, on='message_id', how='left', suffixes=('', '_y')) |
|
|
|
|
|
merged_df = merged_df.drop(columns=[col for col in merged_df.columns if col.endswith('_y')]) |
|
|
|
|
|
grouped = merged_df.groupby('message_tree_id') |
|
|
|
def find_longest_chain(group, root_message_id): |
|
max_length = 0 |
|
min_toxicity = 2.0 |
|
leaf_id = None |
|
|
|
|
|
for _, row in group.iterrows(): |
|
current_id = row['message_id'] |
|
if current_id == root_message_id: |
|
continue |
|
|
|
chain_length = 0 |
|
toxicity = 1.0 |
|
|
|
|
|
while current_id != 'nan': |
|
chain_length += 1 |
|
detoxify_data = group.loc[group['message_id'] == current_id, 'detoxify'].iloc[0] |
|
toxicity = detoxify_data['toxicity'] if detoxify_data is not None else 1.0 |
|
current_id = group.loc[group['message_id'] == current_id, 'parent_id'].values[0] |
|
|
|
|
|
if chain_length >= max_length and toxicity <= min_toxicity: |
|
max_length = chain_length |
|
min_toxicity = toxicity |
|
leaf_id = row['message_id'] |
|
|
|
return leaf_id |
|
|
|
leafs = [] |
|
|
|
|
|
for _, group in tqdm(grouped): |
|
|
|
root_message = group[group['parent_id'] == 'nan'].iloc[0] |
|
root_message_id = root_message['message_id'] |
|
|
|
|
|
if root_message['lang'] in ['en', 'es', 'ja']: |
|
leaf_id = find_longest_chain(group, root_message_id) |
|
leafs.append(leaf_id) |
|
|
|
|
|
def create_message_path(message): |
|
role = "User" if message['role'] == "prompter" else "Assistant" |
|
formatted_message = f"{role}:{message['text_ja']}" |
|
if pd.isnull(message['parent_id']): |
|
return [formatted_message] |
|
else: |
|
parent_messages = merged_df[merged_df['message_id'] == message['parent_id']] |
|
if parent_messages.empty: |
|
return [formatted_message] |
|
parent_message = parent_messages.iloc[0] |
|
|
|
return create_message_path(parent_message) + [formatted_message] |
|
|
|
result = [] |
|
for leaf_id in tqdm(leafs): |
|
leaf_message = merged_df[merged_df['message_id'] == leaf_id].iloc[0] |
|
leaf_text = create_message_path(leaf_message) |
|
leaf_json = {} |
|
odd = len(leaf_text) % 2 |
|
if len(leaf_text) <= 3: |
|
leaf_json['instruction'] = leaf_text[0].replace("User:", "", 1) |
|
leaf_json['input'] = "" |
|
leaf_json['output'] = leaf_text[1].replace("Assistant:", "", 1) |
|
else: |
|
instruction = "" |
|
for t in leaf_text[0:-2-odd]: |
|
instruction += t + " " |
|
leaf_json['instruction'] = instruction |
|
leaf_json['input'] = leaf_text[-2-odd] |
|
leaf_json['output'] = leaf_text[-1-odd].replace("Assistant:", "", 1) |
|
result.append(leaf_json) |
|
|
|
|
|
json_data = json.dumps(result, ensure_ascii=False, indent=4) |
|
|
|
|
|
with open("oasst1_ja.json", "w", encoding="utf-8") as json_file: |
|
json_file.write(json_data) |
|
|