from datasets import load_dataset
from transformers import GPT2Tokenizer

# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# List of datasets for conversation, coding, and math
datasets = {
    "conversation": [
        "bavard/personachat_truecased",
        "li2017dailydialog/daily_dialog",
        "ssbuild/alpaca_convai2"
    ],
    "coding": [
        "lvwerra/stack-exchange-paired",
        "iamtarun/python_code_instructions_18k_alpaca",
        "code-search-net/code_search_net"
    ],
    "math": [
        "allenai/math_qa",
        "qfq/openaimath",
        "meta-math/MetaMathQA"
    ]
}

# Function to tokenize data
def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

# Load and process the datasets
def load_and_process_datasets(datasets):
    processed_datasets = {}
    
    for category, dataset_list in datasets.items():
        category_data = []
        
        for dataset_name in dataset_list:
            print(f"Loading dataset: {dataset_name}")
            dataset = load_dataset(dataset_name)
            
            # Apply tokenization
            tokenized_data = dataset.map(tokenize_function, batched=True)
            category_data.append(tokenized_data)
        
        processed_datasets[category] = category_data
    
    return processed_datasets

# Run the function
processed_datasets = load_and_process_datasets(datasets)

# Optionally, save the processed datasets to disk (if you need them locally)
# processed_datasets['conversation'][0].save_to_disk('./conversation_dataset')