BerkIGuler commited on
Commit
4fa78a3
·
1 Parent(s): 71dbdc8

new README.md and Trainer design

Browse files
Files changed (6) hide show
  1. README.md +432 -2
  2. requirements.txt +5 -2
  3. src/main/parser.py +145 -36
  4. src/main/train_helpers.py +0 -265
  5. src/main/trainer.py +472 -110
  6. src/utils.py +0 -17
README.md CHANGED
@@ -1,7 +1,437 @@
1
- # Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
2
 
 
 
 
3
 
4
- ## License
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
7
 
 
1
+ # AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
2
 
3
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
4
+ [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
5
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.8+-red.svg)](https://pytorch.org/)
6
 
7
+ Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
8
+
9
+ ## 📖 Overview
10
+
11
+ AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
12
+
13
+ ### Key Features
14
+ - **🔄 Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
15
+ - **⚡ High Performance**: State-of-the-art results on OFDM channel estimation tasks
16
+ - **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
17
+ - **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
18
+ - **🚀 Production Ready**: Comprehensive training pipeline with advanced features
19
+
20
+ ## 🏗️ Architecture
21
+
22
+ The project implements three model variants:
23
+
24
+ 1. **Linear Estimator**: Simple learned linear transformation baseline
25
+ 2. **FortiTran**: Fixed transformer-based channel estimator
26
+ 3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
27
+
28
+ ### Model Comparison
29
+
30
+ | Model | Channel Adaptation | Complexity | Performance |
31
+ |-------|-------------------|------------|-------------|
32
+ | Linear | ❌ | Low | Baseline |
33
+ | FortiTran | ❌ | Medium | Good |
34
+ | AdaFortiTran | ✅ | High | **Best** |
35
+
36
+ ## 🚀 Quick Start
37
+
38
+ ### Installation
39
+
40
+ 1. **Clone the repository**:
41
+ ```bash
42
+ git clone https://github.com/your-username/AdaFortiTran.git
43
+ cd AdaFortiTran
44
+ ```
45
+
46
+ 2. **Install dependencies**:
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ 3. **Verify installation**:
52
+ ```bash
53
+ python -c "import torch; print(f'PyTorch {torch.__version__}')"
54
+ ```
55
+
56
+ ### Basic Training
57
+
58
+ Train an AdaFortiTran model with default settings:
59
+
60
+ ```bash
61
+ python src/main.py \
62
+ --model_name adafortitran \
63
+ --system_config_path config/system_config.yaml \
64
+ --model_config_path config/adafortitran.yaml \
65
+ --train_set data/train \
66
+ --val_set data/val \
67
+ --test_set data/test \
68
+ --exp_id my_experiment
69
+ ```
70
+
71
+ ### Advanced Training
72
+
73
+ Use all available features for optimal performance:
74
+
75
+ ```bash
76
+ python src/main.py \
77
+ --model_name adafortitran \
78
+ --system_config_path config/system_config.yaml \
79
+ --model_config_path config/adafortitran.yaml \
80
+ --train_set data/train \
81
+ --val_set data/val \
82
+ --test_set data/test \
83
+ --exp_id advanced_experiment \
84
+ --batch_size 128 \
85
+ --lr 5e-4 \
86
+ --max_epoch 100 \
87
+ --patience 10 \
88
+ --weight_decay 1e-4 \
89
+ --gradient_clip_val 1.0 \
90
+ --use_mixed_precision \
91
+ --save_every_n_epochs 5 \
92
+ --num_workers 8 \
93
+ --test_every_n 5
94
+ ```
95
+
96
+ ## 📁 Project Structure
97
+
98
+ ```
99
+ AdaFortiTran/
100
+ ├── config/ # Configuration files
101
+ │ ├── system_config.yaml # OFDM system parameters
102
+ │ ├── adafortitran.yaml # AdaFortiTran model config
103
+ │ ├── fortitran.yaml # FortiTran model config
104
+ │ └── linear.yaml # Linear model config
105
+ ├── data/ # Dataset directory
106
+ │ ├── train/ # Training data
107
+ │ ├── val/ # Validation data
108
+ │ └── test/ # Test data (DS, MDS, SNR sets)
109
+ ├── src/ # Source code
110
+ │ ├── main/ # Training pipeline
111
+ │ │ ├── trainer.py # Enhanced ModelTrainer
112
+ │ │ └── parser.py # Command-line argument parser
113
+ │ ├── models/ # Model implementations
114
+ │ │ ├── adafortitran.py # AdaFortiTran model
115
+ │ │ ├── fortitran.py # FortiTran model
116
+ │ │ ├── linear.py # Linear model
117
+ │ │ └── blocks/ # Model building blocks
118
+ │ ├── data/ # Data loading
119
+ │ │ └── dataset.py # Dataset and DataLoader classes
120
+ │ ├── config/ # Configuration management
121
+ │ │ ├── config_loader.py # YAML configuration loader
122
+ │ │ └── schemas.py # Pydantic validation schemas
123
+ │ └── utils.py # Utility functions
124
+ ├── requirements.txt # Python dependencies
125
+ ├── README.md # This file
126
+ ```
127
+
128
+ ## ⚙️ Configuration
129
+
130
+ ### System Configuration (`config/system_config.yaml`)
131
+
132
+ Defines OFDM system parameters:
133
+
134
+ ```yaml
135
+ ofdm:
136
+ num_scs: 120 # Number of subcarriers
137
+ num_symbols: 14 # Number of OFDM symbols
138
+
139
+ pilot:
140
+ num_scs: 12 # Number of pilot subcarriers
141
+ num_symbols: 2 # Number of pilot symbols
142
+ ```
143
+
144
+ ### Model Configuration (`config/adafortitran.yaml`)
145
+
146
+ Defines model architecture parameters:
147
+
148
+ ```yaml
149
+ model_type: 'adafortitran'
150
+ patch_size: [3, 2] # Patch dimensions
151
+ num_layers: 6 # Transformer layers
152
+ model_dim: 128 # Model dimension
153
+ num_head: 4 # Attention heads
154
+ activation: 'gelu' # Activation function
155
+ dropout: 0.1 # Dropout rate
156
+ max_seq_len: 512 # Maximum sequence length
157
+ pos_encoding_type: 'learnable' # Positional encoding
158
+ channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
159
+ adaptive_token_length: 6 # Adaptive token length
160
+ ```
161
+
162
+ ## 🎯 Training Features
163
+
164
+ ### Advanced Training Options
165
+
166
+ | Feature | Description | Default |
167
+ |---------|-------------|---------|
168
+ | `--use_mixed_precision` | Enable mixed precision training | False |
169
+ | `--gradient_clip_val` | Gradient clipping value | None |
170
+ | `--weight_decay` | Weight decay for optimizer | 0.0 |
171
+ | `--save_checkpoints` | Enable model checkpointing | True |
172
+ | `--save_best_only` | Save only best model | True |
173
+ | `--resume_from_checkpoint` | Resume from checkpoint | None |
174
+ | `--num_workers` | Data loading workers | 4 |
175
+ | `--pin_memory` | Pin memory for GPU | True |
176
+
177
+ ### Callback System
178
+
179
+ The training pipeline includes an extensible callback system:
180
+
181
+ - **TensorBoard Logging**: Automatic metric tracking and visualization
182
+ - **Checkpoint Management**: Flexible checkpoint saving strategies
183
+ - **Custom Callbacks**: Easy to add new logging or monitoring systems
184
+
185
+ ### Performance Optimizations
186
+
187
+ - **Mixed Precision Training**: Faster training on modern GPUs
188
+ - **Optimized Data Loading**: Configurable workers and memory pinning
189
+ - **Gradient Clipping**: Stable training with configurable clipping
190
+ - **Early Stopping**: Automatic training termination on plateau
191
+
192
+ ## 📊 Dataset Format
193
+
194
+ ### Expected File Structure
195
+
196
+ ```
197
+ data/
198
+ ├── train/
199
+ │ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
200
+ │ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
201
+ │ └── ...
202
+ ├── val/
203
+ │ └── ...
204
+ └── test/
205
+ ├── DS_test_set/ # Delay Spread tests
206
+ │ ├── DS_50/
207
+ │ ├── DS_100/
208
+ │ └── ...
209
+ ├── SNR_test_set/ # SNR tests
210
+ │ ├── SNR_10/
211
+ │ ├── SNR_20/
212
+ │ └── ...
213
+ └── MDS_test_set/ # Multi-Doppler tests
214
+ ├── DOP_200/
215
+ ├── DOP_400/
216
+ └── ...
217
+ ```
218
+
219
+ ### File Naming Convention
220
+
221
+ Files must follow the pattern:
222
+ ```
223
+ {file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
224
+ ```
225
+
226
+ Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
227
+
228
+ ### Data Format
229
+
230
+ Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
231
+ - `H[:, :, 0]`: Ground truth channel (complex values)
232
+ - `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
233
+ - `H[:, :, 2]`: Reserved for future use
234
+
235
+ ## 🔧 Usage Examples
236
+
237
+ ### Training Different Models
238
+
239
+ **Linear Estimator**:
240
+ ```bash
241
+ python src/main.py \
242
+ --model_name linear \
243
+ --system_config_path config/system_config.yaml \
244
+ --model_config_path config/linear.yaml \
245
+ --train_set data/train \
246
+ --val_set data/val \
247
+ --test_set data/test \
248
+ --exp_id linear_baseline
249
+ ```
250
+
251
+ **FortiTran**:
252
+ ```bash
253
+ python src/main.py \
254
+ --model_name fortitran \
255
+ --system_config_path config/system_config.yaml \
256
+ --model_config_path config/fortitran.yaml \
257
+ --train_set data/train \
258
+ --val_set data/val \
259
+ --test_set data/test \
260
+ --exp_id fortitran_experiment
261
+ ```
262
+
263
+ **AdaFortiTran**:
264
+ ```bash
265
+ python src/main.py \
266
+ --model_name adafortitran \
267
+ --system_config_path config/system_config.yaml \
268
+ --model_config_path config/adafortitran.yaml \
269
+ --train_set data/train \
270
+ --val_set data/val \
271
+ --test_set data/test \
272
+ --exp_id adafortitran_experiment
273
+ ```
274
+
275
+ ### Resume Training
276
+
277
+ ```bash
278
+ python src/main.py \
279
+ --model_name adafortitran \
280
+ --system_config_path config/system_config.yaml \
281
+ --model_config_path config/adafortitran.yaml \
282
+ --train_set data/train \
283
+ --val_set data/val \
284
+ --test_set data/test \
285
+ --exp_id resumed_experiment \
286
+ --resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
287
+ ```
288
+
289
+ ### Hyperparameter Tuning
290
+
291
+ ```bash
292
+ python src/main.py \
293
+ --model_name adafortitran \
294
+ --system_config_path config/system_config.yaml \
295
+ --model_config_path config/adafortitran.yaml \
296
+ --train_set data/train \
297
+ --val_set data/val \
298
+ --test_set data/test \
299
+ --exp_id hyperparameter_tuning \
300
+ --batch_size 64 \
301
+ --lr 1e-3 \
302
+ --max_epoch 50 \
303
+ --patience 5 \
304
+ --weight_decay 1e-5 \
305
+ --gradient_clip_val 0.5 \
306
+ --use_mixed_precision \
307
+ --test_every_n 5
308
+ ```
309
+
310
+ ## 📈 Monitoring and Logging
311
+
312
+ ### TensorBoard Integration
313
+
314
+ Training automatically logs metrics to TensorBoard:
315
+
316
+ ```bash
317
+ tensorboard --logdir runs/
318
+ ```
319
+
320
+ Available metrics:
321
+ - Training/validation loss
322
+ - Learning rate
323
+ - Test performance across conditions
324
+ - Error visualizations
325
+ - Model hyperparameters
326
+
327
+ ### Log Files
328
+
329
+ Training logs are saved to:
330
+ - `logs/training_{exp_id}.log`: Python logging output
331
+ - `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
332
+
333
+ ## 🧪 Testing and Evaluation
334
+
335
+ ### Automatic Testing
336
+
337
+ The training pipeline automatically evaluates models on:
338
+ - **DS (Delay Spread)**: Varying delay spread conditions
339
+ - **SNR**: Different signal-to-noise ratios
340
+ - **MDS (Multi-Doppler)**: Various Doppler shift scenarios
341
+
342
+ ### Manual Evaluation
343
+
344
+ ```python
345
+ from src.models import AdaFortiTranEstimator
346
+ from src.config import load_config
347
+
348
+ # Load configurations
349
+ system_config, model_config = load_config(
350
+ 'config/system_config.yaml',
351
+ 'config/adafortitran.yaml'
352
+ )
353
+
354
+ # Initialize model
355
+ model = AdaFortiTranEstimator(system_config, model_config)
356
+
357
+ # Load checkpoint
358
+ checkpoint = torch.load('checkpoint.pt')
359
+ model.load_state_dict(checkpoint['model_state_dict'])
360
+
361
+ # Evaluate
362
+ model.eval()
363
+ # ... evaluation code
364
+ ```
365
+
366
+ ## 🔬 Research and Development
367
+
368
+ ### Adding Custom Callbacks
369
+
370
+ ```python
371
+ from src.main.trainer import Callback, TrainingMetrics
372
+
373
+ class CustomCallback(Callback):
374
+ def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
375
+ # Custom logic here
376
+ print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
377
+ ```
378
+
379
+ ### Extending Models
380
+
381
+ The modular architecture makes it easy to add new model variants:
382
+
383
+ ```python
384
+ from src.models.fortitran import BaseFortiTranEstimator
385
+
386
+ class CustomEstimator(BaseFortiTranEstimator):
387
+ def __init__(self, system_config, model_config):
388
+ super().__init__(system_config, model_config, use_channel_adaptation=True)
389
+ # Add custom components
390
+ ```
391
+
392
+ ## 🐛 Troubleshooting
393
+
394
+ ### Common Issues
395
+
396
+ **CUDA Out of Memory**:
397
+ - Reduce batch size: `--batch_size 32`
398
+ - Enable mixed precision: `--use_mixed_precision`
399
+ - Reduce number of workers: `--num_workers 2`
400
+
401
+ **Slow Training**:
402
+ - Increase number of workers: `--num_workers 8`
403
+ - Enable pin memory: `--pin_memory`
404
+ - Use mixed precision: `--use_mixed_precision`
405
+
406
+ **Poor Convergence**:
407
+ - Adjust learning rate: `--lr 1e-4`
408
+ - Add gradient clipping: `--gradient_clip_val 1.0`
409
+ - Increase patience: `--patience 10`
410
+
411
+ ### Getting Help
412
+
413
+ 1. Check the logs in `logs/training_{exp_id}.log`
414
+ 2. Verify dataset format matches requirements
415
+ 3. Ensure all dependencies are installed correctly
416
+ 4. Check TensorBoard for training curves
417
+
418
+ ## 📚 Citation
419
+
420
+ If you use this code in your research, please cite:
421
+
422
+ ```bibtex
423
+ @misc{guler2025adafortitranadaptivetransformermodel,
424
+ title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
425
+ author={Berkay Guler and Hamid Jafarkhani},
426
+ year={2025},
427
+ eprint={2505.09076},
428
+ archivePrefix={arXiv},
429
+ primaryClass={cs.LG},
430
+ url={https://arxiv.org/abs/2505.09076},
431
+ }
432
+ ```
433
+
434
+ ## 📄 License
435
 
436
  This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
437
 
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
  torch
2
  pydantic
3
- yaml
4
  scipy
5
- tqdm
 
 
 
 
1
  torch
2
  pydantic
3
+ pyyaml
4
  scipy
5
+ tqdm
6
+ matplotlib
7
+ prettytable
8
+ tensorboard
src/main/parser.py CHANGED
@@ -10,7 +10,7 @@ of training runs.
10
  from pathlib import Path
11
  import argparse
12
  from pydantic import BaseModel, Field, model_validator
13
- from typing import Self
14
 
15
 
16
  class TrainingArguments(BaseModel):
@@ -41,9 +41,22 @@ class TrainingArguments(BaseModel):
41
  lr: Learning rate for optimizer
42
  max_epoch: Maximum number of training epochs
43
  patience: Early stopping patience in epochs
 
 
 
44
 
45
  # Evaluation
46
  test_every_n: Number of epochs between test evaluations
 
 
 
 
 
 
 
 
 
 
47
  """
48
 
49
  # Model Configuration
@@ -67,10 +80,23 @@ class TrainingArguments(BaseModel):
67
  lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
68
  max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
69
  patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
 
 
 
70
 
71
  # Evaluation
72
  test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
73
 
 
 
 
 
 
 
 
 
 
 
74
  @model_validator(mode='after')
75
  def validate_paths(self) -> Self:
76
  """Validate path-related arguments.
@@ -92,6 +118,13 @@ class TrainingArguments(BaseModel):
92
  if not self.model_config_path.suffix == '.yaml':
93
  raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
94
 
 
 
 
 
 
 
 
95
  return self
96
 
97
 
@@ -161,58 +194,134 @@ def parse_arguments() -> TrainingArguments:
161
  help='Experiment identifier for log folder naming'
162
  )
