Commit
·
4fa78a3
1
Parent(s):
71dbdc8
new README.md and Trainer design
Browse files- README.md +432 -2
- requirements.txt +5 -2
- src/main/parser.py +145 -36
- src/main/train_helpers.py +0 -265
- src/main/trainer.py +472 -110
- src/utils.py +0 -17
README.md
CHANGED
@@ -1,7 +1,437 @@
|
|
1 |
-
#
|
2 |
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
7 |
|
|
|
1 |
+
# AdaFortiTran: Adaptive Transformer Model for Robust OFDM Channel Estimation
|
2 |
|
3 |
+
[](LICENSE)
|
4 |
+
[](https://www.python.org/)
|
5 |
+
[](https://pytorch.org/)
|
6 |
|
7 |
+
Official implementation of [AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation](https://arxiv.org/abs/2505.09076) accepted at ICC 2025, Montreal, Canada.
|
8 |
+
|
9 |
+
## 📖 Overview
|
10 |
+
|
11 |
+
AdaFortiTran is a novel adaptive transformer-based model for OFDM channel estimation that dynamically adapts to varying channel conditions (SNR, delay spread, Doppler shift). The model combines the power of transformer architectures with channel-aware adaptation mechanisms to achieve robust performance across diverse wireless environments.
|
12 |
+
|
13 |
+
### Key Features
|
14 |
+
- **🔄 Adaptive Architecture**: Dynamically adapts to channel conditions using meta-information
|
15 |
+
- **⚡ High Performance**: State-of-the-art results on OFDM channel estimation tasks
|
16 |
+
- **🧠 Transformer-Based**: Leverages attention mechanisms for long-range dependencies
|
17 |
+
- **🎯 Robust**: Maintains performance across varying SNR, delay spread, and Doppler conditions
|
18 |
+
- **🚀 Production Ready**: Comprehensive training pipeline with advanced features
|
19 |
+
|
20 |
+
## 🏗️ Architecture
|
21 |
+
|
22 |
+
The project implements three model variants:
|
23 |
+
|
24 |
+
1. **Linear Estimator**: Simple learned linear transformation baseline
|
25 |
+
2. **FortiTran**: Fixed transformer-based channel estimator
|
26 |
+
3. **AdaFortiTran**: Adaptive transformer with channel condition awareness
|
27 |
+
|
28 |
+
### Model Comparison
|
29 |
+
|
30 |
+
| Model | Channel Adaptation | Complexity | Performance |
|
31 |
+
|-------|-------------------|------------|-------------|
|
32 |
+
| Linear | ❌ | Low | Baseline |
|
33 |
+
| FortiTran | ❌ | Medium | Good |
|
34 |
+
| AdaFortiTran | ✅ | High | **Best** |
|
35 |
+
|
36 |
+
## 🚀 Quick Start
|
37 |
+
|
38 |
+
### Installation
|
39 |
+
|
40 |
+
1. **Clone the repository**:
|
41 |
+
```bash
|
42 |
+
git clone https://github.com/your-username/AdaFortiTran.git
|
43 |
+
cd AdaFortiTran
|
44 |
+
```
|
45 |
+
|
46 |
+
2. **Install dependencies**:
|
47 |
+
```bash
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
3. **Verify installation**:
|
52 |
+
```bash
|
53 |
+
python -c "import torch; print(f'PyTorch {torch.__version__}')"
|
54 |
+
```
|
55 |
+
|
56 |
+
### Basic Training
|
57 |
+
|
58 |
+
Train an AdaFortiTran model with default settings:
|
59 |
+
|
60 |
+
```bash
|
61 |
+
python src/main.py \
|
62 |
+
--model_name adafortitran \
|
63 |
+
--system_config_path config/system_config.yaml \
|
64 |
+
--model_config_path config/adafortitran.yaml \
|
65 |
+
--train_set data/train \
|
66 |
+
--val_set data/val \
|
67 |
+
--test_set data/test \
|
68 |
+
--exp_id my_experiment
|
69 |
+
```
|
70 |
+
|
71 |
+
### Advanced Training
|
72 |
+
|
73 |
+
Use all available features for optimal performance:
|
74 |
+
|
75 |
+
```bash
|
76 |
+
python src/main.py \
|
77 |
+
--model_name adafortitran \
|
78 |
+
--system_config_path config/system_config.yaml \
|
79 |
+
--model_config_path config/adafortitran.yaml \
|
80 |
+
--train_set data/train \
|
81 |
+
--val_set data/val \
|
82 |
+
--test_set data/test \
|
83 |
+
--exp_id advanced_experiment \
|
84 |
+
--batch_size 128 \
|
85 |
+
--lr 5e-4 \
|
86 |
+
--max_epoch 100 \
|
87 |
+
--patience 10 \
|
88 |
+
--weight_decay 1e-4 \
|
89 |
+
--gradient_clip_val 1.0 \
|
90 |
+
--use_mixed_precision \
|
91 |
+
--save_every_n_epochs 5 \
|
92 |
+
--num_workers 8 \
|
93 |
+
--test_every_n 5
|
94 |
+
```
|
95 |
+
|
96 |
+
## 📁 Project Structure
|
97 |
+
|
98 |
+
```
|
99 |
+
AdaFortiTran/
|
100 |
+
├── config/ # Configuration files
|
101 |
+
│ ├── system_config.yaml # OFDM system parameters
|
102 |
+
│ ├── adafortitran.yaml # AdaFortiTran model config
|
103 |
+
│ ├── fortitran.yaml # FortiTran model config
|
104 |
+
│ └── linear.yaml # Linear model config
|
105 |
+
├── data/ # Dataset directory
|
106 |
+
│ ├── train/ # Training data
|
107 |
+
│ ├── val/ # Validation data
|
108 |
+
│ └── test/ # Test data (DS, MDS, SNR sets)
|
109 |
+
├── src/ # Source code
|
110 |
+
│ ├── main/ # Training pipeline
|
111 |
+
│ │ ├── trainer.py # Enhanced ModelTrainer
|
112 |
+
│ │ └── parser.py # Command-line argument parser
|
113 |
+
│ ├── models/ # Model implementations
|
114 |
+
│ │ ├── adafortitran.py # AdaFortiTran model
|
115 |
+
│ │ ├── fortitran.py # FortiTran model
|
116 |
+
│ │ ├── linear.py # Linear model
|
117 |
+
│ │ └── blocks/ # Model building blocks
|
118 |
+
│ ├── data/ # Data loading
|
119 |
+
│ │ └── dataset.py # Dataset and DataLoader classes
|
120 |
+
│ ├── config/ # Configuration management
|
121 |
+
│ │ ├── config_loader.py # YAML configuration loader
|
122 |
+
│ │ └── schemas.py # Pydantic validation schemas
|
123 |
+
│ └── utils.py # Utility functions
|
124 |
+
├── requirements.txt # Python dependencies
|
125 |
+
├── README.md # This file
|
126 |
+
```
|
127 |
+
|
128 |
+
## ⚙️ Configuration
|
129 |
+
|
130 |
+
### System Configuration (`config/system_config.yaml`)
|
131 |
+
|
132 |
+
Defines OFDM system parameters:
|
133 |
+
|
134 |
+
```yaml
|
135 |
+
ofdm:
|
136 |
+
num_scs: 120 # Number of subcarriers
|
137 |
+
num_symbols: 14 # Number of OFDM symbols
|
138 |
+
|
139 |
+
pilot:
|
140 |
+
num_scs: 12 # Number of pilot subcarriers
|
141 |
+
num_symbols: 2 # Number of pilot symbols
|
142 |
+
```
|
143 |
+
|
144 |
+
### Model Configuration (`config/adafortitran.yaml`)
|
145 |
+
|
146 |
+
Defines model architecture parameters:
|
147 |
+
|
148 |
+
```yaml
|
149 |
+
model_type: 'adafortitran'
|
150 |
+
patch_size: [3, 2] # Patch dimensions
|
151 |
+
num_layers: 6 # Transformer layers
|
152 |
+
model_dim: 128 # Model dimension
|
153 |
+
num_head: 4 # Attention heads
|
154 |
+
activation: 'gelu' # Activation function
|
155 |
+
dropout: 0.1 # Dropout rate
|
156 |
+
max_seq_len: 512 # Maximum sequence length
|
157 |
+
pos_encoding_type: 'learnable' # Positional encoding
|
158 |
+
channel_adaptivity_hidden_sizes: [7, 42, 560] # Adaptation layers
|
159 |
+
adaptive_token_length: 6 # Adaptive token length
|
160 |
+
```
|
161 |
+
|
162 |
+
## 🎯 Training Features
|
163 |
+
|
164 |
+
### Advanced Training Options
|
165 |
+
|
166 |
+
| Feature | Description | Default |
|
167 |
+
|---------|-------------|---------|
|
168 |
+
| `--use_mixed_precision` | Enable mixed precision training | False |
|
169 |
+
| `--gradient_clip_val` | Gradient clipping value | None |
|
170 |
+
| `--weight_decay` | Weight decay for optimizer | 0.0 |
|
171 |
+
| `--save_checkpoints` | Enable model checkpointing | True |
|
172 |
+
| `--save_best_only` | Save only best model | True |
|
173 |
+
| `--resume_from_checkpoint` | Resume from checkpoint | None |
|
174 |
+
| `--num_workers` | Data loading workers | 4 |
|
175 |
+
| `--pin_memory` | Pin memory for GPU | True |
|
176 |
+
|
177 |
+
### Callback System
|
178 |
+
|
179 |
+
The training pipeline includes an extensible callback system:
|
180 |
+
|
181 |
+
- **TensorBoard Logging**: Automatic metric tracking and visualization
|
182 |
+
- **Checkpoint Management**: Flexible checkpoint saving strategies
|
183 |
+
- **Custom Callbacks**: Easy to add new logging or monitoring systems
|
184 |
+
|
185 |
+
### Performance Optimizations
|
186 |
+
|
187 |
+
- **Mixed Precision Training**: Faster training on modern GPUs
|
188 |
+
- **Optimized Data Loading**: Configurable workers and memory pinning
|
189 |
+
- **Gradient Clipping**: Stable training with configurable clipping
|
190 |
+
- **Early Stopping**: Automatic training termination on plateau
|
191 |
+
|
192 |
+
## 📊 Dataset Format
|
193 |
+
|
194 |
+
### Expected File Structure
|
195 |
+
|
196 |
+
```
|
197 |
+
data/
|
198 |
+
├── train/
|
199 |
+
│ ├── 1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
200 |
+
│ ├── 2_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat
|
201 |
+
│ └── ...
|
202 |
+
├── val/
|
203 |
+
│ └── ...
|
204 |
+
└── test/
|
205 |
+
├── DS_test_set/ # Delay Spread tests
|
206 |
+
│ ├── DS_50/
|
207 |
+
│ ├── DS_100/
|
208 |
+
│ └── ...
|
209 |
+
├── SNR_test_set/ # SNR tests
|
210 |
+
│ ├── SNR_10/
|
211 |
+
│ ├── SNR_20/
|
212 |
+
│ └── ...
|
213 |
+
└── MDS_test_set/ # Multi-Doppler tests
|
214 |
+
├── DOP_200/
|
215 |
+
├── DOP_400/
|
216 |
+
└── ...
|
217 |
+
```
|
218 |
+
|
219 |
+
### File Naming Convention
|
220 |
+
|
221 |
+
Files must follow the pattern:
|
222 |
+
```
|
223 |
+
{file_number}_SNR-{snr}_DS-{delay_spread}_DOP-{doppler}_N-{pilot_freq}_{channel_type}.mat
|
224 |
+
```
|
225 |
+
|
226 |
+
Example: `1_SNR-20_DS-50_DOP-500_N-3_TDL-A.mat`
|
227 |
+
|
228 |
+
### Data Format
|
229 |
+
|
230 |
+
Each `.mat` file must contain variable `H` with shape `[subcarriers, symbols, 3]`:
|
231 |
+
- `H[:, :, 0]`: Ground truth channel (complex values)
|
232 |
+
- `H[:, :, 1]`: LS channel estimate with zeros for non-pilot positions
|
233 |
+
- `H[:, :, 2]`: Reserved for future use
|
234 |
+
|
235 |
+
## 🔧 Usage Examples
|
236 |
+
|
237 |
+
### Training Different Models
|
238 |
+
|
239 |
+
**Linear Estimator**:
|
240 |
+
```bash
|
241 |
+
python src/main.py \
|
242 |
+
--model_name linear \
|
243 |
+
--system_config_path config/system_config.yaml \
|
244 |
+
--model_config_path config/linear.yaml \
|
245 |
+
--train_set data/train \
|
246 |
+
--val_set data/val \
|
247 |
+
--test_set data/test \
|
248 |
+
--exp_id linear_baseline
|
249 |
+
```
|
250 |
+
|
251 |
+
**FortiTran**:
|
252 |
+
```bash
|
253 |
+
python src/main.py \
|
254 |
+
--model_name fortitran \
|
255 |
+
--system_config_path config/system_config.yaml \
|
256 |
+
--model_config_path config/fortitran.yaml \
|
257 |
+
--train_set data/train \
|
258 |
+
--val_set data/val \
|
259 |
+
--test_set data/test \
|
260 |
+
--exp_id fortitran_experiment
|
261 |
+
```
|
262 |
+
|
263 |
+
**AdaFortiTran**:
|
264 |
+
```bash
|
265 |
+
python src/main.py \
|
266 |
+
--model_name adafortitran \
|
267 |
+
--system_config_path config/system_config.yaml \
|
268 |
+
--model_config_path config/adafortitran.yaml \
|
269 |
+
--train_set data/train \
|
270 |
+
--val_set data/val \
|
271 |
+
--test_set data/test \
|
272 |
+
--exp_id adafortitran_experiment
|
273 |
+
```
|
274 |
+
|
275 |
+
### Resume Training
|
276 |
+
|
277 |
+
```bash
|
278 |
+
python src/main.py \
|
279 |
+
--model_name adafortitran \
|
280 |
+
--system_config_path config/system_config.yaml \
|
281 |
+
--model_config_path config/adafortitran.yaml \
|
282 |
+
--train_set data/train \
|
283 |
+
--val_set data/val \
|
284 |
+
--test_set data/test \
|
285 |
+
--exp_id resumed_experiment \
|
286 |
+
--resume_from_checkpoint runs/adafortitran_experiment/best/checkpoint_epoch_50.pt
|
287 |
+
```
|
288 |
+
|
289 |
+
### Hyperparameter Tuning
|
290 |
+
|
291 |
+
```bash
|
292 |
+
python src/main.py \
|
293 |
+
--model_name adafortitran \
|
294 |
+
--system_config_path config/system_config.yaml \
|
295 |
+
--model_config_path config/adafortitran.yaml \
|
296 |
+
--train_set data/train \
|
297 |
+
--val_set data/val \
|
298 |
+
--test_set data/test \
|
299 |
+
--exp_id hyperparameter_tuning \
|
300 |
+
--batch_size 64 \
|
301 |
+
--lr 1e-3 \
|
302 |
+
--max_epoch 50 \
|
303 |
+
--patience 5 \
|
304 |
+
--weight_decay 1e-5 \
|
305 |
+
--gradient_clip_val 0.5 \
|
306 |
+
--use_mixed_precision \
|
307 |
+
--test_every_n 5
|
308 |
+
```
|
309 |
+
|
310 |
+
## 📈 Monitoring and Logging
|
311 |
+
|
312 |
+
### TensorBoard Integration
|
313 |
+
|
314 |
+
Training automatically logs metrics to TensorBoard:
|
315 |
+
|
316 |
+
```bash
|
317 |
+
tensorboard --logdir runs/
|
318 |
+
```
|
319 |
+
|
320 |
+
Available metrics:
|
321 |
+
- Training/validation loss
|
322 |
+
- Learning rate
|
323 |
+
- Test performance across conditions
|
324 |
+
- Error visualizations
|
325 |
+
- Model hyperparameters
|
326 |
+
|
327 |
+
### Log Files
|
328 |
+
|
329 |
+
Training logs are saved to:
|
330 |
+
- `logs/training_{exp_id}.log`: Python logging output
|
331 |
+
- `runs/{model_name}_{exp_id}/`: TensorBoard logs and checkpoints
|
332 |
+
|
333 |
+
## 🧪 Testing and Evaluation
|
334 |
+
|
335 |
+
### Automatic Testing
|
336 |
+
|
337 |
+
The training pipeline automatically evaluates models on:
|
338 |
+
- **DS (Delay Spread)**: Varying delay spread conditions
|
339 |
+
- **SNR**: Different signal-to-noise ratios
|
340 |
+
- **MDS (Multi-Doppler)**: Various Doppler shift scenarios
|
341 |
+
|
342 |
+
### Manual Evaluation
|
343 |
+
|
344 |
+
```python
|
345 |
+
from src.models import AdaFortiTranEstimator
|
346 |
+
from src.config import load_config
|
347 |
+
|
348 |
+
# Load configurations
|
349 |
+
system_config, model_config = load_config(
|
350 |
+
'config/system_config.yaml',
|
351 |
+
'config/adafortitran.yaml'
|
352 |
+
)
|
353 |
+
|
354 |
+
# Initialize model
|
355 |
+
model = AdaFortiTranEstimator(system_config, model_config)
|
356 |
+
|
357 |
+
# Load checkpoint
|
358 |
+
checkpoint = torch.load('checkpoint.pt')
|
359 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
360 |
+
|
361 |
+
# Evaluate
|
362 |
+
model.eval()
|
363 |
+
# ... evaluation code
|
364 |
+
```
|
365 |
+
|
366 |
+
## 🔬 Research and Development
|
367 |
+
|
368 |
+
### Adding Custom Callbacks
|
369 |
+
|
370 |
+
```python
|
371 |
+
from src.main.trainer import Callback, TrainingMetrics
|
372 |
+
|
373 |
+
class CustomCallback(Callback):
|
374 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
375 |
+
# Custom logic here
|
376 |
+
print(f"Epoch {epoch}: Train Loss = {metrics.train_loss:.4f}")
|
377 |
+
```
|
378 |
+
|
379 |
+
### Extending Models
|
380 |
+
|
381 |
+
The modular architecture makes it easy to add new model variants:
|
382 |
+
|
383 |
+
```python
|
384 |
+
from src.models.fortitran import BaseFortiTranEstimator
|
385 |
+
|
386 |
+
class CustomEstimator(BaseFortiTranEstimator):
|
387 |
+
def __init__(self, system_config, model_config):
|
388 |
+
super().__init__(system_config, model_config, use_channel_adaptation=True)
|
389 |
+
# Add custom components
|
390 |
+
```
|
391 |
+
|
392 |
+
## 🐛 Troubleshooting
|
393 |
+
|
394 |
+
### Common Issues
|
395 |
+
|
396 |
+
**CUDA Out of Memory**:
|
397 |
+
- Reduce batch size: `--batch_size 32`
|
398 |
+
- Enable mixed precision: `--use_mixed_precision`
|
399 |
+
- Reduce number of workers: `--num_workers 2`
|
400 |
+
|
401 |
+
**Slow Training**:
|
402 |
+
- Increase number of workers: `--num_workers 8`
|
403 |
+
- Enable pin memory: `--pin_memory`
|
404 |
+
- Use mixed precision: `--use_mixed_precision`
|
405 |
+
|
406 |
+
**Poor Convergence**:
|
407 |
+
- Adjust learning rate: `--lr 1e-4`
|
408 |
+
- Add gradient clipping: `--gradient_clip_val 1.0`
|
409 |
+
- Increase patience: `--patience 10`
|
410 |
+
|
411 |
+
### Getting Help
|
412 |
+
|
413 |
+
1. Check the logs in `logs/training_{exp_id}.log`
|
414 |
+
2. Verify dataset format matches requirements
|
415 |
+
3. Ensure all dependencies are installed correctly
|
416 |
+
4. Check TensorBoard for training curves
|
417 |
+
|
418 |
+
## 📚 Citation
|
419 |
+
|
420 |
+
If you use this code in your research, please cite:
|
421 |
+
|
422 |
+
```bibtex
|
423 |
+
@misc{guler2025adafortitranadaptivetransformermodel,
|
424 |
+
title={AdaFortiTran: An Adaptive Transformer Model for Robust OFDM Channel Estimation},
|
425 |
+
author={Berkay Guler and Hamid Jafarkhani},
|
426 |
+
year={2025},
|
427 |
+
eprint={2505.09076},
|
428 |
+
archivePrefix={arXiv},
|
429 |
+
primaryClass={cs.LG},
|
430 |
+
url={https://arxiv.org/abs/2505.09076},
|
431 |
+
}
|
432 |
+
```
|
433 |
+
|
434 |
+
## 📄 License
|
435 |
|
436 |
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
437 |
|
requirements.txt
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
torch
|
2 |
pydantic
|
3 |
-
|
4 |
scipy
|
5 |
-
tqdm
|
|
|
|
|
|
|
|
1 |
torch
|
2 |
pydantic
|
3 |
+
pyyaml
|
4 |
scipy
|
5 |
+
tqdm
|
6 |
+
matplotlib
|
7 |
+
prettytable
|
8 |
+
tensorboard
|
src/main/parser.py
CHANGED
@@ -10,7 +10,7 @@ of training runs.
|
|
10 |
from pathlib import Path
|
11 |
import argparse
|
12 |
from pydantic import BaseModel, Field, model_validator
|
13 |
-
from typing import Self
|
14 |
|
15 |
|
16 |
class TrainingArguments(BaseModel):
|
@@ -41,9 +41,22 @@ class TrainingArguments(BaseModel):
|
|
41 |
lr: Learning rate for optimizer
|
42 |
max_epoch: Maximum number of training epochs
|
43 |
patience: Early stopping patience in epochs
|
|
|
|
|
|
|
44 |
|
45 |
# Evaluation
|
46 |
test_every_n: Number of epochs between test evaluations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
"""
|
48 |
|
49 |
# Model Configuration
|
@@ -67,10 +80,23 @@ class TrainingArguments(BaseModel):
|
|
67 |
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
68 |
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
69 |
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
|
|
|
|
|
|
70 |
|
71 |
# Evaluation
|
72 |
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
@model_validator(mode='after')
|
75 |
def validate_paths(self) -> Self:
|
76 |
"""Validate path-related arguments.
|
@@ -92,6 +118,13 @@ class TrainingArguments(BaseModel):
|
|
92 |
if not self.model_config_path.suffix == '.yaml':
|
93 |
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
return self
|
96 |
|
97 |
|
@@ -161,58 +194,134 @@ def parse_arguments() -> TrainingArguments:
|
|
161 |
help='Experiment identifier for log folder naming'
|
162 |
)
|
163 |
|
164 |
-
#
|
165 |
-
|
166 |
-
|
167 |
-
'--
|
168 |
-
type=str,
|
169 |
-
default="INFO",
|
170 |
-
help='Logger level for python logging module'
|
171 |
-
)
|
172 |
-
optional.add_argument(
|
173 |
-
'--tensorboard_log_dir',
|
174 |
-
type=Path,
|
175 |
-
default="runs",
|
176 |
-
help='Directory for tensorboard logs'
|
177 |
-
)
|
178 |
-
optional.add_argument(
|
179 |
-
'--python_log_dir',
|
180 |
-
type=Path,
|
181 |
-
default="logs",
|
182 |
-
help='Directory for python logging files'
|
183 |
-
)
|
184 |
-
optional.add_argument(
|
185 |
-
'--test_every_n',
|
186 |
type=int,
|
187 |
-
default=
|
188 |
-
help='
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
)
|
190 |
-
|
191 |
'--max_epoch',
|
192 |
type=int,
|
193 |
default=10,
|
194 |
help='Maximum number of training epochs'
|
195 |
)
|
196 |
-
|
197 |
'--patience',
|
198 |
type=int,
|
199 |
default=3,
|
200 |
help='Early stopping patience (epochs)'
|
201 |
)
|
202 |
-
|
203 |
-
'--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
type=int,
|
205 |
-
default=
|
206 |
-
help='
|
207 |
)
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
)
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
args = parser.parse_args()
|
218 |
|
|
|
10 |
from pathlib import Path
|
11 |
import argparse
|
12 |
from pydantic import BaseModel, Field, model_validator
|
13 |
+
from typing import Self, Optional
|
14 |
|
15 |
|
16 |
class TrainingArguments(BaseModel):
|
|
|
41 |
lr: Learning rate for optimizer
|
42 |
max_epoch: Maximum number of training epochs
|
43 |
patience: Early stopping patience in epochs
|
44 |
+
weight_decay: Weight decay for optimizer
|
45 |
+
gradient_clip_val: Gradient clipping value
|
46 |
+
use_mixed_precision: Whether to use mixed precision training
|
47 |
|
48 |
# Evaluation
|
49 |
test_every_n: Number of epochs between test evaluations
|
50 |
+
|
51 |
+
# Checkpointing
|
52 |
+
save_checkpoints: Whether to save model checkpoints
|
53 |
+
save_best_only: Whether to save only the best model
|
54 |
+
save_every_n_epochs: Save checkpoint every N epochs
|
55 |
+
resume_from_checkpoint: Path to checkpoint to resume from
|
56 |
+
|
57 |
+
# Data Loading
|
58 |
+
num_workers: Number of data loading workers
|
59 |
+
pin_memory: Whether to pin memory for faster GPU transfer
|
60 |
"""
|
61 |
|
62 |
# Model Configuration
|
|
|
80 |
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
81 |
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
82 |
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
83 |
+
weight_decay: float = Field(default=0.0, ge=0.0, description="Weight decay for optimizer")
|
84 |
+
gradient_clip_val: Optional[float] = Field(default=None, gt=0, description="Gradient clipping value")
|
85 |
+
use_mixed_precision: bool = Field(default=False, description="Whether to use mixed precision training")
|
86 |
|
87 |
# Evaluation
|
88 |
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
89 |
|
90 |
+
# Checkpointing
|
91 |
+
save_checkpoints: bool = Field(default=True, description="Whether to save model checkpoints")
|
92 |
+
save_best_only: bool = Field(default=True, description="Whether to save only the best model")
|
93 |
+
save_every_n_epochs: Optional[int] = Field(default=None, gt=0, description="Save checkpoint every N epochs")
|
94 |
+
resume_from_checkpoint: Optional[Path] = Field(default=None, description="Path to checkpoint to resume from")
|
95 |
+
|
96 |
+
# Data Loading
|
97 |
+
num_workers: int = Field(default=4, ge=0, description="Number of data loading workers")
|
98 |
+
pin_memory: bool = Field(default=True, description="Whether to pin memory for faster GPU transfer")
|
99 |
+
|
100 |
@model_validator(mode='after')
|
101 |
def validate_paths(self) -> Self:
|
102 |
"""Validate path-related arguments.
|
|
|
118 |
if not self.model_config_path.suffix == '.yaml':
|
119 |
raise ValueError(f"Model config file must be a .yaml file: {self.model_config_path}")
|
120 |
|
121 |
+
# Validate checkpoint path if provided
|
122 |
+
if self.resume_from_checkpoint is not None:
|
123 |
+
if not self.resume_from_checkpoint.exists():
|
124 |
+
raise ValueError(f"Checkpoint file not found: {self.resume_from_checkpoint}")
|
125 |
+
if not self.resume_from_checkpoint.suffix == '.pt':
|
126 |
+
raise ValueError(f"Checkpoint file must be a .pt file: {self.resume_from_checkpoint}")
|
127 |
+
|
128 |
return self
|
129 |
|
130 |
|
|
|
194 |
help='Experiment identifier for log folder naming'
|
195 |
)
|
196 |
|
197 |
+
# Training hyperparameters
|
198 |
+
training = parser.add_argument_group('training hyperparameters')
|
199 |
+
training.add_argument(
|
200 |
+
'--batch_size',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
type=int,
|
202 |
+
default=64,
|
203 |
+
help='Training batch size'
|
204 |
+
)
|
205 |
+
training.add_argument(
|
206 |
+
'--lr',
|
207 |
+
type=float,
|
208 |
+
default=1e-3,
|
209 |
+
help='Initial learning rate'
|
210 |
)
|
211 |
+
training.add_argument(
|
212 |
'--max_epoch',
|
213 |
type=int,
|
214 |
default=10,
|
215 |
help='Maximum number of training epochs'
|
216 |
)
|
217 |
+
training.add_argument(
|
218 |
'--patience',
|
219 |
type=int,
|
220 |
default=3,
|
221 |
help='Early stopping patience (epochs)'
|
222 |
)
|
223 |
+
training.add_argument(
|
224 |
+
'--weight_decay',
|
225 |
+
type=float,
|
226 |
+
default=0.0,
|
227 |
+
help='Weight decay for optimizer'
|
228 |
+
)
|
229 |
+
training.add_argument(
|
230 |
+
'--gradient_clip_val',
|
231 |
+
type=float,
|
232 |
+
default=None,
|
233 |
+
help='Gradient clipping value (disabled if not specified)'
|
234 |
+
)
|
235 |
+
training.add_argument(
|
236 |
+
'--use_mixed_precision',
|
237 |
+
action='store_true',
|
238 |
+
help='Use mixed precision training (requires PyTorch >= 1.6)'
|
239 |
+
)
|
240 |
+
|
241 |
+
# Evaluation settings
|
242 |
+
evaluation = parser.add_argument_group('evaluation settings')
|
243 |
+
evaluation.add_argument(
|
244 |
+
'--test_every_n',
|
245 |
type=int,
|
246 |
+
default=10,
|
247 |
+
help='Test model every N epochs'
|
248 |
)
|
249 |
|
250 |
+
# Checkpointing settings
|
251 |
+
checkpointing = parser.add_argument_group('checkpointing settings')
|
252 |
+
checkpointing.add_argument(
|
253 |
+
'--save_checkpoints',
|
254 |
+
action='store_true',
|
255 |
+
default=True,
|
256 |
+
help='Save model checkpoints'
|
257 |
+
)
|
258 |
+
checkpointing.add_argument(
|
259 |
+
'--no_save_checkpoints',
|
260 |
+
action='store_false',
|
261 |
+
dest='save_checkpoints',
|
262 |
+
help='Disable saving model checkpoints'
|
263 |
+
)
|
264 |
+
checkpointing.add_argument(
|
265 |
+
'--save_best_only',
|
266 |
+
action='store_true',
|
267 |
+
default=True,
|
268 |
+
help='Save only the best model based on validation loss'
|
269 |
+
)
|
270 |
+
checkpointing.add_argument(
|
271 |
+
'--save_every_n_epochs',
|
272 |
+
type=int,
|
273 |
+
default=None,
|
274 |
+
help='Save checkpoint every N epochs (in addition to best model)'
|
275 |
+
)
|
276 |
+
checkpointing.add_argument(
|
277 |
+
'--resume_from_checkpoint',
|
278 |
+
type=Path,
|
279 |
+
default=None,
|
280 |
+
help='Path to checkpoint file to resume training from'
|
281 |
)
|
282 |
|
283 |
+
# Data loading settings
|
284 |
+
data_loading = parser.add_argument_group('data loading settings')
|
285 |
+
data_loading.add_argument(
|
286 |
+
'--num_workers',
|
287 |
+
type=int,
|
288 |
+
default=4,
|
289 |
+
help='Number of data loading workers'
|
290 |
+
)
|
291 |
+
data_loading.add_argument(
|
292 |
+
'--pin_memory',
|
293 |
+
action='store_true',
|
294 |
+
default=True,
|
295 |
+
help='Pin memory for faster GPU transfer'
|
296 |
+
)
|
297 |
+
data_loading.add_argument(
|
298 |
+
'--no_pin_memory',
|
299 |
+
action='store_false',
|
300 |
+
dest='pin_memory',
|
301 |
+
help='Disable pin memory'
|
302 |
+
)
|
303 |
+
|
304 |
+
# Logging settings
|
305 |
+
logging_group = parser.add_argument_group('logging settings')
|
306 |
+
logging_group.add_argument(
|
307 |
+
'--python_log_level',
|
308 |
+
type=str,
|
309 |
+
default="INFO",
|
310 |
+
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
311 |
+
help='Logger level for python logging module'
|
312 |
+
)
|
313 |
+
logging_group.add_argument(
|
314 |
+
'--tensorboard_log_dir',
|
315 |
+
type=Path,
|
316 |
+
default="runs",
|
317 |
+
help='Directory for tensorboard logs'
|
318 |
+
)
|
319 |
+
logging_group.add_argument(
|
320 |
+
'--python_log_dir',
|
321 |
+
type=Path,
|
322 |
+
default="logs",
|
323 |
+
help='Directory for python logging files'
|
324 |
+
)
|
325 |
|
326 |
args = parser.parse_args()
|
327 |
|
src/main/train_helpers.py
DELETED
@@ -1,265 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Training helper functions for OFDM channel estimation models.
|
3 |
-
|
4 |
-
This module provides utility functions for training, evaluating, and testing
|
5 |
-
deep learning models for OFDM channel estimation tasks. It includes functions
|
6 |
-
for performing training epochs, model evaluation, prediction generation,
|
7 |
-
and performance statistics calculation across different test conditions.
|
8 |
-
"""
|
9 |
-
|
10 |
-
from typing import Dict, List, Tuple, Union, Callable
|
11 |
-
import torch
|
12 |
-
from torch import nn
|
13 |
-
from torch.utils.data import DataLoader
|
14 |
-
from torch.optim import Optimizer
|
15 |
-
from torch.optim.lr_scheduler import ExponentialLR
|
16 |
-
from src.utils import to_db, concat_complex_channel
|
17 |
-
|
18 |
-
# Type aliases
|
19 |
-
ComplexTensor = torch.Tensor # Complex tensor
|
20 |
-
BatchType = Tuple[ComplexTensor, ComplexTensor, Union[Dict, None]]
|
21 |
-
TestDataLoadersType = List[Tuple[str, DataLoader]]
|
22 |
-
StatsType = Dict[int, float]
|
23 |
-
|
24 |
-
|
25 |
-
def get_all_test_stats(
|
26 |
-
model: nn.Module,
|
27 |
-
test_dataloaders: Dict[str, TestDataLoadersType],
|
28 |
-
loss_fn: Callable
|
29 |
-
) -> Tuple[StatsType, StatsType, StatsType]:
|
30 |
-
"""
|
31 |
-
Evaluate model on all test datasets.
|
32 |
-
|
33 |
-
Calculates performance statistics (MSE in dB) for a model across different
|
34 |
-
test conditions: Delay Spread (DS), Max Doppler Shift (MDS), and
|
35 |
-
Signal-to-Noise Ratio (SNR).
|
36 |
-
|
37 |
-
Args:
|
38 |
-
model: Model to evaluate
|
39 |
-
test_dataloaders: Dictionary containing DataLoader objects for test sets:
|
40 |
-
- "DS": Delay Spread test set
|
41 |
-
- "MDS": Max Doppler Shift test set
|
42 |
-
- "SNR": Signal-to-Noise Ratio test set
|
43 |
-
loss_fn: Loss function for evaluation
|
44 |
-
|
45 |
-
Returns:
|
46 |
-
Tuple containing statistics (MSE in dB) for DS, MDS, and SNR test sets,
|
47 |
-
where each set of statistics is a dictionary mapping parameter values to MSE
|
48 |
-
"""
|
49 |
-
ds_stats = get_test_stats(model, test_dataloaders["DS"], loss_fn)
|
50 |
-
mds_stats = get_test_stats(model, test_dataloaders["MDS"], loss_fn)
|
51 |
-
snr_stats = get_test_stats(model, test_dataloaders["SNR"], loss_fn)
|
52 |
-
return ds_stats, mds_stats, snr_stats
|
53 |
-
|
54 |
-
|
55 |
-
def get_test_stats(
|
56 |
-
model: nn.Module,
|
57 |
-
test_dataloaders: TestDataLoadersType,
|
58 |
-
loss_fn: Callable
|
59 |
-
) -> StatsType:
|
60 |
-
"""
|
61 |
-
Evaluate model on provided test dataloaders.
|
62 |
-
|
63 |
-
Calculates performance statistics (MSE in dB) for a model on a
|
64 |
-
specific set of test conditions.
|
65 |
-
|
66 |
-
Args:
|
67 |
-
model: Model to evaluate
|
68 |
-
test_dataloaders: List of (name, DataLoader) tuples for test sets,
|
69 |
-
where names are in format "parameter_value"
|
70 |
-
loss_fn: Loss function for evaluation
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
Dictionary mapping test parameter values (as integers) to MSE values in dB
|
74 |
-
"""
|
75 |
-
stats: StatsType = {}
|
76 |
-
sorted_loaders = sorted(
|
77 |
-
test_dataloaders,
|
78 |
-
key=lambda x: int(x[0].split("_")[1])
|
79 |
-
)
|
80 |
-
|
81 |
-
for name, test_dataloader in sorted_loaders:
|
82 |
-
var, val = name.split("_")
|
83 |
-
test_loss = eval_model(model, test_dataloader, loss_fn)
|
84 |
-
db_error = to_db(test_loss)
|
85 |
-
print(f"{var}:{val} Test MSE: {db_error:.4f} dB")
|
86 |
-
stats[int(val)] = db_error
|
87 |
-
|
88 |
-
return stats
|
89 |
-
|
90 |
-
|
91 |
-
def eval_model(
|
92 |
-
model: nn.Module,
|
93 |
-
eval_dataloader: DataLoader,
|
94 |
-
loss_fn: Callable
|
95 |
-
) -> float:
|
96 |
-
"""
|
97 |
-
Evaluate model on given dataloader.
|
98 |
-
|
99 |
-
Calculates the average loss for a model on a dataset without
|
100 |
-
performing parameter updates.
|
101 |
-
|
102 |
-
Args:
|
103 |
-
model: Model to evaluate
|
104 |
-
eval_dataloader: DataLoader containing evaluation data
|
105 |
-
loss_fn: Loss function for computing error
|
106 |
-
|
107 |
-
Returns:
|
108 |
-
Average validation loss (adjusted for complex values)
|
109 |
-
|
110 |
-
Notes:
|
111 |
-
Loss is multiplied by 2 to account for complex-valued matrices being
|
112 |
-
represented as real-valued matrices of double size.
|
113 |
-
"""
|
114 |
-
val_loss = 0.0
|
115 |
-
model.eval()
|
116 |
-
|
117 |
-
with torch.no_grad():
|
118 |
-
for batch in eval_dataloader:
|
119 |
-
estimated_channel, ideal_channel = _forward_pass(batch, model)
|
120 |
-
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
121 |
-
val_loss += (2 * output.item() * batch[0].size(0))
|
122 |
-
|
123 |
-
val_loss /= sum(len(batch[0]) for batch in eval_dataloader)
|
124 |
-
return val_loss
|
125 |
-
|
126 |
-
|
127 |
-
def predict_channels(
|
128 |
-
model: nn.Module,
|
129 |
-
test_dataloaders: TestDataLoadersType
|
130 |
-
) -> Dict[int, Dict[str, ComplexTensor]]:
|
131 |
-
"""
|
132 |
-
Generate channel predictions for test datasets.
|
133 |
-
|
134 |
-
Creates predictions for a sample from each test dataset to enable
|
135 |
-
visualization and error analysis.
|
136 |
-
|
137 |
-
Args:
|
138 |
-
model: Model to use for predictions
|
139 |
-
test_dataloaders: List of (name, DataLoader) tuples for test sets,
|
140 |
-
where names are in format "parameter_value"
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
Dictionary mapping test parameter values (as integers) to dictionaries containing
|
144 |
-
estimated and ideal channels for a single sample
|
145 |
-
"""
|
146 |
-
channels: Dict[int, Dict[str, ComplexTensor]] = {}
|
147 |
-
sorted_loaders = sorted(
|
148 |
-
test_dataloaders,
|
149 |
-
key=lambda x: int(x[0].split("_")[1])
|
150 |
-
)
|
151 |
-
|
152 |
-
for name, test_dataloader in sorted_loaders:
|
153 |
-
with torch.no_grad():
|
154 |
-
batch = next(iter(test_dataloader))
|
155 |
-
estimated_channels, ideal_channels = _forward_pass(batch, model)
|
156 |
-
|
157 |
-
var, val = name.split("_")
|
158 |
-
channels[int(val)] = {
|
159 |
-
"estimated_channel": estimated_channels[0],
|
160 |
-
"ideal_channel": ideal_channels[0]
|
161 |
-
}
|
162 |
-
|
163 |
-
return channels
|
164 |
-
|
165 |
-
|
166 |
-
def train_epoch(
|
167 |
-
model: nn.Module,
|
168 |
-
optimizer: Optimizer,
|
169 |
-
loss_fn: Callable,
|
170 |
-
scheduler: ExponentialLR,
|
171 |
-
train_dataloader: DataLoader
|
172 |
-
) -> float:
|
173 |
-
"""
|
174 |
-
Train model for one epoch.
|
175 |
-
|
176 |
-
Performs a complete training iteration over the dataset, including:
|
177 |
-
- Forward pass through the model
|
178 |
-
- Loss calculation
|
179 |
-
- Backpropagation
|
180 |
-
- Parameter updates
|
181 |
-
- Learning rate scheduling
|
182 |
-
|
183 |
-
Args:
|
184 |
-
model: Model to train
|
185 |
-
optimizer: Optimizer for updating model parameters
|
186 |
-
loss_fn: Loss function for computing error
|
187 |
-
scheduler: Learning rate scheduler
|
188 |
-
train_dataloader: DataLoader containing training data
|
189 |
-
|
190 |
-
Returns:
|
191 |
-
Average training loss for the epoch (adjusted for complex values)
|
192 |
-
|
193 |
-
Notes:
|
194 |
-
Loss is multiplied by 2 to account for complex-valued matrices being
|
195 |
-
represented as real-valued matrices of double size.
|
196 |
-
"""
|
197 |
-
train_loss = 0.0
|
198 |
-
model.train()
|
199 |
-
|
200 |
-
for batch in train_dataloader:
|
201 |
-
optimizer.zero_grad()
|
202 |
-
estimated_channel, ideal_channel = _forward_pass(batch, model)
|
203 |
-
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
204 |
-
output.backward()
|
205 |
-
optimizer.step()
|
206 |
-
train_loss += (2 * output.item() * batch[0].size(0))
|
207 |
-
|
208 |
-
scheduler.step()
|
209 |
-
train_loss /= sum(len(batch[0]) for batch in train_dataloader)
|
210 |
-
return train_loss
|
211 |
-
|
212 |
-
|
213 |
-
def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, ComplexTensor]:
|
214 |
-
"""
|
215 |
-
Perform forward pass through model.
|
216 |
-
|
217 |
-
Processes input data through the appropriate model based on its type,
|
218 |
-
handling different input requirements for different model architectures.
|
219 |
-
|
220 |
-
Args:
|
221 |
-
batch: Tuple containing (estimated_channel, ideal_channel, metadata)
|
222 |
-
model: Model to use for processing
|
223 |
-
|
224 |
-
Returns:
|
225 |
-
Tuple of (processed_estimated_channel, ideal_channel)
|
226 |
-
|
227 |
-
Raises:
|
228 |
-
ValueError: If model type is not recognized
|
229 |
-
"""
|
230 |
-
estimated_channel, ideal_channel, meta_data = batch
|
231 |
-
|
232 |
-
# All models now handle complex input directly
|
233 |
-
if hasattr(model, 'use_channel_adaptation') and model.use_channel_adaptation:
|
234 |
-
# AdaFortiTran uses meta_data for channel adaptation
|
235 |
-
estimated_channel = model(estimated_channel, meta_data)
|
236 |
-
else:
|
237 |
-
# Linear and FortiTran models don't use meta_data
|
238 |
-
estimated_channel = model(estimated_channel)
|
239 |
-
|
240 |
-
return estimated_channel, ideal_channel.to(model.device)
|
241 |
-
|
242 |
-
|
243 |
-
def _compute_loss(
|
244 |
-
estimated_channel: ComplexTensor,
|
245 |
-
ideal_channel: ComplexTensor,
|
246 |
-
loss_fn: Callable
|
247 |
-
) -> torch.Tensor:
|
248 |
-
"""
|
249 |
-
Calculate loss between estimated and ideal channels.
|
250 |
-
|
251 |
-
Computes the loss between model output and ground truth using the specified
|
252 |
-
loss function, with appropriate handling of complex values.
|
253 |
-
|
254 |
-
Args:
|
255 |
-
estimated_channel: Estimated channel from model
|
256 |
-
ideal_channel: Ground truth ideal channel
|
257 |
-
loss_fn: Loss function to compute error
|
258 |
-
|
259 |
-
Returns:
|
260 |
-
Computed loss value as a scalar tensor
|
261 |
-
"""
|
262 |
-
return loss_fn(
|
263 |
-
concat_complex_channel(estimated_channel),
|
264 |
-
concat_complex_channel(ideal_channel)
|
265 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main/trainer.py
CHANGED
@@ -11,9 +11,12 @@ import torch
|
|
11 |
from torch import nn, optim
|
12 |
from torch.utils.data import DataLoader
|
13 |
from torch.utils.tensorboard.writer import SummaryWriter
|
14 |
-
from typing import Dict, Tuple, Type, Union
|
15 |
import logging
|
16 |
from tqdm import tqdm
|
|
|
|
|
|
|
17 |
|
18 |
from .parser import TrainingArguments
|
19 |
from src.data.dataset import MatDataset, get_test_dataloaders
|
@@ -33,6 +36,291 @@ from src.config.schemas import SystemConfig, ModelConfig
|
|
33 |
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
class ModelTrainer:
|
37 |
"""Handles the training and evaluation of deep learning models.
|
38 |
|
@@ -59,6 +347,9 @@ class ModelTrainer:
|
|
59 |
val_loader: DataLoader for validation set (used for validation)
|
60 |
test_loaders: Dictionary of test set DataLoaders (used for testing)
|
61 |
logger: Logger instance for logging messages
|
|
|
|
|
|
|
62 |
"""
|
63 |
|
64 |
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
@@ -86,13 +377,59 @@ class ModelTrainer:
|
|
86 |
self.logger = logging.getLogger(__name__)
|
87 |
|
88 |
self.model = self._initialize_model()
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
|
91 |
self.early_stopper = EarlyStopping(patience=args.patience)
|
92 |
-
|
93 |
self.training_loss = nn.MSELoss()
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def _setup_tensorboard(self) -> SummaryWriter:
|
98 |
"""Set up TensorBoard logging.
|
@@ -134,26 +471,30 @@ class ModelTrainer:
|
|
134 |
return model
|
135 |
|
136 |
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
|
|
137 |
pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
|
|
|
138 |
# Training and validation dataloaders
|
139 |
-
train_dataset = MatDataset(
|
140 |
-
|
141 |
-
|
142 |
-
)
|
143 |
-
val_dataset = MatDataset(
|
144 |
-
self.args.val_set,
|
145 |
-
pilot_dims
|
146 |
-
)
|
147 |
train_loader = DataLoader(
|
148 |
train_dataset,
|
149 |
batch_size=self.args.batch_size,
|
150 |
-
shuffle=True
|
|
|
|
|
151 |
)
|
|
|
152 |
val_loader = DataLoader(
|
153 |
val_dataset,
|
154 |
batch_size=self.args.batch_size,
|
155 |
-
shuffle=
|
|
|
|
|
156 |
)
|
|
|
|
|
157 |
test_loaders = {
|
158 |
"DS": get_test_dataloaders(
|
159 |
self.args.test_set / "DS_test_set",
|
@@ -173,11 +514,7 @@ class ModelTrainer:
|
|
173 |
}
|
174 |
return train_loader, val_loader, test_loaders
|
175 |
|
176 |
-
def _log_test_results(
|
177 |
-
self,
|
178 |
-
epoch: int,
|
179 |
-
test_stats: Dict[str, Dict]
|
180 |
-
) -> None:
|
181 |
"""Log test results to TensorBoard.
|
182 |
|
183 |
Creates and logs visualizations for model performance across different test conditions.
|
@@ -198,7 +535,7 @@ class ModelTrainer:
|
|
198 |
)
|
199 |
|
200 |
# Plot error images
|
201 |
-
predicted_channels = self.
|
202 |
self.writer.add_figure(
|
203 |
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
204 |
figure=get_error_images(
|
@@ -208,15 +545,20 @@ class ModelTrainer:
|
|
208 |
)
|
209 |
)
|
210 |
|
211 |
-
def _run_tests(self, epoch: int) ->
|
212 |
"""Run tests and log results.
|
213 |
|
214 |
Evaluates the model on all test datasets and logs performance metrics and visualizations.
|
215 |
|
216 |
Args:
|
217 |
epoch: Current training epoch
|
|
|
|
|
|
|
218 |
"""
|
219 |
-
ds_stats
|
|
|
|
|
220 |
|
221 |
test_stats = {
|
222 |
"DS": ds_stats,
|
@@ -225,6 +567,8 @@ class ModelTrainer:
|
|
225 |
}
|
226 |
|
227 |
self._log_test_results(epoch, test_stats)
|
|
|
|
|
228 |
|
229 |
def _log_final_metrics(self, final_epoch: int) -> None:
|
230 |
"""Log final training metrics and hyperparameters.
|
@@ -270,92 +614,84 @@ class ModelTrainer:
|
|
270 |
except Exception as e:
|
271 |
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
272 |
|
273 |
-
def
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
)
|
|
|
278 |
|
279 |
-
def
|
280 |
-
|
|
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
self.
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
self.
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
self.
|
312 |
-
num_samples = 0
|
313 |
-
with torch.no_grad():
|
314 |
-
for batch in eval_dataloader:
|
315 |
-
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
316 |
-
output = self._compute_loss(estimated_channel, ideal_channel, self.training_loss)
|
317 |
-
batch_size = batch[0].size(0)
|
318 |
-
val_loss += (2 * output.item() * batch_size)
|
319 |
-
num_samples += batch_size
|
320 |
-
val_loss /= num_samples
|
321 |
-
return val_loss
|
322 |
-
|
323 |
-
def _predict_channels(self, test_dataloaders):
|
324 |
-
channels = {}
|
325 |
-
sorted_loaders = sorted(
|
326 |
-
test_dataloaders,
|
327 |
-
key=lambda x: int(x[0].split("_")[1])
|
328 |
-
)
|
329 |
-
for name, test_dataloader in sorted_loaders:
|
330 |
-
with torch.no_grad():
|
331 |
-
batch = next(iter(test_dataloader))
|
332 |
-
estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
|
333 |
-
var, val = name.split("_")
|
334 |
-
channels[int(val)] = {
|
335 |
-
"estimated_channel": estimated_channels[0],
|
336 |
-
"ideal_channel": ideal_channels[0]
|
337 |
-
}
|
338 |
-
return channels
|
339 |
|
340 |
-
def
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
-
def
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
def train(self) -> None:
|
361 |
"""Execute the training loop.
|
@@ -366,21 +702,43 @@ class ModelTrainer:
|
|
366 |
- Early stopping when validation loss plateaus
|
367 |
- Logging final metrics and results
|
368 |
"""
|
|
|
|
|
|
|
|
|
369 |
last_epoch = 0
|
370 |
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
|
|
371 |
for epoch in pbar:
|
372 |
last_epoch = epoch
|
|
|
|
|
|
|
|
|
|
|
373 |
# Training step
|
374 |
-
train_loss = self.
|
375 |
-
|
376 |
-
|
377 |
# Validation step
|
378 |
-
val_loss = self.
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
# Update progress bar with loss info
|
382 |
pbar.set_description(
|
383 |
-
f"Epoch {epoch + 1}/{self.args.max_epoch} -
|
|
|
|
|
384 |
|
385 |
if self.early_stopper.early_stop(val_loss):
|
386 |
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
@@ -391,8 +749,12 @@ class ModelTrainer:
|
|
391 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
392 |
pbar.write(message)
|
393 |
self._run_tests(epoch)
|
|
|
394 |
self._log_final_metrics(last_epoch)
|
395 |
-
|
|
|
|
|
|
|
396 |
|
397 |
|
398 |
def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
|
|
|
11 |
from torch import nn, optim
|
12 |
from torch.utils.data import DataLoader
|
13 |
from torch.utils.tensorboard.writer import SummaryWriter
|
14 |
+
from typing import Dict, Tuple, Type, Union, Optional, List, Protocol
|
15 |
import logging
|
16 |
from tqdm import tqdm
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from pathlib import Path
|
19 |
+
from abc import ABC, abstractmethod
|
20 |
|
21 |
from .parser import TrainingArguments
|
22 |
from src.data.dataset import MatDataset, get_test_dataloaders
|
|
|
36 |
ModelType = Union[LinearEstimator, AdaFortiTranEstimator, FortiTranEstimator]
|
37 |
|
38 |
|
39 |
+
@dataclass
|
40 |
+
class TrainingMetrics:
|
41 |
+
"""Container for training metrics."""
|
42 |
+
train_loss: float
|
43 |
+
val_loss: float
|
44 |
+
epoch: int
|
45 |
+
learning_rate: float
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class TestResults:
|
50 |
+
"""Container for test results."""
|
51 |
+
ds_stats: Dict[int, float]
|
52 |
+
mds_stats: Dict[int, float]
|
53 |
+
snr_stats: Dict[int, float]
|
54 |
+
|
55 |
+
|
56 |
+
class Callback(ABC):
|
57 |
+
"""Base class for training callbacks."""
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def on_epoch_begin(self, epoch: int) -> None:
|
61 |
+
"""Called at the beginning of each epoch."""
|
62 |
+
pass
|
63 |
+
|
64 |
+
@abstractmethod
|
65 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
66 |
+
"""Called at the end of each epoch."""
|
67 |
+
pass
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def on_training_begin(self) -> None:
|
71 |
+
"""Called at the beginning of training."""
|
72 |
+
pass
|
73 |
+
|
74 |
+
@abstractmethod
|
75 |
+
def on_training_end(self) -> None:
|
76 |
+
"""Called at the end of training."""
|
77 |
+
pass
|
78 |
+
|
79 |
+
|
80 |
+
class CheckpointCallback(Callback):
|
81 |
+
"""Callback for saving model checkpoints."""
|
82 |
+
|
83 |
+
def __init__(self, save_dir: Path, save_best_only: bool = True,
|
84 |
+
save_every_n_epochs: Optional[int] = None):
|
85 |
+
self.save_dir = save_dir
|
86 |
+
self.save_best_only = save_best_only
|
87 |
+
self.save_every_n_epochs = save_every_n_epochs
|
88 |
+
self.best_val_loss = float('inf')
|
89 |
+
self.trainer = None
|
90 |
+
|
91 |
+
def set_trainer(self, trainer: 'ModelTrainer') -> None:
|
92 |
+
"""Set the trainer reference."""
|
93 |
+
self.trainer = trainer
|
94 |
+
|
95 |
+
def on_epoch_begin(self, epoch: int) -> None:
|
96 |
+
pass
|
97 |
+
|
98 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
99 |
+
if self.trainer is None:
|
100 |
+
return
|
101 |
+
|
102 |
+
# Save best model
|
103 |
+
if self.save_best_only and metrics.val_loss < self.best_val_loss:
|
104 |
+
self.best_val_loss = metrics.val_loss
|
105 |
+
self.trainer.save_checkpoint(
|
106 |
+
epoch, metrics,
|
107 |
+
checkpoint_dir=self.save_dir / "best"
|
108 |
+
)
|
109 |
+
|
110 |
+
# Save every N epochs
|
111 |
+
if (self.save_every_n_epochs is not None and
|
112 |
+
(epoch + 1) % self.save_every_n_epochs == 0):
|
113 |
+
self.trainer.save_checkpoint(
|
114 |
+
epoch, metrics,
|
115 |
+
checkpoint_dir=self.save_dir / "periodic"
|
116 |
+
)
|
117 |
+
|
118 |
+
def on_training_begin(self) -> None:
|
119 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
120 |
+
|
121 |
+
def on_training_end(self) -> None:
|
122 |
+
pass
|
123 |
+
|
124 |
+
|
125 |
+
class TensorBoardCallback(Callback):
|
126 |
+
"""Callback for TensorBoard logging."""
|
127 |
+
|
128 |
+
def __init__(self, writer: SummaryWriter):
|
129 |
+
self.writer = writer
|
130 |
+
|
131 |
+
def on_epoch_begin(self, epoch: int) -> None:
|
132 |
+
pass
|
133 |
+
|
134 |
+
def on_epoch_end(self, epoch: int, metrics: TrainingMetrics) -> None:
|
135 |
+
self.writer.add_scalar('Loss/Train', metrics.train_loss, metrics.epoch + 1)
|
136 |
+
self.writer.add_scalar('Loss/Val', metrics.val_loss, metrics.epoch + 1)
|
137 |
+
self.writer.add_scalar('Learning_Rate', metrics.learning_rate, metrics.epoch + 1)
|
138 |
+
|
139 |
+
def on_training_begin(self) -> None:
|
140 |
+
pass
|
141 |
+
|
142 |
+
def on_training_end(self) -> None:
|
143 |
+
self.writer.close()
|
144 |
+
|
145 |
+
|
146 |
+
class TrainingLoop:
|
147 |
+
"""Handles the core training loop logic."""
|
148 |
+
|
149 |
+
def __init__(self, model: ModelType, optimizer: optim.Optimizer,
|
150 |
+
scheduler: optim.lr_scheduler.LRScheduler,
|
151 |
+
loss_fn: nn.Module, device: torch.device, scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
152 |
+
gradient_clip_val: Optional[float] = None):
|
153 |
+
self.model = model
|
154 |
+
self.optimizer = optimizer
|
155 |
+
self.scheduler = scheduler
|
156 |
+
self.loss_fn = loss_fn
|
157 |
+
self.device = device
|
158 |
+
self.scaler = scaler
|
159 |
+
self.gradient_clip_val = gradient_clip_val
|
160 |
+
|
161 |
+
def _compute_loss(self, estimated_channel: torch.Tensor,
|
162 |
+
ideal_channel: torch.Tensor) -> torch.Tensor:
|
163 |
+
"""Compute loss between estimated and ideal channels."""
|
164 |
+
return self.loss_fn(
|
165 |
+
concat_complex_channel(estimated_channel),
|
166 |
+
concat_complex_channel(ideal_channel)
|
167 |
+
)
|
168 |
+
|
169 |
+
def _forward_pass(self, batch: Tuple[torch.Tensor, torch.Tensor, Tuple],
|
170 |
+
model: ModelType) -> Tuple[torch.Tensor, torch.Tensor]:
|
171 |
+
"""Perform forward pass through the model."""
|
172 |
+
estimated_channel, ideal_channel, meta_data = batch
|
173 |
+
|
174 |
+
# All models now handle complex input directly
|
175 |
+
if isinstance(model, AdaFortiTranEstimator):
|
176 |
+
# AdaFortiTran uses meta_data for channel adaptation
|
177 |
+
estimated_channel = model(estimated_channel, meta_data)
|
178 |
+
else:
|
179 |
+
# Linear and FortiTran models don't use meta_data
|
180 |
+
estimated_channel = model(estimated_channel)
|
181 |
+
|
182 |
+
return estimated_channel, ideal_channel.to(model.device)
|
183 |
+
|
184 |
+
def train_epoch(self, train_loader: DataLoader) -> float:
|
185 |
+
"""Train for one epoch."""
|
186 |
+
train_loss = 0.0
|
187 |
+
self.model.train()
|
188 |
+
num_samples = 0
|
189 |
+
|
190 |
+
for batch in train_loader:
|
191 |
+
self.optimizer.zero_grad()
|
192 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
193 |
+
|
194 |
+
if self.scaler:
|
195 |
+
with torch.cuda.amp.autocast():
|
196 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
197 |
+
self.scaler.scale(loss).backward()
|
198 |
+
|
199 |
+
# Gradient clipping
|
200 |
+
if self.gradient_clip_val:
|
201 |
+
self.scaler.unscale_(self.optimizer)
|
202 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
|
203 |
+
|
204 |
+
self.scaler.step(self.optimizer)
|
205 |
+
self.scaler.update()
|
206 |
+
else:
|
207 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
208 |
+
loss.backward()
|
209 |
+
|
210 |
+
# Gradient clipping
|
211 |
+
if self.gradient_clip_val:
|
212 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
|
213 |
+
|
214 |
+
self.optimizer.step()
|
215 |
+
|
216 |
+
batch_size = batch[0].size(0)
|
217 |
+
train_loss += (2 * loss.item() * batch_size)
|
218 |
+
num_samples += batch_size
|
219 |
+
|
220 |
+
self.scheduler.step()
|
221 |
+
return train_loss / num_samples
|
222 |
+
|
223 |
+
def evaluate(self, eval_loader: DataLoader) -> float:
|
224 |
+
"""Evaluate the model."""
|
225 |
+
val_loss = 0.0
|
226 |
+
self.model.eval()
|
227 |
+
num_samples = 0
|
228 |
+
|
229 |
+
with torch.no_grad():
|
230 |
+
for batch in eval_loader:
|
231 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
232 |
+
|
233 |
+
if self.scaler:
|
234 |
+
with torch.cuda.amp.autocast():
|
235 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
236 |
+
else:
|
237 |
+
loss = self._compute_loss(estimated_channel, ideal_channel)
|
238 |
+
|
239 |
+
batch_size = batch[0].size(0)
|
240 |
+
val_loss += (2 * loss.item() * batch_size)
|
241 |
+
num_samples += batch_size
|
242 |
+
|
243 |
+
return val_loss / num_samples
|
244 |
+
|
245 |
+
|
246 |
+
class ModelEvaluator:
|
247 |
+
"""Handles model evaluation and testing."""
|
248 |
+
|
249 |
+
def __init__(self, model: ModelType, device: torch.device, logger: logging.Logger):
|
250 |
+
self.model = model
|
251 |
+
self.device = device
|
252 |
+
self.logger = logger
|
253 |
+
|
254 |
+
def _forward_pass(self, batch: Tuple[torch.Tensor, torch.Tensor, Tuple],
|
255 |
+
model: ModelType) -> Tuple[torch.Tensor, torch.Tensor]:
|
256 |
+
"""Perform forward pass through the model."""
|
257 |
+
estimated_channel, ideal_channel, meta_data = batch
|
258 |
+
|
259 |
+
if isinstance(model, AdaFortiTranEstimator):
|
260 |
+
estimated_channel = model(estimated_channel, meta_data)
|
261 |
+
else:
|
262 |
+
estimated_channel = model(estimated_channel)
|
263 |
+
|
264 |
+
return estimated_channel, ideal_channel.to(model.device)
|
265 |
+
|
266 |
+
def predict_channels(self, test_dataloaders: List[Tuple[str, DataLoader]]) -> Dict[int, Dict]:
|
267 |
+
"""Predict channels for visualization."""
|
268 |
+
channels = {}
|
269 |
+
sorted_loaders = sorted(
|
270 |
+
test_dataloaders,
|
271 |
+
key=lambda x: int(x[0].split("_")[1])
|
272 |
+
)
|
273 |
+
|
274 |
+
for name, test_dataloader in sorted_loaders:
|
275 |
+
with torch.no_grad():
|
276 |
+
batch = next(iter(test_dataloader))
|
277 |
+
estimated_channels, ideal_channels = self._forward_pass(batch, self.model)
|
278 |
+
|
279 |
+
var, val = name.split("_")
|
280 |
+
channels[int(val)] = {
|
281 |
+
"estimated_channel": estimated_channels[0],
|
282 |
+
"ideal_channel": ideal_channels[0]
|
283 |
+
}
|
284 |
+
return channels
|
285 |
+
|
286 |
+
def get_test_stats(self, test_dataloaders: List[Tuple[str, DataLoader]],
|
287 |
+
loss_fn: nn.Module) -> Dict[int, float]:
|
288 |
+
"""Get test statistics for a set of dataloaders."""
|
289 |
+
stats = {}
|
290 |
+
sorted_loaders = sorted(
|
291 |
+
test_dataloaders,
|
292 |
+
key=lambda x: int(x[0].split("_")[1])
|
293 |
+
)
|
294 |
+
|
295 |
+
for name, test_dataloader in sorted_loaders:
|
296 |
+
var, val = name.split("_")
|
297 |
+
test_loss = self._evaluate_dataloader(test_dataloader, loss_fn)
|
298 |
+
db_error = to_db(test_loss)
|
299 |
+
self.logger.info(f"{var}:{val} Test MSE: {db_error:.4f} dB")
|
300 |
+
stats[int(val)] = db_error
|
301 |
+
return stats
|
302 |
+
|
303 |
+
def _evaluate_dataloader(self, dataloader: DataLoader, loss_fn: nn.Module) -> float:
|
304 |
+
"""Evaluate a single dataloader."""
|
305 |
+
total_loss = 0.0
|
306 |
+
num_samples = 0
|
307 |
+
self.model.eval()
|
308 |
+
|
309 |
+
with torch.no_grad():
|
310 |
+
for batch in dataloader:
|
311 |
+
estimated_channel, ideal_channel = self._forward_pass(batch, self.model)
|
312 |
+
loss = loss_fn(
|
313 |
+
concat_complex_channel(estimated_channel),
|
314 |
+
concat_complex_channel(ideal_channel)
|
315 |
+
)
|
316 |
+
|
317 |
+
batch_size = batch[0].size(0)
|
318 |
+
total_loss += (2 * loss.item() * batch_size)
|
319 |
+
num_samples += batch_size
|
320 |
+
|
321 |
+
return total_loss / num_samples
|
322 |
+
|
323 |
+
|
324 |
class ModelTrainer:
|
325 |
"""Handles the training and evaluation of deep learning models.
|
326 |
|
|
|
347 |
val_loader: DataLoader for validation set (used for validation)
|
348 |
test_loaders: Dictionary of test set DataLoaders (used for testing)
|
349 |
logger: Logger instance for logging messages
|
350 |
+
training_loop: TrainingLoop instance for core training logic
|
351 |
+
evaluator: ModelEvaluator instance for evaluation logic
|
352 |
+
callbacks: List of training callbacks
|
353 |
"""
|
354 |
|
355 |
MODEL_REGISTRY: Dict[str, Type[ModelType]] = {
|
|
|
377 |
self.logger = logging.getLogger(__name__)
|
378 |
|
379 |
self.model = self._initialize_model()
|
380 |
+
|
381 |
+
# Initialize optimizer with weight decay
|
382 |
+
self.optimizer = optim.Adam(
|
383 |
+
self.model.parameters(),
|
384 |
+
lr=args.lr,
|
385 |
+
weight_decay=args.weight_decay
|
386 |
+
)
|
387 |
+
|
388 |
self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.EXP_LR_GAMMA)
|
389 |
self.early_stopper = EarlyStopping(patience=args.patience)
|
|
|
390 |
self.training_loss = nn.MSELoss()
|
391 |
|
392 |
+
# Initialize mixed precision training if requested
|
393 |
+
self.scaler = None
|
394 |
+
if args.use_mixed_precision and self.device.type == 'cuda':
|
395 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
396 |
+
self.logger.info("Mixed precision training enabled")
|
397 |
+
|
398 |
self.train_loader, self.val_loader, self.test_loaders = self._get_dataloaders()
|
399 |
+
|
400 |
+
# Initialize components
|
401 |
+
self.training_loop = TrainingLoop(
|
402 |
+
self.model, self.optimizer, self.scheduler, self.training_loss,
|
403 |
+
self.device, self.scaler, self.args.gradient_clip_val
|
404 |
+
)
|
405 |
+
self.evaluator = ModelEvaluator(self.model, self.device, self.logger)
|
406 |
+
|
407 |
+
# Initialize callbacks
|
408 |
+
self.callbacks = self._setup_callbacks()
|
409 |
+
|
410 |
+
# Resume from checkpoint if specified
|
411 |
+
if args.resume_from_checkpoint is not None:
|
412 |
+
self._resume_from_checkpoint(args.resume_from_checkpoint)
|
413 |
+
|
414 |
+
def _setup_callbacks(self) -> List[Callback]:
|
415 |
+
"""Set up training callbacks."""
|
416 |
+
callbacks = []
|
417 |
+
|
418 |
+
# TensorBoard callback
|
419 |
+
callbacks.append(TensorBoardCallback(self.writer))
|
420 |
+
|
421 |
+
# Checkpoint callback (only if checkpointing is enabled)
|
422 |
+
if self.args.save_checkpoints:
|
423 |
+
checkpoint_dir = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
|
424 |
+
checkpoint_callback = CheckpointCallback(
|
425 |
+
save_dir=checkpoint_dir,
|
426 |
+
save_best_only=self.args.save_best_only,
|
427 |
+
save_every_n_epochs=self.args.save_every_n_epochs
|
428 |
+
)
|
429 |
+
checkpoint_callback.set_trainer(self)
|
430 |
+
callbacks.append(checkpoint_callback)
|
431 |
+
|
432 |
+
return callbacks
|
433 |
|
434 |
def _setup_tensorboard(self) -> SummaryWriter:
|
435 |
"""Set up TensorBoard logging.
|
|
|
471 |
return model
|
472 |
|
473 |
def _get_dataloaders(self) -> Tuple[DataLoader, DataLoader, dict[str, list[tuple[str, DataLoader]]]]:
|
474 |
+
"""Get training, validation, and test dataloaders."""
|
475 |
pilot_dims = [self.system_config.pilot.num_scs, self.system_config.pilot.num_symbols]
|
476 |
+
|
477 |
# Training and validation dataloaders
|
478 |
+
train_dataset = MatDataset(self.args.train_set, pilot_dims)
|
479 |
+
val_dataset = MatDataset(self.args.val_set, pilot_dims)
|
480 |
+
|
|
|
|
|
|
|
|
|
|
|
481 |
train_loader = DataLoader(
|
482 |
train_dataset,
|
483 |
batch_size=self.args.batch_size,
|
484 |
+
shuffle=True,
|
485 |
+
num_workers=self.args.num_workers,
|
486 |
+
pin_memory=self.args.pin_memory and self.device.type == 'cuda'
|
487 |
)
|
488 |
+
|
489 |
val_loader = DataLoader(
|
490 |
val_dataset,
|
491 |
batch_size=self.args.batch_size,
|
492 |
+
shuffle=False, # No need to shuffle validation data
|
493 |
+
num_workers=self.args.num_workers,
|
494 |
+
pin_memory=self.args.pin_memory and self.device.type == 'cuda'
|
495 |
)
|
496 |
+
|
497 |
+
# Test dataloaders
|
498 |
test_loaders = {
|
499 |
"DS": get_test_dataloaders(
|
500 |
self.args.test_set / "DS_test_set",
|
|
|
514 |
}
|
515 |
return train_loader, val_loader, test_loaders
|
516 |
|
517 |
+
def _log_test_results(self, epoch: int, test_stats: Dict[str, Dict]) -> None:
|
|
|
|
|
|
|
|
|
518 |
"""Log test results to TensorBoard.
|
519 |
|
520 |
Creates and logs visualizations for model performance across different test conditions.
|
|
|
535 |
)
|
536 |
|
537 |
# Plot error images
|
538 |
+
predicted_channels = self.evaluator.predict_channels(self.test_loaders[key])
|
539 |
self.writer.add_figure(
|
540 |
tag=f"{key} Error Images (Epoch:{epoch + 1})",
|
541 |
figure=get_error_images(
|
|
|
545 |
)
|
546 |
)
|
547 |
|
548 |
+
def _run_tests(self, epoch: int) -> TestResults:
|
549 |
"""Run tests and log results.
|
550 |
|
551 |
Evaluates the model on all test datasets and logs performance metrics and visualizations.
|
552 |
|
553 |
Args:
|
554 |
epoch: Current training epoch
|
555 |
+
|
556 |
+
Returns:
|
557 |
+
TestResults containing all test statistics
|
558 |
"""
|
559 |
+
ds_stats = self.evaluator.get_test_stats(self.test_loaders["DS"], self.training_loss)
|
560 |
+
mds_stats = self.evaluator.get_test_stats(self.test_loaders["MDS"], self.training_loss)
|
561 |
+
snr_stats = self.evaluator.get_test_stats(self.test_loaders["SNR"], self.training_loss)
|
562 |
|
563 |
test_stats = {
|
564 |
"DS": ds_stats,
|
|
|
567 |
}
|
568 |
|
569 |
self._log_test_results(epoch, test_stats)
|
570 |
+
|
571 |
+
return TestResults(ds_stats, mds_stats, snr_stats)
|
572 |
|
573 |
def _log_final_metrics(self, final_epoch: int) -> None:
|
574 |
"""Log final training metrics and hyperparameters.
|
|
|
614 |
except Exception as e:
|
615 |
self.writer.add_text("Error", f"Failed to log final test results: {str(e)}")
|
616 |
|
617 |
+
def _get_all_test_stats(self) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float]]:
|
618 |
+
"""Get all test statistics."""
|
619 |
+
ds_stats = self.evaluator.get_test_stats(self.test_loaders["DS"], self.training_loss)
|
620 |
+
mds_stats = self.evaluator.get_test_stats(self.test_loaders["MDS"], self.training_loss)
|
621 |
+
snr_stats = self.evaluator.get_test_stats(self.test_loaders["SNR"], self.training_loss)
|
622 |
+
return ds_stats, mds_stats, snr_stats
|
623 |
|
624 |
+
def save_checkpoint(self, epoch: int, metrics: TrainingMetrics,
|
625 |
+
checkpoint_dir: Optional[Path] = None) -> None:
|
626 |
+
"""Save model checkpoint.
|
627 |
|
628 |
+
Args:
|
629 |
+
epoch: Current epoch number
|
630 |
+
metrics: Current training metrics
|
631 |
+
checkpoint_dir: Directory to save checkpoint (defaults to tensorboard log dir)
|
632 |
+
"""
|
633 |
+
if checkpoint_dir is None:
|
634 |
+
checkpoint_dir = self.args.tensorboard_log_dir / f"{self.args.model_name}_{self.args.exp_id}"
|
635 |
+
|
636 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
637 |
+
checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
|
638 |
+
|
639 |
+
checkpoint = {
|
640 |
+
'epoch': epoch,
|
641 |
+
'model_state_dict': self.model.state_dict(),
|
642 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
643 |
+
'scheduler_state_dict': self.scheduler.state_dict(),
|
644 |
+
'train_loss': metrics.train_loss,
|
645 |
+
'val_loss': metrics.val_loss,
|
646 |
+
'learning_rate': metrics.learning_rate,
|
647 |
+
'system_config': self.system_config,
|
648 |
+
'model_config': self.model_config,
|
649 |
+
'args': self.args
|
650 |
+
}
|
651 |
+
|
652 |
+
# Save scaler state if using mixed precision
|
653 |
+
if self.scaler:
|
654 |
+
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
655 |
+
|
656 |
+
torch.save(checkpoint, checkpoint_path)
|
657 |
+
self.logger.info(f"Checkpoint saved to {checkpoint_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
|
659 |
+
def load_checkpoint(self, checkpoint_path: Path) -> int:
|
660 |
+
"""Load model checkpoint.
|
661 |
+
|
662 |
+
Args:
|
663 |
+
checkpoint_path: Path to checkpoint file
|
664 |
+
|
665 |
+
Returns:
|
666 |
+
Epoch number of loaded checkpoint
|
667 |
+
"""
|
668 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
669 |
+
|
670 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
671 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
672 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
673 |
+
|
674 |
+
# Load scaler state if it exists
|
675 |
+
if self.scaler and 'scaler_state_dict' in checkpoint:
|
676 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
677 |
+
|
678 |
+
self.logger.info(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
|
679 |
+
return checkpoint['epoch']
|
680 |
|
681 |
+
def _resume_from_checkpoint(self, checkpoint_path: Path) -> None:
|
682 |
+
"""Resume training from a checkpoint.
|
683 |
+
|
684 |
+
Args:
|
685 |
+
checkpoint_path: Path to checkpoint file
|
686 |
+
"""
|
687 |
+
start_epoch = self.load_checkpoint(checkpoint_path)
|
688 |
+
self.logger.info(f"Resuming training from epoch {start_epoch}")
|
689 |
+
|
690 |
+
# Update the early stopper with the best loss from checkpoint
|
691 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
692 |
+
if 'val_loss' in checkpoint:
|
693 |
+
self.early_stopper.min_loss = checkpoint['val_loss']
|
694 |
+
self.logger.info(f"Early stopper initialized with validation loss: {checkpoint['val_loss']:.4f}")
|
695 |
|
696 |
def train(self) -> None:
|
697 |
"""Execute the training loop.
|
|
|
702 |
- Early stopping when validation loss plateaus
|
703 |
- Logging final metrics and results
|
704 |
"""
|
705 |
+
# Notify callbacks that training is beginning
|
706 |
+
for callback in self.callbacks:
|
707 |
+
callback.on_training_begin()
|
708 |
+
|
709 |
last_epoch = 0
|
710 |
pbar = tqdm(range(self.args.max_epoch), desc="Training")
|
711 |
+
|
712 |
for epoch in pbar:
|
713 |
last_epoch = epoch
|
714 |
+
|
715 |
+
# Notify callbacks that epoch is beginning
|
716 |
+
for callback in self.callbacks:
|
717 |
+
callback.on_epoch_begin(epoch)
|
718 |
+
|
719 |
# Training step
|
720 |
+
train_loss = self.training_loop.train_epoch(self.train_loader)
|
721 |
+
|
|
|
722 |
# Validation step
|
723 |
+
val_loss = self.training_loop.evaluate(self.val_loader)
|
724 |
+
|
725 |
+
# Create metrics object
|
726 |
+
metrics = TrainingMetrics(
|
727 |
+
train_loss=train_loss,
|
728 |
+
val_loss=val_loss,
|
729 |
+
epoch=epoch,
|
730 |
+
learning_rate=self.optimizer.param_groups[0]['lr']
|
731 |
+
)
|
732 |
+
|
733 |
+
# Notify callbacks that epoch has ended
|
734 |
+
for callback in self.callbacks:
|
735 |
+
callback.on_epoch_end(epoch, metrics)
|
736 |
|
737 |
# Update progress bar with loss info
|
738 |
pbar.set_description(
|
739 |
+
f"Epoch {epoch + 1}/{self.args.max_epoch} - "
|
740 |
+
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
|
741 |
+
)
|
742 |
|
743 |
if self.early_stopper.early_stop(val_loss):
|
744 |
pbar.write(f"Early stopping triggered at epoch {epoch + 1}")
|
|
|
749 |
message = f"Test results after epoch {epoch + 1}:\n" + 50 * "-"
|
750 |
pbar.write(message)
|
751 |
self._run_tests(epoch)
|
752 |
+
|
753 |
self._log_final_metrics(last_epoch)
|
754 |
+
|
755 |
+
# Notify callbacks that training has ended
|
756 |
+
for callback in self.callbacks:
|
757 |
+
callback.on_training_end()
|
758 |
|
759 |
|
760 |
def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
|
src/utils.py
CHANGED
@@ -180,24 +180,7 @@ def concat_complex_channel(channel_matrix):
|
|
180 |
return cat_channel_m
|
181 |
|
182 |
|
183 |
-
def inverse_concat_complex_channel(channel_matrix: torch.Tensor) -> torch.Tensor:
|
184 |
-
"""
|
185 |
-
Reconstruct complex channel matrix from concatenated real matrix.
|
186 |
-
|
187 |
-
Reverses the operation performed by concat_complex_channel by
|
188 |
-
splitting the tensor and combining the parts into a complex tensor.
|
189 |
-
|
190 |
-
Args:
|
191 |
-
channel_matrix: Real-valued matrix of shape (B, F, 2*T)
|
192 |
|
193 |
-
Returns:
|
194 |
-
Complex matrix of shape (B, F, T)
|
195 |
-
"""
|
196 |
-
split_idx = channel_matrix.shape[-1] // 2
|
197 |
-
return torch.complex(
|
198 |
-
channel_matrix[:, :split_idx],
|
199 |
-
channel_matrix[:, split_idx:]
|
200 |
-
)
|
201 |
|
202 |
|
203 |
def get_test_stats_plot(x_name, stats, methods, show=False):
|
|
|
180 |
return cat_channel_m
|
181 |
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
|
186 |
def get_test_stats_plot(x_name, stats, methods, show=False):
|