zhuoranyang's picture
Deploy app with precomputed results for p=15,23,29,31
b753304 verified
#!/usr/bin/env python3
"""
Module NN: Neural Network Training Wrapper
==========================================
A flexible wrapper for training neural networks on modular arithmetic tasks.
Supports command-line parameter overrides for easy batch experimentation.
Usage Examples:
# Use default config
python module_nn.py
# Override specific parameters
python module_nn.py --p 17 --lr 0.01 --num_epochs 10000
# Run batch experiments on init_type, optimizer, and act_type (16 total combinations)
python module_nn.py --experiments
# Run batch experiments with custom parameters
python module_nn.py --experiments --p 17 --num_epochs 3000
# Dry run to see configuration
python module_nn.py --dry_run --p 23 --lr 0.001
# Multiple parameters for single experiment
python module_nn.py --p 23 --lr 0.001 --d_mlp 256 --act_type ReLU --seed 42
Bash Script Example:
# Run experiments for different primes
for p in 17 23 31; do
python module_nn.py --experiments --p $p --num_epochs 3000
done
"""
import argparse
import sys
from collections import deque
from utils import *
from nnTrainer import Trainer
def parse_arguments():
"""Parse command line arguments with support for config overrides"""
parser = argparse.ArgumentParser(
description='Neural Network Training for Modular Arithmetic',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
# Data parameters
parser.add_argument('--p', type=int, help='Prime number for modular arithmetic')
parser.add_argument('--d_vocab', type=int, help='Vocabulary size (defaults to p)')
parser.add_argument('--fn_name', type=str, choices=['add', 'subtract', 'x2xyy2'], help='Function to learn')
parser.add_argument('--frac_train', type=float, help='Fraction of data for training')
parser.add_argument('--batch_style', type=str, help='Batch processing style')
# Model parameters
parser.add_argument('--d_model', type=int, help='Model embedding dimensionality')
parser.add_argument('--d_mlp', type=int, help='MLP layer dimensionality')
parser.add_argument('--act_type', type=str, choices=['ReLU', 'GeLU', 'Quad', 'Id'], help='Activation function')
parser.add_argument('--embed_type', type=str, choices=['one_hot', 'learned'], help='Embedding type')
parser.add_argument('--init_type', type=str, choices=['random', 'single-freq'], help='Weight initialization')
parser.add_argument('--init_scale', type=float, help='Scale factor for weight initialization')
parser.add_argument('--freq_num', type=int, help='Number of frequencies for single-freq init')
# Training parameters
parser.add_argument('--num_epochs', type=int, help='Number of training epochs')
parser.add_argument('--lr', type=float, help='Learning rate')
parser.add_argument('--weight_decay', type=float, help='Weight decay')
parser.add_argument('--optimizer', type=str, choices=['AdamW', 'SGD'], help='Optimizer')
parser.add_argument('--stopping_thresh', type=float, help='Early stopping threshold')
parser.add_argument('--save_models', type=bool, help='Whether to save models')
parser.add_argument('--save_every', type=int, help='Save frequency (epochs)')
parser.add_argument('--seed', type=int, help='Random seed')
# Special flags
parser.add_argument('--config', type=str, help='Path to custom config file')
parser.add_argument('--dry_run', action='store_true', help='Print config and exit without training')
parser.add_argument('--no_wandb', action='store_true', help='Disable wandb logging')
return parser.parse_args()
def override_config(config_dict, args):
"""Override config values with command line arguments"""
# Flatten the nested config for easier access
flat_config = {}
def flatten_dict(d, parent_key=''):
for k, v in d.items():
if isinstance(v, dict):
flatten_dict(v, parent_key)
else:
flat_config[k] = v
flatten_dict(config_dict)
# Override with command line arguments
for arg_name, arg_value in vars(args).items():
if arg_value is not None and arg_name in flat_config:
flat_config[arg_name] = arg_value
print(f"Override: {arg_name} = {arg_value}")
# Reconstruct nested structure
result = {'data': {}, 'model': {}, 'training': {}}
# Data parameters
data_params = ['p', 'd_vocab', 'fn_name', 'frac_train', 'batch_style']
for param in data_params:
if param in flat_config:
result['data'][param] = flat_config[param]
# Model parameters
model_params = ['d_model', 'd_mlp', 'act_type', 'embed_type', 'init_type', 'init_scale', 'freq_num']
for param in model_params:
if param in flat_config:
result['model'][param] = flat_config[param]
# Training parameters
training_params = ['num_epochs', 'lr', 'weight_decay', 'optimizer', 'stopping_thresh',
'save_models', 'save_every', 'seed', 'no_wandb']
for param in training_params:
if param in flat_config:
result['training'][param] = flat_config[param]
return result
def run_experiment(config_dict):
"""Run the training experiment with given configuration"""
print("="*80)
print("MODULAR ARITHMETIC NEURAL NETWORK TRAINING")
print("="*80)
# Create config object
pipeline_config = Config(config_dict)
print(f"Configuration loaded successfully")
print(f"Device: {pipeline_config.device}")
print(f"Prime p: {pipeline_config.p}")
print(f"Vocabulary size: {pipeline_config.d_vocab}")
print(f"Model dimensions: d_model={pipeline_config.d_model}, d_mlp={pipeline_config.d_mlp}")
print(f"Function: {pipeline_config.fn_name}")
print(f"Activation: {pipeline_config.act_type}")
print(f"Seed: {pipeline_config.seed}")
print(f"Init scale: {pipeline_config.init_scale}")
print(f"Learning rate: {pipeline_config.lr}")
print("-" * 80)
# Initialize trainer
use_wandb = not getattr(pipeline_config, 'no_wandb', False)
world = Trainer(config=pipeline_config, use_wandb=use_wandb)
print(f'Run name: {world.run_name}')
world.initial_save_if_appropriate()
# Training variables
recent_test_loss = deque(maxlen=2)
save_point = 0
print(f"Starting training for {pipeline_config.num_epochs} epochs...")
print("-" * 80)
# Training loop
for epoch in range(pipeline_config.num_epochs):
# Perform a training step and get train/test losses
train_loss, test_loss = world.do_a_training_step(epoch)
# Stop training if test loss falls below the threshold
if test_loss.item() < pipeline_config.stopping_thresh:
print(f"Early stopping at epoch {epoch}: test loss {test_loss.item():.6f} < {pipeline_config.stopping_thresh}")
break
# Save model state if it's time to do so
if pipeline_config.is_it_time_to_save(epoch=epoch):
world.save_epoch(epoch=epoch, local_save=True)
# Save final model state after training is complete
print("-" * 80)
print("Training completed! Saving final model...")
world.post_training_save(save_optimizer_and_scheduler=True)
print(f"Final train loss: {world.train_losses[-1]:.6f}")
print(f"Final test loss: {world.test_losses[-1]:.6f}")
print(f"Final train accuracy: {world.train_accs[-1]:.4f}")
print(f"Final test accuracy: {world.test_accs[-1]:.4f}")
print("="*80)
return world
def run_batch_experiments(base_config):
"""Run batch experiments on init_type, optimizer, and act_type"""
print("="*80)
print("BATCH EXPERIMENTS: init_type, optimizer, act_type")
print("="*80)
results = []
experiment_count = 0
# Test parameters
init_types = ['random', 'single-freq']
optimizers = ['AdamW', 'SGD']
act_types = ['ReLU', 'GeLU', 'Quad', 'Id']
total_experiments = len(init_types) * len(optimizers) * len(act_types)
print(f"Running {total_experiments} experiments...")
print("-" * 80)
for init_type in init_types:
for optimizer in optimizers:
for act_type in act_types:
experiment_count += 1
print(f"\nExperiment {experiment_count}/{total_experiments}")
print(f"Configuration: init_type={init_type}, optimizer={optimizer}, act_type={act_type}")
print("-" * 50)
# Create experiment config
exp_config = base_config.copy()
exp_config['model']['init_type'] = init_type
exp_config['training']['optimizer'] = optimizer
exp_config['model']['act_type'] = act_type
# Use different seeds for each experiment
exp_config['training']['seed'] = 1024 + experiment_count
# Reduce epochs for faster batch testing
exp_config['training']['num_epochs'] = min(exp_config['training']['num_epochs'], 5000)
try:
# Run the experiment
trainer = run_experiment(exp_config)
# Collect results
result = {
'experiment': experiment_count,
'init_type': init_type,
'optimizer': optimizer,
'act_type': act_type,
'seed': exp_config['training']['seed'],
'final_train_loss': trainer.train_losses[-1],
'final_test_loss': trainer.test_losses[-1],
'final_train_acc': trainer.train_accs[-1],
'final_test_acc': trainer.test_accs[-1],
'run_name': trainer.run_name
}
results.append(result)
print(f"✓ Experiment {experiment_count} completed successfully")
print(f" Final test accuracy: {result['final_test_acc']:.4f}")
except Exception as e:
print(f"✗ Experiment {experiment_count} failed: {str(e)}")
results.append({
'experiment': experiment_count,
'init_type': init_type,
'optimizer': optimizer,
'act_type': act_type,
'seed': exp_config['training']['seed'],
'error': str(e)
})
print("-" * 50)
# Print summary
print("\n" + "="*80)
print("BATCH EXPERIMENTS SUMMARY")
print("="*80)
successful_results = [r for r in results if 'error' not in r]
failed_results = [r for r in results if 'error' in r]
print(f"Total experiments: {total_experiments}")
print(f"Successful: {len(successful_results)}")
print(f"Failed: {len(failed_results)}")
if successful_results:
print("\nTop 5 Results by Test Accuracy:")
print("-" * 50)
sorted_results = sorted(successful_results, key=lambda x: x['final_test_acc'], reverse=True)
for i, result in enumerate(sorted_results[:5]):
print(f"{i+1}. Test Acc: {result['final_test_acc']:.4f} | "
f"init_type={result['init_type']}, optimizer={result['optimizer']}, "
f"act_type={result['act_type']}")
print("\nDetailed Results:")
print("-" * 80)
print(f"{'Exp':<3} {'Init':<11} {'Opt':<5} {'Act':<4} {'Train Acc':<9} {'Test Acc':<8} {'Train Loss':<10} {'Test Loss':<9}")
print("-" * 80)
for result in sorted_results:
print(f"{result['experiment']:<3} "
f"{result['init_type']:<11} "
f"{result['optimizer']:<5} "
f"{result['act_type']:<4} "
f"{result['final_train_acc']:<9.4f} "
f"{result['final_test_acc']:<8.4f} "
f"{result['final_train_loss']:<10.6f} "
f"{result['final_test_loss']:<9.6f}")
if failed_results:
print(f"\nFailed Experiments:")
for result in failed_results:
print(f"Exp {result['experiment']}: {result['init_type']}, {result['optimizer']}, {result['act_type']} - {result['error']}")
print("="*80)
return results
def main():
"""Main entry point"""
args = parse_arguments()
# Load base configuration
if args.config:
# Load custom config file
import yaml
with open(args.config, 'r') as f:
configs = yaml.safe_load(f)
else:
# Load default config
configs = read_config()
# Override with command line arguments
final_config = override_config(configs, args)
if args.dry_run:
print("DRY RUN - Configuration that would be used:")
print("-" * 50)
import yaml
print(yaml.dump(final_config, default_flow_style=False, indent=2))
return
# Run single experiment
trainer = run_experiment(final_config)
print(f"Experiment completed successfully!")
print(f"Results saved to: {trainer.save_dir}/{trainer.run_name}")
return trainer
if __name__ == "__main__":
main()