163
 
164
- # Optional arguments
165
- optional = parser.add_argument_group('optional arguments')
166
- optional.add_argument(
167
- '--python_log_level',
168
- type=str,
169
- default="INFO",
170
- help='Logger level for python logging module'
171
- )
172
- optional.add_argument(
173
- '--tensorboard_log_dir',
174
- type=Path,
175
- default="runs",
176
- help='Directory for tensorboard logs'
177
- )
178
- optional.add_argument(
179
- '--python_log_dir',
180
- type=Path,
181
- default="logs",
182
- help='Directory for python logging files'
183
- )
184
- optional.add_argument(
185
- '--test_every_n',
186
  type=int,
187
- default=10,
188
- help='Test model every N epochs'
 
 
 
 
 
 
189
  )
190
- optional.add_argument(
191
  '--max_epoch',
192
  type=int,
193
  default=10,
194
  help='Maximum number of training epochs'
195
  )
196
- optional.add_argument(
197
  '--patience',
198
  type=int,
199
  default=3,
200
  help='Early stopping patience (epochs)'
201
  )
202
- optional.add_argument(
203
- '--batch_size',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  type=int,
205
- default=64,
206
- help='Training batch size'
207
  )
208
 
209
- optional.add_argument(
210
- '--lr',
211
- type=float,
212
- default=1e-3,
213
- help='Initial learning rate'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  args = parser.parse_args()
218
 
 
10
  from pathlib import Path
11
  import argparse
12
  from pydantic import BaseModel, Field, model_validator
13
+ from typing import Self, Optional
14
 
15
 
16
  class TrainingArguments(BaseModel):
 
41
  lr: Learning rate for optimizer
42
  max_epoch: Maximum number of training epochs
43
  patience: Early stopping patience in epochs
44
+ weight_decay: Weight decay for optimizer
45
+ gradient_clip_val: Gradient clipping value
46
+ use_mixed_precision: Whether to use mixed precision training
47
 
48
  # Evaluation
49
  test_every_n: Number of epochs between test evaluations
50
+
51
+ # Checkpointing
52
+ save_checkpoints: Whether to save model checkpoints
53
+ save_best_only: Whether to save only the best model
54
+ save_every_n_epochs: Save checkpoint every N epochs
55
+ resume_from_checkpoint: Path to checkpoint to resume from
56
+
57
+ # Data Loading
58
+ num_workers: Number of data loading workers
59
+ pin_memory: Whether to pin memory for faster GPU transfer
60
  """
61
 
62
  # Model Configuration
 
80
  lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
81
  max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
82
  patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
83
+ weight_decay: float = Field(default=0.0, ge=0.0, description="Weight decay for optimizer")
84
+ gradient_clip_val: Optional[float] = Field(default=None, gt=0, description="Gradient clipping value")
85
+ use_mixed_precision: bool = Field(default=False, description="Whether to use mixed precision training")
86
 
87
  # Evaluation
88
  test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
89
 
90
+ # Checkpointing
91
+ save_checkpoints: bool = Field(default=True, description="Whether to save model checkpoints")
92
+ save_best_only: bool = Field(default=True, description="Whether to save only the best model")
93
+ save_every_n_epochs: Optional[int] = Field(default=None, gt=0, description="Save checkpoint every N epochs")
94
+ resume_from_checkpoint: Optional[Path] = Field(default=None, description="Path to checkpoint to resume from")
95
+
96
+ # Data Loading
97
+ num_workers: int = Field(default=4, ge=0, description="Number of data loading workers")
98
+ pin_memory: bool = Field(default=True, description="Whether to pin memory for faster GPU transfer")
99
+
100
  @model_validator(mode='after')
101
  def validate_paths(self) -> Self:
102
  """Validate path-related arguments.
 
118
  if not self.model_config_path.suffix == '.yaml':
119
  raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
120
 
121
+ # Validate checkpoint path if provided
122
+ if self.resume_from_checkpoint is not None:
123
+ if not self.resume_from_checkpoint.exists():
124
+ raise ValueError(f"Checkpoint file not found: {self.resume_from_checkpoint}")
125
+ if not self.resume_from_checkpoint.suffix == '.pt':
126
+ raise ValueError(f"Checkpoint file must be a .pt file: {self.resume_from_checkpoint}")
127
+
128
  return self
129
 
130
 
 
194
  help='Experiment identifier for log folder naming'
195
  )
