chess-mamba-vs-xformer / sort_split.py
HaileyStorm's picture
Upload 5 files
80bc2b3 verified
import pandas as pd
import pyarrow.parquet as pq
import os
import numpy as np
import math
def sort_and_split_parquet(input_file, output_dir, n_splits, prefix, min_len, max_len):
# Load the parquet file
print("Loading parquet file...")
df = pq.read_table(input_file).to_pandas()
# Sort by the length of the 'tokenized' column
print("Sorting games & filtering by length...")
df['length'] = df['tokenized'].apply(len)
df_sorted = df.sort_values(by='length').drop(columns=['length'])
lenb4 = len(df_sorted)
df_sorted = df_sorted[df_sorted['tokenized'].apply(len) <= max_len]
df_sorted = df_sorted[df_sorted['tokenized'].apply(len) >= min_len]
if len(df_sorted) < lenb4:
removed = lenb4 - len(df_sorted)
print(f"Removed {removed} ({float(removed)/lenb4:.2%}) short and long games.")
# Calculate the number of rows per split
total_rows = len(df_sorted)
rows_per_split = math.ceil(total_rows / n_splits)
print("Dataset sorted. Splitting...")
games = 0
# Split and save each part
for i in range(n_splits):
start_row = i * rows_per_split
end_row = min(start_row + rows_per_split, total_rows)
split_df = df_sorted.iloc[start_row:end_row]
#lenb4 = len(split_df)
#split_df = split_df[split_df['tokenized'].apply(len) <= max_len]
#if len(split_df) < lenb4:
# print(f"\tRemoved {lenb4 - len(split_df)} long games.")
games += len(split_df)
first_game_length = len(split_df.iloc[0]['tokenized'])
last_game_length = len(split_df.iloc[-1]['tokenized'])
# Save the split DataFrame as a parquet file
split_file_name = f"{prefix}_{i}.parquet"
split_df.to_parquet(os.path.join(output_dir, split_file_name))
print(f"Saved {split_file_name}... Game lengths: {first_game_length} - {last_game_length}")
print(f"Saved {games} games total.")
input_file = '/media/hailey/TVBox/NEW_stable.parquet'
output_dir = '/media/hailey/More/AI/mamba.py/data/stable'
os.makedirs(output_dir, exist_ok=True)
n_splits = 360 #should be roughly input size / 10MB
prefix = "stable"
min_len = 200
max_len = 1536
sort_and_split_parquet(input_file, output_dir, n_splits, prefix, min_len, max_len)
print("Done.")