import pandas as pd import pyarrow.parquet as pq import os import numpy as np import math def sort_and_split_parquet(input_file, output_dir, n_splits, prefix, min_len, max_len): # Load the parquet file print("Loading parquet file...") df = pq.read_table(input_file).to_pandas() # Sort by the length of the 'tokenized' column print("Sorting games & filtering by length...") df['length'] = df['tokenized'].apply(len) df_sorted = df.sort_values(by='length').drop(columns=['length']) lenb4 = len(df_sorted) df_sorted = df_sorted[df_sorted['tokenized'].apply(len) <= max_len] df_sorted = df_sorted[df_sorted['tokenized'].apply(len) >= min_len] if len(df_sorted) < lenb4: removed = lenb4 - len(df_sorted) print(f"Removed {removed} ({float(removed)/lenb4:.2%}) short and long games.") # Calculate the number of rows per split total_rows = len(df_sorted) rows_per_split = math.ceil(total_rows / n_splits) print("Dataset sorted. Splitting...") games = 0 # Split and save each part for i in range(n_splits): start_row = i * rows_per_split end_row = min(start_row + rows_per_split, total_rows) split_df = df_sorted.iloc[start_row:end_row] #lenb4 = len(split_df) #split_df = split_df[split_df['tokenized'].apply(len) <= max_len] #if len(split_df) < lenb4: # print(f"\tRemoved {lenb4 - len(split_df)} long games.") games += len(split_df) first_game_length = len(split_df.iloc[0]['tokenized']) last_game_length = len(split_df.iloc[-1]['tokenized']) # Save the split DataFrame as a parquet file split_file_name = f"{prefix}_{i}.parquet" split_df.to_parquet(os.path.join(output_dir, split_file_name)) print(f"Saved {split_file_name}... Game lengths: {first_game_length} - {last_game_length}") print(f"Saved {games} games total.") input_file = '/media/hailey/TVBox/NEW_stable.parquet' output_dir = '/media/hailey/More/AI/mamba.py/data/stable' os.makedirs(output_dir, exist_ok=True) n_splits = 360 #should be roughly input size / 10MB prefix = "stable" min_len = 200 max_len = 1536 sort_and_split_parquet(input_file, output_dir, n_splits, prefix, min_len, max_len) print("Done.")