196
 
197
+ # Training hyperparameters
198
+ training = parser.add_argument_group('training hyperparameters')
199
+ training.add_argument(
200
+ '--batch_size',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  type=int,
202
+ default=64,
203
+ help='Training batch size'
204
+ )
205
+ training.add_argument(
206
+ '--lr',
207
+ type=float,
208
+ default=1e-3,
209
+ help='Initial learning rate'
210
  )
211
+ training.add_argument(
212
  '--max_epoch',
213
  type=int,
214
  default=10,
215
  help='Maximum number of training epochs'
216
  )
217
+ training.add_argument(
218
  '--patience',
219
  type=int,
220
  default=3,
221
  help='Early stopping patience (epochs)'
222
  )
223
+ training.add_argument(
224
+ '--weight_decay',
225
+ type=float,
226
+ default=0.0,
227
+ help='Weight decay for optimizer'
228
+ )
229
+ training.add_argument(
230
+ '--gradient_clip_val',
231
+ type=float,
232
+ default=None,
233
+ help='Gradient clipping value (disabled if not specified)'
234
+ )
235
+ training.add_argument(
236
+ '--use_mixed_precision',
237
+ action='store_true',
238
+ help='Use mixed precision training (requires PyTorch >= 1.6)'
239
+ )
240
+
241
+ # Evaluation settings
242
+ evaluation = parser.add_argument_group('evaluation settings')
243
+ evaluation.add_argument(
244
+ '--test_every_n',
245
  type=int,
246
+ default=10,
247
+ help='Test model every N epochs'
248
  )
249
 
250
+ # Checkpointing settings
251
+ checkpointing = parser.add_argument_group('checkpointing settings')
252
+ checkpointing.add_argument(
253
+ '--save_checkpoints',
254
+ action='store_true',
255
+ default=True,
256
+ help='Save model checkpoints'
257
+ )
258
+ checkpointing.add_argument(
259
+ '--no_save_checkpoints',
260
+ action='store_false',
261
+ dest='save_checkpoints',
262
+ help='Disable saving model checkpoints'
263
+ )
264
+ checkpointing.add_argument(
265
+ '--save_best_only',
266
+ action='store_true',
267
+ default=True,
268
+ help='Save only the best model based on validation loss'
269
+ )
270
+ checkpointing.add_argument(
271
+ '--save_every_n_epochs',
272
+ type=int,
273
+ default=None,
274
+ help='Save checkpoint every N epochs (in addition to best model)'
275
+ )
276
+ checkpointing.add_argument(
277
+ '--resume_from_checkpoint',
278
+ type=Path,
279
+ default=None,
280
+ help='Path to checkpoint file to resume training from'
281
  )
282
 
283
+ # Data loading settings
284
+ data_loading = parser.add_argument_group('data loading settings')
285
+ data_loading.add_argument(
286
+ '--num_workers',
287
+ type=int,
288
+ default=4,
289
+ help='Number of data loading workers'
290
+ )
291
+ data_loading.add_argument(
292
+ '--pin_memory',
293
+ action='store_true',
294
+ default=True,
295
+ help='Pin memory for faster GPU transfer'
296
+ )
297
+ data_loading.add_argument(
298
+ '--no_pin_memory',
299
+ action='store_false',
300
+ dest='pin_memory',
301
+ help='Disable pin memory'
302
+ )
303
+
304
+ # Logging settings
305
+ logging_group = parser.add_argument_group('logging settings')
306
+ logging_group.add_argument(
307
+ '--python_log_level',
308
+ type=str,
309
+ default="INFO",
310
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
311
+ help='Logger level for python logging module'
312
+ )
313
+ logging_group.add_argument(
314
+ '--tensorboard_log_dir',
315
+ type=Path,
316
+ default="runs",
317
+ help='Directory for tensorboard logs'
318
+ )
319
+ logging_group.add_argument(
320
+ '--python_log_dir',
321
+ type=Path,
322
+ default="logs",
323
+ help='Directory for python logging files'
324
+ )
325
 
