AdaFortiTran / src /main.py
BerkIGuler's picture
minor fixed on src/main/parser.py
71dbdc8
#!/usr/bin/env python3
"""
Main entry point for OFDM channel estimation model training.
This script provides the command-line interface for training OFDM channel estimation
models. It loads configuration files, parses command-line arguments, and initiates
the training process.
Dataset Requirements:
The training script expects datasets with the following structure:
Training/Validation Sets:
Directory containing .mat files with naming convention:
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
Test Sets:
Directory with subdirectories for different test conditions:
test_set/
β”œβ”€β”€ DS_test_set/ # Delay Spread tests
β”‚ β”œβ”€β”€ DS_50/
β”‚ β”œβ”€β”€ DS_100/
β”‚ └── ...
β”œβ”€β”€ SNR_test_set/ # SNR tests
β”‚ β”œβ”€β”€ SNR_10/
β”‚ β”œβ”€β”€ SNR_20/
β”‚ └── ...
└── MDS_test_set/ # Multi-Doppler tests
β”œβ”€β”€ DOP_200/
β”œβ”€β”€ DOP_400/
└── ...
Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
- H[:, :, 0]: Ground truth channel (complex-valued channel matrix)
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions (complex-valued) - used as input to models
- H[:, :, 2]: Bilinear interpolated LS channel estimate (complex-valued) - available but currently unused
"""
import logging
import sys
from datetime import datetime
from pathlib import Path
from src.main.parser import parse_arguments
from src.main.trainer import train
from src.config import load_config
from src.config.schemas import ModelConfig
def setup_logging(log_level: str, log_dir: Path, exp_id: str) -> None:
"""Set up logging configuration.
Args:
log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_dir: Directory path for log files
exp_id: Experiment identifier for log file naming
"""
# Create logs directory if it doesn't exist
log_dir.mkdir(parents=True, exist_ok=True)
# Create log file path using exp_id for easy matching
log_file = log_dir / f"training_{exp_id}.log"
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_file)
]
)
def main() -> None:
"""Main entry point for the training script."""
try:
# Parse command-line arguments
args = parse_arguments()
# Set up logging
setup_logging(args.python_log_level, args.python_log_dir, args.exp_id)
logger = logging.getLogger(__name__)
logger.info("Starting OFDM channel estimation model training")
logger.info(f"Model: {args.model_name}")
logger.info(f"System config: {args.system_config_path}")
logger.info(f"Model config: {args.model_config_path}")
logger.info(f"Experiment ID: {args.exp_id}")
# Load and validate configurations
logger.info("Loading configuration files...")
system_config, model_config = load_config(
args.system_config_path,
args.model_config_path
)
# Validate model type consistency
expected_model_types = {
"linear": "linear",
"fortitran": "fortitran",
"adafortitran": "adafortitran"
}
if args.model_name not in expected_model_types:
raise ValueError(f"Unknown model name: {args.model_name}. Expected one of: {list(expected_model_types.keys())}")
if model_config.model_type != expected_model_types[args.model_name]:
raise ValueError(f"Model type mismatch: config specifies '{model_config.model_type}' but model name is '{args.model_name}'")
logger.info("Configuration loaded successfully")
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
# Log model-specific information
if model_config.model_type == "linear":
logger.info(f"Linear model with device: {model_config.device}")
elif model_config.model_type == "fortitran":
logger.info(f"FortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
logger.info(f"Channel adaptation: disabled")
elif model_config.model_type == "adafortitran":
logger.info(f"AdaFortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
logger.info(f"Channel adaptation: enabled")
logger.info(f"Adaptive token length: {model_config.adaptive_token_length}")
else:
logger.warning(f"Unknown model type: {model_config.model_type}")
# Start training
logger.info("Initializing training...")
train(system_config, model_config, args)
logger.info("Training completed successfully")
except KeyboardInterrupt:
logger.info("Training interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"Training failed with error: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()