File size: 5,667 Bytes
4e938bd
 
 
 
 
 
 
9727e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71dbdc8
 
 
4e938bd
 
 
 
71dbdc8
4e938bd
 
 
 
8a42349
687eaba
4e938bd
 
71dbdc8
4e938bd
 
 
 
71dbdc8
 
4e938bd
71dbdc8
 
 
 
 
 
4e938bd
 
 
 
 
71dbdc8
4e938bd
 
 
 
 
 
 
 
 
 
 
71dbdc8
4e938bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71dbdc8
 
 
 
 
 
 
 
 
 
 
 
 
4e938bd
 
 
71dbdc8
 
687eaba
 
71dbdc8
 
 
 
 
 
 
8a42349
71dbdc8
4e938bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
"""
Main entry point for OFDM channel estimation model training.

This script provides the command-line interface for training OFDM channel estimation
models. It loads configuration files, parses command-line arguments, and initiates
the training process.

Dataset Requirements:
    The training script expects datasets with the following structure:
    
    Training/Validation Sets:
        Directory containing .mat files with naming convention:
        {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
        
        Example: 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
    
    Test Sets:
        Directory with subdirectories for different test conditions:
        test_set/
        β”œβ”€β”€ DS_test_set/     # Delay Spread tests
        β”‚   β”œβ”€β”€ DS_50/
        β”‚   β”œβ”€β”€ DS_100/
        β”‚   └── ...
        β”œβ”€β”€ SNR_test_set/    # SNR tests
        β”‚   β”œβ”€β”€ SNR_10/
        β”‚   β”œβ”€β”€ SNR_20/
        β”‚   └── ...
        └── MDS_test_set/    # Multi-Doppler tests
            β”œβ”€β”€ DOP_200/
            β”œβ”€β”€ DOP_400/
            └── ...
    
    Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
    - H[:, :, 0]: Ground truth channel (complex-valued channel matrix)
    - H[:, :, 1]: LS channel estimate with zeros for non-pilot positions (complex-valued) - used as input to models
    - H[:, :, 2]: Bilinear interpolated LS channel estimate (complex-valued) - available but currently unused
"""

import logging
import sys
from datetime import datetime
from pathlib import Path

from src.main.parser import parse_arguments
from src.main.trainer import train
from src.config import load_config
from src.config.schemas import ModelConfig


def setup_logging(log_level: str, log_dir: Path, exp_id: str) -> None:
    """Set up logging configuration.
    
    Args:
        log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
        log_dir: Directory path for log files
        exp_id: Experiment identifier for log file naming
    """
    # Create logs directory if it doesn't exist
    log_dir.mkdir(parents=True, exist_ok=True)
    
    # Create log file path using exp_id for easy matching
    log_file = log_dir / f"training_{exp_id}.log"
    
    logging.basicConfig(
        level=getattr(logging, log_level.upper()),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(log_file)
        ]
    )


def main() -> None:
    """Main entry point for the training script."""
    try:
        # Parse command-line arguments
        args = parse_arguments()
        
        # Set up logging
        setup_logging(args.python_log_level, args.python_log_dir, args.exp_id)
        logger = logging.getLogger(__name__)
        
        logger.info("Starting OFDM channel estimation model training")
        logger.info(f"Model: {args.model_name}")
        logger.info(f"System config: {args.system_config_path}")
        logger.info(f"Model config: {args.model_config_path}")
        logger.info(f"Experiment ID: {args.exp_id}")
        
        # Load and validate configurations
        logger.info("Loading configuration files...")
        system_config, model_config = load_config(
            args.system_config_path, 
            args.model_config_path
        )
        
        # Validate model type consistency
        expected_model_types = {
            "linear": "linear",
            "fortitran": "fortitran", 
            "adafortitran": "adafortitran"
        }
        
        if args.model_name not in expected_model_types:
            raise ValueError(f"Unknown model name: {args.model_name}. Expected one of: {list(expected_model_types.keys())}")
        
        if model_config.model_type != expected_model_types[args.model_name]:
            raise ValueError(f"Model type mismatch: config specifies '{model_config.model_type}' but model name is '{args.model_name}'")
        
        logger.info("Configuration loaded successfully")
        logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
        logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
        
        # Log model-specific information
        if model_config.model_type == "linear":
            logger.info(f"Linear model with device: {model_config.device}")
        elif model_config.model_type == "fortitran":
            logger.info(f"FortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
            logger.info(f"Channel adaptation: disabled")
        elif model_config.model_type == "adafortitran":
            logger.info(f"AdaFortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
            logger.info(f"Channel adaptation: enabled")
            logger.info(f"Adaptive token length: {model_config.adaptive_token_length}")
        else:
            logger.warning(f"Unknown model type: {model_config.model_type}")
        
        # Start training
        logger.info("Initializing training...")
        train(system_config, model_config, args)
        
        logger.info("Training completed successfully")
        
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
        sys.exit(1)
    except Exception as e:
        logger.error(f"Training failed with error: {str(e)}")
        sys.exit(1)


if __name__ == "__main__":
    main()