326
  args = parser.parse_args()
327
 
src/main/train_helpers.py DELETED
@@ -1,265 +0,0 @@
1
- """
2
- Training helper functions for OFDM channel estimation models.
3
-
4
- This module provides utility functions for training, evaluating, and testing
5
- deep learning models for OFDM channel estimation tasks. It includes functions
6
- for performing training epochs, model evaluation, prediction generation,
7
- and performance statistics calculation across different test conditions.
8
- """
9
-
10
- from typing import Dict, List, Tuple, Union, Callable
11
- import torch
12
- from torch import nn
13
- from torch.utils.data import DataLoader
14
- from torch.optim import Optimizer
15
- from torch.optim.lr_scheduler import ExponentialLR
16
- from src.utils import to_db, concat_complex_channel
17
-
18
- # Type aliases
19
- ComplexTensor = torch.Tensor # Complex tensor
20
- BatchType = Tuple[ComplexTensor, ComplexTensor, Union[Dict, None]]
21
- TestDataLoadersType = List[Tuple[str, DataLoader]]
22
- StatsType = Dict[int, float]
23
-
24
-
25
- def get_all_test_stats(
26
- model: nn.Module,
27
- test_dataloaders: Dict[str, TestDataLoadersType],
28
- loss_fn: Callable
29
- ) -> Tuple[StatsType, StatsType, StatsType]:
30
- """
31
- Evaluate model on all test datasets.
32
-
33
- Calculates performance statistics (MSE in dB) for a model across different
34
- test conditions: Delay Spread (DS), Max Doppler Shift (MDS), and
35
- Signal-to-Noise Ratio (SNR).
36
-
37
- Args:
38
- model: Model to evaluate
39
- test_dataloaders: Dictionary containing DataLoader objects for test sets:
40
- - "DS": Delay Spread test set
41
- - "MDS": Max Doppler Shift test set
42
- - "SNR": Signal-to-Noise Ratio test set
43
- loss_fn: Loss function for evaluation
44
-
45
- Returns:
46
- Tuple containing statistics (MSE in dB) for DS, MDS, and SNR test sets,
47
- where each set of statistics is a dictionary mapping parameter values to MSE
48
- """
49
- ds_stats = get_test_stats(model, test_dataloaders["DS"], loss_fn)
50
- mds_stats = get_test_stats(model, test_dataloaders["MDS"], loss_fn)
51
- snr_stats = get_test_stats(model, test_dataloaders["SNR"], loss_fn)
52
- return ds_stats, mds_stats, snr_stats
53
-
54
-
55
- def get_test_stats(
56
- model: nn.Module,
57
- test_dataloaders: TestDataLoadersType,
58
- loss_fn: Callable
59
- ) -> StatsType:
60
- """
61
- Evaluate model on provided test dataloaders.
62
-
63
- Calculates performance statistics (MSE in dB) for a model on a
64
- specific set of test conditions.
65
-
66
- Args:
67
- model: Model to evaluate
68
- test_dataloaders: List of (name, DataLoader) tuples for test sets,
69
- where names are in format "parameter_value"
70
- loss_fn: Loss function for evaluation
71
-
72
- Returns:
73
- Dictionary mapping test parameter values (as integers) to MSE values in dB
74
- """
75
- stats: StatsType = {}
76
- sorted_loaders = sorted(
77
- test_dataloaders,
78
- key=lambda x: int(x[0].split("_")[1])
79
- )
80
-
81
- for name, test_dataloader in sorted_loaders:
82
- var, val = name.split("_")
83
- test_loss = eval_model(model, test_dataloader, loss_fn)
84
- db_error = to_db(test_loss)
85
- print(f"{var}:{val} Test MSE: {db_error:.4f} dB")
86
- stats[int(val)] = db_error
87
-
88
- return stats
89
-
90
-
91
- def eval_model(
92
- model: nn.Module,
93
- eval_dataloader: DataLoader,
94
- loss_fn: Callable
95
- ) -> float:
96
- """
97
- Evaluate model on given dataloader.
98
-
99
- Calculates the average loss for a model on a dataset without
100
- performing parameter updates.
101
-
102
- Args:
103
- model: Model to evaluate
104
- eval_dataloader: DataLoader containing evaluation data
105
- loss_fn: Loss function for computing error
106
-
107
- Returns:
108
- Average validation loss (adjusted for complex values)
109
-
110
- Notes:
111
- Loss is multiplied by 2 to account for complex-valued matrices being
112
- represented as real-valued matrices of double size.
113
- """
114
- val_loss = 0.0
115
- model.eval()
116
-
117
- with torch.no_grad():
118
- for batch in eval_dataloader:
119
- estimated_channel, ideal_channel = _forward_pass(batch, model)
120
- output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
121
- val_loss += (2 * output.item() * batch[0].size(0))
122
-
123
- val_loss /= sum(len(batch[0]) for batch in eval_dataloader)
124
- return val_loss
125
-
126
-
127
- def predict_channels(
128
- model: nn.Module,
129
- test_dataloaders: TestDataLoadersType
130
- ) -> Dict[int, Dict[str, ComplexTensor]]:
131
- """
132
- Generate channel predictions for test datasets.
133
-
134
- Creates predictions for a sample from each test dataset to enable
135
- visualization and error analysis.
136
-
137
- Args:
138
- model: Model to use for predictions
139
- test_dataloaders: List of (name, DataLoader) tuples for test sets,
140
- where names are in format "parameter_value"
141
-
142
- Returns:
143
- Dictionary mapping test parameter values (as integers) to dictionaries containing
144
- estimated and ideal channels for a single sample
145
- """
146
- channels: Dict[int, Dict[str, ComplexTensor]] = {}
147
- sorted_loaders = sorted(
148
- test_dataloaders,
149
- key=lambda x: int(x[0].split("_")[1])
150
- )
151
-
152
- for name, test_dataloader in sorted_loaders:
153
- with torch.no_grad():
154
- batch = next(iter(test_dataloader))
155
- estimated_channels, ideal_channels = _forward_pass(batch, model)
156
-
157
- var, val = name.split("_")
158
- channels[int(val)] = {
159
- "estimated_channel": estimated_channels[0],
160
- "ideal_channel": ideal_channels[0]
161
- }
162
-
163
- return channels
164
-
165
-
166
- def train_epoch(
167
- model: nn.Module,
168
- optimizer: Optimizer,
169
- loss_fn: Callable,
170
- scheduler: ExponentialLR,
171
- train_dataloader: DataLoader
172
- ) -> float:
173
- """
174
- Train model for one epoch.
175
-
176
- Performs a complete training iteration over the dataset, including:
177
- - Forward pass through the model
178
- - Loss calculation
179
- - Backpropagation
180
- - Parameter updates
181
- - Learning rate scheduling
182
-
183
- Args:
184
- model: Model to train
185
- optimizer: Optimizer for updating model parameters
186
- loss_fn: Loss function for computing error
187
- scheduler: Learning rate scheduler
188
- train_dataloader: DataLoader containing training data
189
-
190
- Returns:
191
- Average training loss for the epoch (adjusted for complex values)
192
-
193
- Notes:
194
- Loss is multiplied by 2 to account for complex-valued matrices being
195
- represented as real-valued matrices of double size.
196
- """
197
- train_loss = 0.0
198
- model.train()
199
-
200
- for batch in train_dataloader:
201
- optimizer.zero_grad()
202
- estimated_channel, ideal_channel = _forward_pass(batch, model)
203
- output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
204
- output.backward()
205
- optimizer.step()
206
- train_loss += (2 * output.item() * batch[0].size(0))
207
-
208
- scheduler.step()
209
- train_loss /= sum(len(batch[0]) for batch in train_dataloader)
210
- return train_loss
211
-
212
-
213
- def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, ComplexTensor]:
214
- """
215
- Perform forward pass through model.
216
-
217
- Processes input data through the appropriate model based on its type,
218
- handling different input requirements for different model architectures.
219
-
220
- Args:
221
- batch: Tuple containing (estimated_channel, ideal_channel, metadata)
222
- model: Model to use for processing
223
-
224
- Returns:
225
- Tuple of (processed_estimated_channel, ideal_channel)
226
-
227
- Raises:
228
- ValueError: If model type is not recognized
229
- """
230
- estimated_channel, ideal_channel, meta_data = batch
231
-
232
- # All models now handle complex input directly
233
- if hasattr(model, 'use_channel_adaptation') and model.use_channel_adaptation:
234
- # AdaFortiTran uses meta_data for channel adaptation
235
- estimated_channel = model(estimated_channel, meta_data)
236
- else:
237
- # Linear and FortiTran models don't use meta_data
238
- estimated_channel = model(estimated_channel)
239
-
240
- return estimated_channel, ideal_channel.to(model.device)
241
-
242
-
243
- def _compute_loss(
244
- estimated_channel: ComplexTensor,
245
- ideal_channel: ComplexTensor,
246
- loss_fn: Callable
247
- ) -> torch.Tensor:
248
- """
249
- Calculate loss between estimated and ideal channels.
250
-
251
- Computes the loss between model output and ground truth using the specified
252
- loss function, with appropriate handling of complex values.
253
-
254
- Args:
255
- estimated_channel: Estimated channel from model
256
- ideal_channel: Ground truth ideal channel
257
- loss_fn: Loss function to compute error
258
-
259
- Returns:
260
- Computed loss value as a scalar tensor
261
- """
262
- return loss_fn(
263
- concat_complex_channel(estimated_channel),
264
- concat_complex_channel(ideal_channel)
265
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main/trainer.py CHANGED
@@ -11,9 +11,12 @@ import torch
11
  from torch import nn, optim
12
  from torch.utils.data import DataLoader
13
  from torch.utils.tensorboard.writer import SummaryWriter
14
- from typing import Dict, Tuple, Type, Union
15
  import logging
16
  from tqdm import tqdm
 
 
 
17
 
18
  from .parser import TrainingArguments
19
  from src.data.dataset import MatDataset, get_test_dataloaders
@@ -33,6 +36,291 @@ from src.config.schemas import SystemConfig, ModelConfig
33
  ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class ModelTrainer:
37
  """Handles the training and evaluation of deep learning models.
38
 
@@ -59,6 +347,9 @@ class ModelTrainer:
59
  val_loader: DataLoader for validation set (used for validation)
60
  test_loaders: Dictionary of test set DataLoaders (used for testing)
61
  logger: Logger instance for logging messages
 
 
 
62
  """
63
 
64
  MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
@@ -86,13 +377,59 @@ class ModelTrainer:
86
  self.logger = logging.getLogger(__name__)
87
 
88
  self.model = self._initialize_model()
89
- self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
 
 
 
 
 
 
 
90
  self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
91
  self.early_stopper = EarlyStopping(patience=args.patience)
92
-
93
  self.training_loss = nn.MSELoss()
94
 
 
 
 
 
 
 
95
  self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def _setup_tensorboard(self) -> SummaryWriter:
98
  """Set up TensorBoard logging.
@@ -134,26 +471,30 @@ class ModelTrainer:
134
  return model
135
 
136
  def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
 
137
  pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
 
138
  # Training and validation dataloaders
139
- train_dataset = MatDataset(
140
- self.args.train_set,
141
- pilot_dims
142
- )
143
- val_dataset = MatDataset(
144
- self.args.val_set,
145
- pilot_dims
146
- )
147
  train_loader = DataLoader(
148
  train_dataset,
149
  batch_size=self.args.batch_size,
150
- shuffle=True
 
 
151
  )
 
152
  val_loader = DataLoader(
153
  val_dataset,
154
  batch_size=self.args.batch_size,
155
- shuffle=True
 
 
156
  )
 
 
157
  test_loaders = {
158
  "DS": get_test_dataloaders(
159
  self.args.test_set / "DS_test_set",
@@ -173,11 +514,7 @@ class ModelTrainer:
173
  }
174
  return train_loader, val_loader, test_loaders
175
 
176
- def _log_test_results(
177
- self,
178
- epoch: int,
179
- test_stats: Dict[str, Dict]
180
- ) -> None:
181
  """Log test results to TensorBoard.
182
 
183
  Creates and logs visualizations for model performance across different test conditions.
@@ -198,7 +535,7 @@ class ModelTrainer:
198
  )
