|
|
|
""" |
|
Main entry point for OFDM channel estimation model training. |
|
|
|
This script provides the command-line interface for training OFDM channel estimation |
|
models. It loads configuration files, parses command-line arguments, and initiates |
|
the training process. |
|
|
|
Dataset Requirements: |
|
The training script expects datasets with the following structure: |
|
|
|
Training/Validation Sets: |
|
Directory containing .mat files with naming convention: |
|
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat |
|
|
|
Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat |
|
|
|
Test Sets: |
|
Directory with subdirectories for different test conditions: |
|
test_set/ |
|
βββ DS_test_set/ # Delay Spread tests |
|
β βββ DS_50/ |
|
β βββ DS_100/ |
|
β βββ ... |
|
βββ SNR_test_set/ # SNR tests |
|
β βββ SNR_10/ |
|
β βββ SNR_20/ |
|
β βββ ... |
|
βββ MDS_test_set/ # Multi-Doppler tests |
|
βββ DOP_200/ |
|
βββ DOP_400/ |
|
βββ ... |
|
|
|
Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]: |
|
- H[:, :, 0]: Ground truth channel (complex-valued channel matrix) |
|
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions (complex-valued) - used as input to models |
|
- H[:, :, 2]: Bilinear interpolated LS channel estimate (complex-valued) - available but currently unused |
|
""" |
|
|
|
import logging |
|
import sys |
|
from datetime import datetime |
|
from pathlib import Path |
|
|
|
from src.main.parser import parse_arguments |
|
from src.main.trainer import train |
|
from src.config import load_config |
|
from src.config.schemas import ModelConfig |
|
|
|
|
|
def setup_logging(log_level: str, log_dir: Path, exp_id: str) -> None: |
|
"""Set up logging configuration. |
|
|
|
Args: |
|
log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL) |
|
log_dir: Directory path for log files |
|
exp_id: Experiment identifier for log file naming |
|
""" |
|
|
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
log_file = log_dir / f"training_{exp_id}.log" |
|
|
|
logging.basicConfig( |
|
level=getattr(logging, log_level.upper()), |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout), |
|
logging.FileHandler(log_file) |
|
] |
|
) |
|
|
|
|
|
def main() -> None: |
|
"""Main entry point for the training script.""" |
|
try: |
|
|
|
args = parse_arguments() |
|
|
|
|
|
setup_logging(args.python_log_level, args.python_log_dir, args.exp_id) |
|
logger = logging.getLogger(__name__) |
|
|
|
logger.info("Starting OFDM channel estimation model training") |
|
logger.info(f"Model: {args.model_name}") |
|
logger.info(f"System config: {args.system_config_path}") |
|
logger.info(f"Model config: {args.model_config_path}") |
|
logger.info(f"Experiment ID: {args.exp_id}") |
|
|
|
|
|
logger.info("Loading configuration files...") |
|
system_config, model_config = load_config( |
|
args.system_config_path, |
|
args.model_config_path |
|
) |
|
|
|
|
|
expected_model_types = { |
|
"linear": "linear", |
|
"fortitran": "fortitran", |
|
"adafortitran": "adafortitran" |
|
} |
|
|
|
if args.model_name not in expected_model_types: |
|
raise ValueError(f"Unknown model name: {args.model_name}. Expected one of: {list(expected_model_types.keys())}") |
|
|
|
if model_config.model_type != expected_model_types[args.model_name]: |
|
raise ValueError(f"Model type mismatch: config specifies '{model_config.model_type}' but model name is '{args.model_name}'") |
|
|
|
logger.info("Configuration loaded successfully") |
|
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols") |
|
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols") |
|
|
|
|
|
if model_config.model_type == "linear": |
|
logger.info(f"Linear model with device: {model_config.device}") |
|
elif model_config.model_type == "fortitran": |
|
logger.info(f"FortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions") |
|
logger.info(f"Channel adaptation: disabled") |
|
elif model_config.model_type == "adafortitran": |
|
logger.info(f"AdaFortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions") |
|
logger.info(f"Channel adaptation: enabled") |
|
logger.info(f"Adaptive token length: {model_config.adaptive_token_length}") |
|
else: |
|
logger.warning(f"Unknown model type: {model_config.model_type}") |
|
|
|
|
|
logger.info("Initializing training...") |
|
train(system_config, model_config, args) |
|
|
|
logger.info("Training completed successfully") |
|
|
|
except KeyboardInterrupt: |
|
logger.info("Training interrupted by user") |
|
sys.exit(1) |
|
except Exception as e: |
|
logger.error(f"Training failed with error: {str(e)}") |
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |