File size: 5,667 Bytes
4e938bd 9727e5e 71dbdc8 4e938bd 71dbdc8 4e938bd 8a42349 687eaba 4e938bd 71dbdc8 4e938bd 71dbdc8 4e938bd 71dbdc8 4e938bd 71dbdc8 4e938bd 71dbdc8 4e938bd 71dbdc8 4e938bd 71dbdc8 687eaba 71dbdc8 8a42349 71dbdc8 4e938bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#!/usr/bin/env python3
"""
Main entry point for OFDM channel estimation model training.
This script provides the command-line interface for training OFDM channel estimation
models. It loads configuration files, parses command-line arguments, and initiates
the training process.
Dataset Requirements:
The training script expects datasets with the following structure:
Training/Validation Sets:
Directory containing .mat files with naming convention:
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
Test Sets:
Directory with subdirectories for different test conditions:
test_set/
βββ DS_test_set/ # Delay Spread tests
β βββ DS_50/
β βββ DS_100/
β βββ ...
βββ SNR_test_set/ # SNR tests
β βββ SNR_10/
β βββ SNR_20/
β βββ ...
βββ MDS_test_set/ # Multi-Doppler tests
βββ DOP_200/
βββ DOP_400/
βββ ...
Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
- H[:, :, 0]: Ground truth channel (complex-valued channel matrix)
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions (complex-valued) - used as input to models
- H[:, :, 2]: Bilinear interpolated LS channel estimate (complex-valued) - available but currently unused
"""
import logging
import sys
from datetime import datetime
from pathlib import Path
from src.main.parser import parse_arguments
from src.main.trainer import train
from src.config import load_config
from src.config.schemas import ModelConfig
def setup_logging(log_level: str, log_dir: Path, exp_id: str) -> None:
"""Set up logging configuration.
Args:
log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_dir: Directory path for log files
exp_id: Experiment identifier for log file naming
"""
# Create logs directory if it doesn't exist
log_dir.mkdir(parents=True, exist_ok=True)
# Create log file path using exp_id for easy matching
log_file = log_dir / f"training_{exp_id}.log"
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_file)
]
)
def main() -> None:
"""Main entry point for the training script."""
try:
# Parse command-line arguments
args = parse_arguments()
# Set up logging
setup_logging(args.python_log_level, args.python_log_dir, args.exp_id)
logger = logging.getLogger(__name__)
logger.info("Starting OFDM channel estimation model training")
logger.info(f"Model: {args.model_name}")
logger.info(f"System config: {args.system_config_path}")
logger.info(f"Model config: {args.model_config_path}")
logger.info(f"Experiment ID: {args.exp_id}")
# Load and validate configurations
logger.info("Loading configuration files...")
system_config, model_config = load_config(
args.system_config_path,
args.model_config_path
)
# Validate model type consistency
expected_model_types = {
"linear": "linear",
"fortitran": "fortitran",
"adafortitran": "adafortitran"
}
if args.model_name not in expected_model_types:
raise ValueError(f"Unknown model name: {args.model_name}. Expected one of: {list(expected_model_types.keys())}")
if model_config.model_type != expected_model_types[args.model_name]:
raise ValueError(f"Model type mismatch: config specifies '{model_config.model_type}' but model name is '{args.model_name}'")
logger.info("Configuration loaded successfully")
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
# Log model-specific information
if model_config.model_type == "linear":
logger.info(f"Linear model with device: {model_config.device}")
elif model_config.model_type == "fortitran":
logger.info(f"FortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
logger.info(f"Channel adaptation: disabled")
elif model_config.model_type == "adafortitran":
logger.info(f"AdaFortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
logger.info(f"Channel adaptation: enabled")
logger.info(f"Adaptive token length: {model_config.adaptive_token_length}")
else:
logger.warning(f"Unknown model type: {model_config.model_type}")
# Start training
logger.info("Initializing training...")
train(system_config, model_config, args)
logger.info("Training completed successfully")
except KeyboardInterrupt:
logger.info("Training interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"Training failed with error: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main() |