199
 
200
  # Plot error images
201
- predicted_channels = self._predict_channels(self.test_loaders[key])
202
  self.writer.add_figure(
203
  tag=f"{key} Error Images (Epoch:{epoch + 1})",
204
  figure=get_error_images(
@@ -208,15 +545,20 @@ class ModelTrainer:
208
  )
209
  )
210
 
211
- def _run_tests(self, epoch: int) -> None:
212
  """Run tests and log results.
213
 
214
  Evaluates the model on all test datasets and logs performance metrics and visualizations.
215
 
216
  Args:
217
  epoch: Current training epoch
 
 
 
218
  """
219
- ds_stats, mds_stats, snr_stats = self._get_all_test_stats()
 
 
220
 
221
  test_stats = {
222
  "DS": ds_stats,
@@ -225,6 +567,8 @@ class ModelTrainer:
225
  }
226
 
227
  self._log_test_results(epoch, test_stats)
 
 
228
 
229
  def _log_final_metrics(self, final_epoch: int) -> None:
230
  """Log final training metrics and hyperparameters.
@@ -270,92 +614,84 @@ class ModelTrainer:
270
  except Exception as e:
271
  self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
272
 
273
- def _compute_loss(self, estimated_channel, ideal_channel, loss_fn):
274
- return loss_fn(
275
- concat_complex_channel(estimated_channel),
276
- concat_complex_channel(ideal_channel)
277
- )
 
278
 
279
- def _forward_pass(self, batch, model):
280
- estimated_channel, ideal_channel, meta_data = batch
 
281
 
282
- # All models now handle complex input directly
283
- if isinstance(model, AdaFortiTranEstimator):
284
- # AdaFortiTran uses meta_data for channel adaptation
285
- estimated_channel = model(estimated_channel, meta_data)
286
- else:
287
- # Linear and FortiTran models don't use meta_data
288
- estimated_channel = model(estimated_channel)
289
-
290
- return estimated_channel, ideal_channel.to(model.device)
291
-
292
- def _train_epoch(self):
293
- train_loss = 0.0
294
- self.model.train()
295
- num_samples = 0
296
- for batch in self.train_loader:
297
- self.optimizer.zero_grad()
298
- estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
299
- output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
300
- output.backward()
301
- self.optimizer.step()
302
- batch_size = batch[0].size(0)
303
- train_loss += (2 * output.item() * batch_size)
304
- num_samples += batch_size
305
- self.scheduler.step()
306
- train_loss /= num_samples
307
- return train_loss
308
-
309
- def _eval_model(self, eval_dataloader):
310
- val_loss = 0.0
311
- self.model.eval()
312
- num_samples = 0
313
- with torch.no_grad():
314
- for batch in eval_dataloader:
315
- estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
316
- output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
317
- batch_size = batch[0].size(0)
318
- val_loss += (2 * output.item() * batch_size)
319
- num_samples += batch_size
320
- val_loss /= num_samples
321
- return val_loss
322
-
323
- def _predict_channels(self, test_dataloaders):
324
- channels = {}
325
- sorted_loaders = sorted(
326
- test_dataloaders,
327
- key=lambda x: int(x[0].split("_")[1])
328
- )
329
- for name, test_dataloader in sorted_loaders:
330
- with torch.no_grad():
331
- batch = next(iter(test_dataloader))
332
- estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
333
- var, val = name.split("_")
334
- channels[int(val)] = {
335
- "estimated_channel": estimated_channels[0],
336
- "ideal_channel": ideal_channels[0]
337
- }
338
- return channels
339
 
340
- def _get_test_stats(self, test_dataloaders):
341
- stats = {}
342
- sorted_loaders = sorted(
343
- test_dataloaders,
344
- key=lambda x: int(x[0].split("_")[1])
345
- )
346
- for name, test_dataloader in sorted_loaders:
347
- var, val = name.split("_")
348
- test_loss = self._eval_model(test_dataloader)
349
- db_error = to_db(test_loss)
350
- self.logger.info(f"{var}:{val} Test MSE: {db_error:.4f} dB")
351
- stats[int(val)] = db_error
352
- return stats
 
 
 
 
 
 
 
 
353
 
354
- def _get_all_test_stats(self):
355
- ds_stats = self._get_test_stats(self.test_loaders["DS"])
356
- mds_stats = self._get_test_stats(self.test_loaders["MDS"])
357
- snr_stats = self._get_test_stats(self.test_loaders["SNR"])
358
- return ds_stats, mds_stats, snr_stats
 
 
 
 
 
 
 
 
 
359
 
360
  def train(self) -> None:
361
  """Execute the training loop.
@@ -366,21 +702,43 @@ class ModelTrainer:
366
  - Early stopping when validation loss plateaus
367
  - Logging final metrics and results
368
  """
 
 
 
 
369
  last_epoch = 0
370
  pbar = tqdm(range(self.args.max_epoch), desc="Training")
 
371
  for epoch in pbar:
372
  last_epoch = epoch
 
 
 
 
 
373
  # Training step
374
- train_loss = self._train_epoch()
375
- self.writer.add_scalar('Loss/Train', train_loss, epoch + 1)
376
-
377
  # Validation step
378
- val_loss = self._eval_model(self.val_loader)
379
- self.writer.add_scalar('Loss/Val', val_loss, epoch + 1)
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  # Update progress bar with loss info
382
  pbar.set_description(
383
- f"Epoch {epoch + 1}/{self.args.max_epoch} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
 
 
384
 
385
  if self.early_stopper.early_stop(val_loss):
386
  pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
@@ -391,8 +749,12 @@ class ModelTrainer:
391
  message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
392
  pbar.write(message)
393
  self._run_tests(epoch)
 
394
  self._log_final_metrics(last_epoch)
395
- self.writer.close()
 
 
 
396
 
397
 
398
  def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
 
11
  from torch import nn, optim
12
  from torch.utils.data import DataLoader
13
  from torch.utils.tensorboard.writer import SummaryWriter
14
+ from typing import Dict, Tuple, Type, Union, Optional, List, Protocol
15
  import logging
16
  from tqdm import tqdm
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from abc import ABC, abstractmethod
20
 
21
  from .parser import TrainingArguments
22
  from src.data.dataset import MatDataset, get_test_dataloaders
 
36
  ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
37
 
38
 
39
+ @dataclass
40
+ class TrainingMetrics:
41
+ """Container for training metrics."""
42
+ train_loss: float
43
+ val_loss: float
44
+ epoch: int
45
+ learning_rate: float
46
+
47
+
48
+ @dataclass
49
+ class TestResults:
50
+ """Container for test results."""
51
+ ds_stats: Dict[int, float]
52
+ mds_stats: Dict[int, float]
53
+ snr_stats: Dict[int, float]
54
+
55
+
56
+ class Callback(ABC):
57
+ """Base class for training callbacks."""
58
+
59
+ @abstractmethod
60
+ def on_epoch_begin(self, epoch: int) -> None:
61
+ """Called at the beginning of each epoch."""
62
+ pass
63
+
64
+ @abstractmethod
65
+ def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
66
+ """Called at the end of each epoch."""
67
+ pass
68
+
69
+ @abstractmethod
70
+ def on_training_begin(self) -> None:
71
+ """Called at the beginning of training."""
72
+ pass
73
+
74
+ @abstractmethod
75
+ def on_training_end(self) -> None:
76
+ """Called at the end of training."""
77
+ pass
78
+
79
+
80
+ class CheckpointCallback(Callback):
81
+ """Callback for saving model checkpoints."""
82
+
83
+ def __init__(self, save_dir: Path, save_best_only: bool = True,
84
+ save_every_n_epochs: Optional[int] = None):
85
+ self.save_dir = save_dir
86
+ self.save_best_only = save_best_only
87
+ self.save_every_n_epochs = save_every_n_epochs
88
+ self.best_val_loss = float('inf')
89
+ self.trainer = None
90
+
91
+ def set_trainer(self, trainer: 'ModelTrainer') -> None:
92
+ """Set the trainer reference."""
93
+ self.trainer = trainer
94
+
95
+ def on_epoch_begin(self, epoch: int) -> None:
96
+ pass
97
+
98
+ def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
99
+ if self.trainer is None:
100
+ return
101
+
102
+ # Save best model
103
+ if self.save_best_only and metrics.val_loss < self.best_val_loss:
104
+ self.best_val_loss = metrics.val_loss
105
+ self.trainer.save_checkpoint(
106
+ epoch, metrics,
107
+ checkpoint_dir=self.save_dir / "best"
108
+ )
109
+
110
+ # Save every N epochs
111
+ if (self.save_every_n_epochs is not None and
112
+ (epoch + 1) % self.save_every_n_epochs == 0):
113
+ self.trainer.save_checkpoint(
114
+ epoch, metrics,
115
+ checkpoint_dir=self.save_dir / "periodic"
116
+ )
117
+
118
+ def on_training_begin(self) -> None:
119
+ self.save_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ def on_training_end(self) -> None:
122
+ pass
123
+
124
+
125
+ class TensorBoardCallback(Callback):
126
+ """Callback for TensorBoard logging."""
127
+
128
+ def __init__(self, writer: SummaryWriter):
129
+ self.writer = writer
130
+
131
+ def on_epoch_begin(self, epoch: int) -> None:
132
+ pass
133
+
134
+ def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
135
+ self.writer.add_scalar('Loss/Train', metrics.train_loss, metrics.epoch + 1)
136
+ self.writer.add_scalar('Loss/Val', metrics.val_loss, metrics.epoch + 1)
137
+ self.writer.add_scalar('Learning_Rate', metrics.learning_rate, metrics.epoch + 1)
138
+
139
+ def on_training_begin(self) -> None:
140
+ pass
141
+
142
+ def on_training_end(self) -> None:
143
+ self.writer.close()
144
+
145
+
146
+ class TrainingLoop:
147
+ """Handles the core training loop logic."""
148
+
149
+ def __init__(self, model: ModelType, optimizer: optim.Optimizer,
150
+ scheduler: optim.lr_scheduler.LRScheduler,
151
+ loss_fn: nn.Module, device: torch.device, scaler: Optional[torch.cuda.amp.GradScaler] = None,
152
+ gradient_clip_val: Optional[float] = None):
153
+ self.model = model
154
+ self.optimizer = optimizer
155
+ self.scheduler = scheduler
156
+ self.loss_fn = loss_fn
157
+ self.device = device
158
+ self.scaler = scaler
159
+ self.gradient_clip_val = gradient_clip_val
160
+
161
+ def _compute_loss(self, estimated_channel: torch.Tensor,
162
+ ideal_channel: torch.Tensor) -> torch.Tensor:
163
+ """Compute loss between estimated and ideal channels."""
164
+ return self.loss_fn(
165
+ concat_complex_channel(estimated_channel),
166
+ concat_complex_channel(ideal_channel)
167
+ )
168
+
169
+ def _forward_pass(self, batch: Tuple[torch.Tensor, torch.Tensor, Tuple],
170
+ model: ModelType) -> Tuple[torch.Tensor, torch.Tensor]:
171
+ """Perform forward pass through the model."""
172
+ estimated_channel, ideal_channel, meta_data = batch
173
+
174
+ # All models now handle complex input directly
175
+ if isinstance(model, AdaFortiTranEstimator):
176
+ # AdaFortiTran uses meta_data for channel adaptation
177
+ estimated_channel = model(estimated_channel, meta_data)
178
+ else:
179
+ # Linear and FortiTran models don't use meta_data
180
+ estimated_channel = model(estimated_channel)
181
+
182
+ return estimated_channel, ideal_channel.to(model.device)
183
+
184
+ def train_epoch(self, train_loader: DataLoader) -> float:
185
+ """Train for one epoch."""
186
+ train_loss = 0.0
187
+ self.model.train()
188
+ num_samples = 0
189
+
190
+ for batch in train_loader:
191
+ self.optimizer.zero_grad()
192
+ estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
193
+
194
+ if self.scaler:
195
+ with torch.cuda.amp.autocast():
196
+ loss = self._compute_loss(estimated_channel, ideal_channel)
197
+ self.scaler.scale(loss).backward()
198
+
199
+ # Gradient clipping
200
+ if self.gradient_clip_val:
201
+ self.scaler.unscale_(self.optimizer)
202
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
203
+
204
+ self.scaler.step(self.optimizer)
205
+ self.scaler.update()
206
+ else:
207
+ loss = self._compute_loss(estimated_channel, ideal_channel)
208
+ loss.backward()
209
+
210
+ # Gradient clipping
211
+ if self.gradient_clip_val:
212
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
213
+
214
+ self.optimizer.step()
215
+
216
+ batch_size = batch[0].size(0)
217
+ train_loss += (2 * loss.item() * batch_size)
218
+ num_samples += batch_size
219
+
220
+ self.scheduler.step()
221
+ return train_loss / num_samples
222
+
223
+ def evaluate(self, eval_loader: DataLoader) -> float:
224
+ """Evaluate the model."""
225
+ val_loss = 0.0
226
+ self.model.eval()
227
+ num_samples = 0
228
+
229
+ with torch.no_grad():
230
+ for batch in eval_loader:
231
+ estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
232
+
233
+ if self.scaler:
234
+ with torch.cuda.amp.autocast():
235
+ loss = self._compute_loss(estimated_channel, ideal_channel)
236
+ else:
237
+ loss = self._compute_loss(estimated_channel, ideal_channel)
238
+
239
+ batch_size = batch[0].size(0)
240
+ val_loss += (2 * loss.item() * batch_size)
241
+ num_samples += batch_size
242
+
243
+ return val_loss / num_samples
244
+
245
+
246
+ class ModelEvaluator:
247
+ """Handles model evaluation and testing."""
248
+
249
+ def __init__(self, model: ModelType, device: torch.device, logger: logging.Logger):
250
+ self.model = model
251
+ self.device = device
252
+ self.logger = logger
253
+
254
+ def _forward_pass(self, batch: Tuple[torch.Tensor, torch.Tensor, Tuple],
255
+ model: ModelType) -> Tuple[torch.Tensor, torch.Tensor]:
256
+ """Perform forward pass through the model."""
257
+ estimated_channel, ideal_channel, meta_data = batch
258
+
259
+ if isinstance(model, AdaFortiTranEstimator):
260
+ estimated_channel = model(estimated_channel, meta_data)
261
+ else:
262
+ estimated_channel = model(estimated_channel)
263
+
264
+ return estimated_channel, ideal_channel.to(model.device)
265
+
266
+ def predict_channels(self, test_dataloaders: List[Tuple[str, DataLoader]]) -> Dict[int, Dict]:
267
+ """Predict channels for visualization."""
268
+ channels = {}
269
+ sorted_loaders = sorted(
270
+ test_dataloaders,
271
+ key=lambda x: int(x[0].split("_")[1])
272
+ )
273
+
274
+ for name, test_dataloader in sorted_loaders:
275
+ with torch.no_grad():
276
+ batch = next(iter(test_dataloader))
277
+ estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
278
+
279
+ var, val = name.split("_")
280
+ channels[int(val)] = {
281
+ "estimated_channel": estimated_channels[0],
282
+ "ideal_channel": ideal_channels[0]
283
+ }
284
+ return channels
285
+
286
+ def get_test_stats(self, test_dataloaders: List[Tuple[str, DataLoader]],
287
+ loss_fn: nn.Module) -> Dict[int, float]:
288
+ """Get test statistics for a set of dataloaders."""
289
+ stats = {}
290
+ sorted_loaders = sorted(
291
+ test_dataloaders,
292
+ key=lambda x: int(x[0].split("_")[1])
293
+ )
294
+
295
+ for name, test_dataloader in sorted_loaders:
296
+ var, val = name.split("_")
297
+ test_loss = self._evaluate_dataloader(test_dataloader, loss_fn)
298
+ db_error = to_db(test_loss)
299
+ self.logger.info(f"{var}:{val} Test MSE: {db_error:.4f} dB")
300
+ stats[int(val)] = db_error
301
+ return stats
302
+
303
+ def _evaluate_dataloader(self, dataloader: DataLoader, loss_fn: nn.Module) -> float:
304
+ """Evaluate a single dataloader."""
305
+ total_loss = 0.0
306
+ num_samples = 0
307
+ self.model.eval()
308
+
309
+ with torch.no_grad():
310
+ for batch in dataloader:
311
+ estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
312
+ loss = loss_fn(
313
+ concat_complex_channel(estimated_channel),
314
+ concat_complex_channel(ideal_channel)
315
+ )
316
+
317
+ batch_size = batch[0].size(0)
318
+ total_loss += (2 * loss.item() * batch_size)
319
+ num_samples += batch_size
320
+
321
+ return total_loss / num_samples
322
+
323
+
324
  class ModelTrainer:
325
  """Handles the training and evaluation of deep learning models.
326
 
 
347
  val_loader: DataLoader for validation set (used for validation)
348
  test_loaders: Dictionary of test set DataLoaders (used for testing)
349
  logger: Logger instance for logging messages
350
+ training_loop: TrainingLoop instance for core training logic
351
+ evaluator: ModelEvaluator instance for evaluation logic
352
+ callbacks: List of training callbacks
353
  """
354
 
355
  MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
 
377
  self.logger = logging.getLogger(__name__)
378
 
379
  self.model = self._initialize_model()
380
+
381
+ # Initialize optimizer with weight decay
382
+ self.optimizer = optim.Adam(
383
+ self.model.parameters(),
384
+ lr=args.lr,
385
+ weight_decay=args.weight_decay
386
+ )
387
+
388
  self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
389
  self.early_stopper = EarlyStopping(patience=args.patience)
 
390
  self.training_loss = nn.MSELoss()
391
 
392
+ # Initialize mixed precision training if requested
393
+ self.scaler = None
394
+ if args.use_mixed_precision and self.device.type == 'cuda':
395
+ self.scaler = torch.cuda.amp.GradScaler()
396
+ self.logger.info("Mixed precision training enabled")
397
+
398
  self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
399
+
400
+ # Initialize components
401
+ self.training_loop = TrainingLoop(
402
+ self.model, self.optimizer, self.scheduler, self.training_loss,
403
+ self.device, self.scaler, self.args.gradient_clip_val
404
+ )
405
+ self.evaluator = ModelEvaluator(self.model, self.device, self.logger)
406
+
407
+ # Initialize callbacks
408
+ self.callbacks = self._setup_callbacks()
409
+
410
+ # Resume from checkpoint if specified
411
+ if args.resume_from_checkpoint is not None:
412
+ self._resume_from_checkpoint(args.resume_from_checkpoint)
413
+
414
+ def _setup_callbacks(self) -> List[Callback]:
415
+ """Set up training callbacks."""
416
+ callbacks = []
417
+
418
+ # TensorBoard callback
419
+ callbacks.append(TensorBoardCallback(self.writer))
420
+
421
+ # Checkpoint callback (only if checkpointing is enabled)
422
+ if self.args.save_checkpoints:
423
+ checkpoint_dir = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
424
+ checkpoint_callback = CheckpointCallback(
425
+ save_dir=checkpoint_dir,
426
+ save_best_only=self.args.save_best_only,
427
+ save_every_n_epochs=self.args.save_every_n_epochs
428
+ )
429
+ checkpoint_callback.set_trainer(self)
430
+ callbacks.append(checkpoint_callback)
431
+
432
+ return callbacks
433
 
434
  def _setup_tensorboard(self) -> SummaryWriter:
435
  """Set up TensorBoard logging.
 
471
  return model
472
 
473
  def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
474
+ """Get training, validation, and test dataloaders."""
475
  pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
476
+
477
  # Training and validation dataloaders
478
+ train_dataset = MatDataset(self.args.train_set, pilot_dims)
479
+ val_dataset = MatDataset(self.args.val_set, pilot_dims)
480
+
 
 
 
 
 
481
  train_loader = DataLoader(
482
  train_dataset,
483
  batch_size=self.args.batch_size,
484
+ shuffle=True,
485
+ num_workers=self.args.num_workers,
486
+ pin_memory=self.args.pin_memory and self.device.type == 'cuda'
487
  )
488
+
489
  val_loader = DataLoader(
490
  val_dataset,
491
  batch_size=self.args.batch_size,
492
+ shuffle=False, # No need to shuffle validation data
493
+ num_workers=self.args.num_workers,
494
+ pin_memory=self.args.pin_memory and self.device.type == 'cuda'
495
  )
496
+
497
+ # Test dataloaders
498
  test_loaders = {
499
  "DS": get_test_dataloaders(
500
  self.args.test_set / "DS_test_set",
 
514
  }
515
  return train_loader, val_loader, test_loaders
516
 
517
+ def _log_test_results(self, epoch: int, test_stats: Dict[str, Dict]) -> None:
 
 
 
 
518
  """Log test results to TensorBoard.
519
 
520
  Creates and logs visualizations for model performance across different test conditions.
 
535
  )
536
 
537
  # Plot error images
538
+ predicted_channels = self.evaluator.predict_channels(self.test_loaders[key])
539
  self.writer.add_figure(
540
  tag=f"{key} Error Images (Epoch:{epoch + 1})",
541
  figure=get_error_images(
 
545
  )
546
  )
547
 
548
+ def _run_tests(self, epoch: int) -> TestResults:
549
  """Run tests and log results.
550
 
551
  Evaluates the model on all test datasets and logs performance metrics and visualizations.
552
 
553
  Args:
554
  epoch: Current training epoch
555
+
556
+ Returns:
557
+ TestResults containing all test statistics
558
  """
559
+ ds_stats = self.evaluator.get_test_stats(self.test_loaders["DS"], self.training_loss)
560
+ mds_stats = self.evaluator.get_test_stats(self.test_loaders["MDS"], self.training_loss)
561
+ snr_stats = self.evaluator.get_test_stats(self.test_loaders["SNR"], self.training_loss)
562
 
563
  test_stats = {
564
  "DS": ds_stats,
 
567
  }
568
 
569
  self._log_test_results(epoch, test_stats)
570
+
571
+ return TestResults(ds_stats, mds_stats, snr_stats)
572
 
573
  def _log_final_metrics(self, final_epoch: int) -> None:
574
  """Log final training metrics and hyperparameters.
 
614
  except Exception as e:
615
  self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
616
 
617
+ def _get_all_test_stats(self) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float]]:
618
+ """Get all test statistics."""
619
+ ds_stats = self.evaluator.get_test_stats(self.test_loaders["DS"], self.training_loss)
620
+ mds_stats = self.evaluator.get_test_stats(self.test_loaders["MDS"], self.training_loss)
621
+ snr_stats = self.evaluator.get_test_stats(self.test_loaders["SNR"], self.training_loss)
622
+ return ds_stats, mds_stats, snr_stats
623
 
624
+ def save_checkpoint(self, epoch: int, metrics: TrainingMetrics,
625
+ checkpoint_dir: Optional[Path] = None) -> None:
626
+ """Save model checkpoint.
627
 
628
+ Args:
629
+ epoch: Current epoch number
630
+ metrics: Current training metrics
631
+ checkpoint_dir: Directory to save checkpoint (defaults to tensorboard log dir)
632
+ """
633
+ if checkpoint_dir is None:
634
+ checkpoint_dir = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
635
+
636
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
637
+ checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
638
+
639
+ checkpoint = {
640
+ 'epoch': epoch,
641
+ 'model_state_dict': self.model.state_dict(),
642
+ 'optimizer_state_dict': self.optimizer.state_dict(),
643
+ 'scheduler_state_dict': self.scheduler.state_dict(),
644
+ 'train_loss': metrics.train_loss,
645
+ 'val_loss': metrics.val_loss,
646
+ 'learning_rate': metrics.learning_rate,
647
+ 'system_config': self.system_config,
648
+ 'model_config': self.model_config,
649
+ 'args': self.args
650
+ }
651
+
652
+ # Save scaler state if using mixed precision
653
+ if self.scaler:
654
+ checkpoint['scaler_state_dict'] = self.scaler.state_dict()
655
+
656
+ torch.save(checkpoint, checkpoint_path)
657
+ self.logger.info(f"Checkpoint saved to {checkpoint_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
 
659
+ def load_checkpoint(self, checkpoint_path: Path) -> int:
660
+ """Load model checkpoint.
661
+
662
+ Args:
663
+ checkpoint_path: Path to checkpoint file
664
+
665
+ Returns:
666
+ Epoch number of loaded checkpoint
667
+ """
668
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
669
+
670
+ self.model.load_state_dict(checkpoint['model_state_dict'])
671
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
672
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
673
+
674
+ # Load scaler state if it exists
675
+ if self.scaler and 'scaler_state_dict' in checkpoint:
676
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
677
+
678
+ self.logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
679
+ return checkpoint['epoch']
680
 
681
+ def _resume_from_checkpoint(self, checkpoint_path: Path) -> None:
682
+ """Resume training from a checkpoint.
683
+
684
+ Args:
685
+ checkpoint_path: Path to checkpoint file
686
+ """
687
+ start_epoch = self.load_checkpoint(checkpoint_path)
688
+ self.logger.info(f"Resuming training from epoch {start_epoch}")
689
+
690
+ # Update the early stopper with the best loss from checkpoint
691
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
692
+ if 'val_loss' in checkpoint:
693
+ self.early_stopper.min_loss = checkpoint['val_loss']
694
+ self.logger.info(f"Early stopper initialized with validation loss: {checkpoint['val_loss']:.4f}")
695
 
696
  def train(self) -> None:
697
  """Execute the training loop.
 
702
  - Early stopping when validation loss plateaus
703
  - Logging final metrics and results
704
  """
705
+ # Notify callbacks that training is beginning
706
+ for callback in self.callbacks:
707
+ callback.on_training_begin()
708
+
709
  last_epoch = 0
710
  pbar = tqdm(range(self.args.max_epoch), desc="Training")
711
+
712
  for epoch in pbar:
713
  last_epoch = epoch
714
+
715
+ # Notify callbacks that epoch is beginning
716
+ for callback in self.callbacks:
717
+ callback.on_epoch_begin(epoch)
718
+
719
  # Training step
720
+ train_loss = self.training_loop.train_epoch(self.train_loader)
721
+
 
722
  # Validation step
723
+ val_loss = self.training_loop.evaluate(self.val_loader)
724
+
725
+ # Create metrics object
726
+ metrics = TrainingMetrics(
727
+ train_loss=train_loss,
728
+ val_loss=val_loss,
729
+ epoch=epoch,
730
+ learning_rate=self.optimizer.param_groups[0]['lr']
731
+ )
732
+
733
+ # Notify callbacks that epoch has ended
734
+ for callback in self.callbacks:
735
+ callback.on_epoch_end(epoch, metrics)
736
 
737
  # Update progress bar with loss info
738
  pbar.set_description(
739
+ f"Epoch {epoch + 1}/{self.args.max_epoch} - "
740
+ f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
741
+ )
742
 
743
  if self.early_stopper.early_stop(val_loss):
744
  pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
 
749
  message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
750
  pbar.write(message)
751
  self._run_tests(epoch)
752
+
753
  self._log_final_metrics(last_epoch)
754
+
755
+ # Notify callbacks that training has ended
756
+ for callback in self.callbacks:
757
+ callback.on_training_end()
758
 
759
 
760
  def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
src/utils.py CHANGED
@@ -180,24 +180,7 @@ def concat_complex_channel(channel_matrix):
180
  return cat_channel_m
181
 
182
 
183
- def inverse_concat_complex_channel(channel_matrix: torch.Tensor) -> torch.Tensor:
184
- """
185
- Reconstruct complex channel matrix from concatenated real matrix.
186
-
187
- Reverses the operation performed by concat_complex_channel by
188
- splitting the tensor and combining the parts into a complex tensor.
189
-
190
- Args:
191
- channel_matrix: Real-valued matrix of shape (B, F, 2*T)
192
 
193
- Returns:
194
- Complex matrix of shape (B, F, T)
195
- """
196
- split_idx = channel_matrix.shape[-1] // 2
197
- return torch.complex(
198
- channel_matrix[:, :split_idx],
199
- channel_matrix[:, split_idx:]
200
- )
201
 
202
 
203
  def get_test_stats_plot(x_name, stats, methods, show=False):
 
180
  return cat_channel_m
181
 
182
 
 
 
 
 
 
 
 
 
 
183
 
 
 
 
 
 
 
 
 
184
 
185
 
186
  def get_test_stats_plot(x_name, stats, methods, show=False):