AndrewMayesPrezzee commited on
Commit
8abd44b
·
1 Parent(s): f7451e7

Feat - Block like transformer structure

Browse files
Files changed (10) hide show
  1. .gitignore +2 -0
  2. README.md +389 -377
  3. blocks.py +446 -0
  4. config.json +2 -0
  5. configuration_autoencoder.py +122 -21
  6. model.safetensors +0 -0
  7. modeling_autoencoder.py +122 -748
  8. preprocessing.py +457 -0
  9. template.py +382 -0
  10. utils.py +69 -0
.gitignore CHANGED
@@ -8,3 +8,5 @@ wheels/
8
 
9
  # Virtual environments
10
  .venv
 
 
 
8
 
9
  # Virtual environments
10
  .venv
11
+
12
+ tests/*
README.md CHANGED
@@ -12,491 +12,503 @@ tags:
12
  - scaler
13
  ---
14
 
15
- # Autoencoder Implementation for Hugging Face Transformers
16
 
17
- A complete autoencoder implementation that integrates seamlessly with the Hugging Face Transformers ecosystem, providing all the standard functionality you expect from transformer models.
18
 
 
 
 
 
 
 
19
 
20
- ### Install-and-Use from the Hub (code repo)
21
-
22
- If you want to use the implementation directly from the Hub code repository (without a packaged pip install), you can download the repo and add it to `sys.path`:
23
 
24
  ```python
25
  from huggingface_hub import snapshot_download
26
  import sys, torch
27
 
28
- # 1) Download the code+weights for your repo “as is”
29
  repo_dir = snapshot_download(
30
  repo_id="amaye15/autoencoder",
31
  repo_type="model",
32
- allow_patterns=["*.py", "config.json", "*.safetensors"], # note the * wildcards
33
  )
34
-
35
- # 2) Add to import path so plain imports work
36
  sys.path.append(repo_dir)
37
 
38
- # 3) Import your classes from the repo code
39
- from configuration_autoencoder import AutoencoderConfig
40
  from modeling_autoencoder import AutoencoderForReconstruction
41
-
42
- # 4) Load the placeholder weights from the local folder (no internet, no code refresh)
43
  model = AutoencoderForReconstruction.from_pretrained(repo_dir)
44
 
45
- # 5) Quick smoke test
46
  x = torch.randn(8, 20)
47
  out = model(input_values=x)
48
  print("latent:", out.last_hidden_state.shape, "reconstructed:", out.reconstructed.shape)
49
  ```
50
 
51
- ## 🚀 Features
52
-
53
- - **Full Hugging Face Integration**: Compatible with `AutoModel`, `AutoConfig`, and `AutoTokenizer` patterns
54
- - **Standard Training Workflows**: Works with `Trainer`, `TrainingArguments`, and all HF training utilities
55
- - **Model Hub Compatible**: Save and share models on Hugging Face Hub with `push_to_hub()`
56
- - **Flexible Architecture**: Configurable encoder-decoder architecture with various activation functions
57
- - **Multiple Loss Functions**: Support for MSE, BCE, L1, Huber, Smooth L1, KL Divergence, Cosine, Focal, Dice, Tversky, SSIM, and Perceptual loss
58
- - **Multiple Autoencoder Types (7)**: Classic, Variational (VAE), Beta-VAE, Denoising, Sparse, Contractive, and Recurrent autoencoders
59
- - **Extended Activation Functions**: 18+ activation functions including ReLU, GELU, Swish, Mish, ELU, and more
60
- - **Learnable Preprocessing**: Neural Scaler, Normalizing Flow, MinMax Scaler (learnable), Robust Scaler (learnable), and Yeo-Johnson preprocessors (2D and 3D tensors)
61
- - **Extensible Design**: Easy to extend for new autoencoder variants and custom loss functions
62
- - **Production Ready**: Proper serialization, checkpointing, and inference support
63
 
 
 
 
64
 
65
- ## 🏗️ Architecture
 
66
 
67
- The implementation consists of three main components:
 
 
 
68
 
69
- ### 1. AutoencoderConfig
70
- Configuration class that inherits from `PretrainedConfig`:
71
- - Defines model architecture parameters
72
- - Handles validation and serialization
73
- - Enables `AutoConfig.from_pretrained()` functionality
 
 
 
 
 
74
 
75
- ### 2. AutoencoderModel
76
- Base model class that inherits from `PreTrainedModel`:
77
- - Implements encoder-decoder architecture
78
- - Provides latent space representation
79
- - Returns structured outputs with `AutoencoderOutput`
80
 
81
- ### 3. AutoencoderForReconstruction
82
- Task-specific model for reconstruction:
83
- - Adds reconstruction loss calculation
84
- - Compatible with `Trainer` for easy training
85
- - Returns `AutoencoderForReconstructionOutput` with loss
86
 
87
- ## 🔧 Quick Start
88
 
89
- ### Basic Usage
 
90
 
91
  ```python
92
- from configuration_autoencoder import AutoencoderConfig
93
- from modeling_autoencoder import AutoencoderForReconstruction
94
- import torch
95
-
96
- # Create configuration
97
- config = AutoencoderConfig(
98
- input_dim=784, # Input dimensionality (e.g., 28x28 images flattened)
99
- hidden_dims=[512, 256], # Encoder hidden layers
100
- latent_dim=64, # Latent space dimension
101
- activation="gelu", # Activation function (18+ options available)
102
- reconstruction_loss="mse", # Loss function (12+ options available)
103
- autoencoder_type="classic", # Autoencoder type (7 types available)
104
- # Optional learnable preprocessing
105
- use_learnable_preprocessing=True,
106
- preprocessing_type="neural_scaler", # or "normalizing_flow", "minmax_scaler", "robust_scaler", "yeo_johnson"
107
- )
108
-
109
- # Create model
110
- model = AutoencoderForReconstruction(config)
111
-
112
- # Forward pass
113
- input_data = torch.randn(32, 784) # Batch of 32 samples
114
- outputs = model(input_values=input_data)
115
-
116
- print(f"Reconstruction loss: {outputs.loss}")
117
- print(f"Latent shape: {outputs.last_hidden_state.shape}")
118
- print(f"Reconstructed shape: {outputs.reconstructed.shape}")
119
  ```
120
 
121
-
122
- ### Training with Hugging Face Trainer
123
 
124
  ```python
125
- from transformers import Trainer, TrainingArguments
126
- from torch.utils.data import Dataset
127
-
128
- class AutoencoderDataset(Dataset):
129
- def __init__(self, data):
130
- self.data = torch.FloatTensor(data)
131
-
132
- def __len__(self):
133
- return len(self.data)
134
-
135
- def __getitem__(self, idx):
136
- return {
137
- "input_values": self.data[idx],
138
- "labels": self.data[idx] # For autoencoder, input = target
139
- }
140
-
141
- # Prepare data
142
- train_dataset = AutoencoderDataset(your_training_data)
143
- val_dataset = AutoencoderDataset(your_validation_data)
144
-
145
- # Training arguments
146
- training_args = TrainingArguments(
147
- output_dir="./autoencoder_output",
148
- num_train_epochs=10,
149
- per_device_train_batch_size=64,
150
- per_device_eval_batch_size=64,
151
- warmup_steps=500,
152
- weight_decay=0.01,
153
- logging_dir="./logs",
154
- evaluation_strategy="steps",
155
- eval_steps=500,
156
- save_steps=1000,
157
- load_best_model_at_end=True,
158
- )
159
-
160
- # Create trainer
161
- trainer = Trainer(
162
- model=model,
163
- args=training_args,
164
- train_dataset=train_dataset,
165
- eval_dataset=val_dataset,
166
- )
167
 
168
- # Train
169
- trainer.train()
170
 
171
- # Save model
172
- model.save_pretrained("./my_autoencoder")
173
- config.save_pretrained("./my_autoencoder")
 
 
 
 
 
 
 
 
174
  ```
175
 
176
- ### Using AutoModel Framework
 
177
 
178
  ```python
179
- from register_autoencoder import register_autoencoder_models
180
- from transformers import AutoConfig, AutoModel
181
-
182
- # Register models with AutoModel framework
183
- register_autoencoder_models()
 
 
 
 
 
 
184
 
185
- # Now you can use standard HF patterns
186
- config = AutoConfig.from_pretrained("./my_autoencoder")
187
- model = AutoModel.from_pretrained("./my_autoencoder")
188
 
189
- # Use the model
190
- outputs = model(input_values=your_data)
 
 
 
 
191
  ```
192
 
193
- ## ⚙️ Configuration Options
194
-
195
- The `AutoencoderConfig` class supports extensive customization:
196
 
 
197
  ```python
198
- config = AutoencoderConfig(
199
- input_dim=784, # Input dimension
200
- hidden_dims=[512, 256, 128], # Encoder hidden layers
201
- latent_dim=64, # Latent space dimension
202
- activation="gelu", # Activation function (see full list below)
203
- dropout_rate=0.1, # Dropout rate (0.0 to 1.0)
204
- use_batch_norm=True, # Use batch normalization
205
- tie_weights=False, # Tie encoder/decoder weights
206
- reconstruction_loss="mse", # Loss function (see full list below)
207
- autoencoder_type="variational", # Autoencoder type (see types below)
208
- beta=0.5, # Beta parameter for β-VAE
209
- temperature=1.0, # Temperature for Gumbel softmax
210
- noise_factor=0.1, # Noise factor for denoising AE
211
- # Recurrent autoencoder parameters
212
- rnn_type="lstm", # RNN type: "lstm", "gru", "rnn"
213
- num_layers=2, # Number of RNN layers
214
- bidirectional=True, # Bidirectional encoding
215
- sequence_length=None, # Fixed sequence length (None for variable)
216
- teacher_forcing_ratio=0.5, # Teacher forcing ratio during training
217
- # Learnable preprocessing parameters
218
- use_learnable_preprocessing=False, # Enable learnable preprocessing
219
- preprocessing_type="none", # "none", "neural_scaler", "normalizing_flow"
220
- preprocessing_hidden_dim=64, # Hidden dimension for preprocessing networks
221
- preprocessing_num_layers=2, # Number of layers in preprocessing networks
222
- learn_inverse_preprocessing=True, # Learn inverse transformation
223
- flow_coupling_layers=4, # Number of coupling layers for flows
224
  )
225
  ```
226
 
227
- ### 🎛️ Available Activation Functions
228
-
229
- **Standard Activations:**
230
- - `relu`, `leaky_relu`, `relu6`, `elu`, `prelu`
231
- - `tanh`, `sigmoid`, `hardsigmoid`, `hardtanh`
232
- - `gelu`, `swish`, `silu`, `hardswish`
233
- - `mish`, `softplus`, `softsign`, `tanhshrink`, `threshold`
234
-
235
- ### 📊 Available Loss Functions
236
-
237
- **Regression Losses:**
238
- - `mse` - Mean Squared Error
239
- - `l1` - L1/MAE Loss
240
- - `huber` - Huber Loss
241
- - `smooth_l1` - Smooth L1 Loss
242
-
243
- **Classification/Probability Losses:**
244
- - `bce` - Binary Cross Entropy
245
- - `kl_div` - KL Divergence
246
- - `focal` - Focal Loss
247
-
248
- **Similarity Losses:**
249
- - `cosine` - Cosine Similarity Loss
250
- - `ssim` - Structural Similarity Loss
251
- - `perceptual` - Perceptual Loss
252
-
253
- **Segmentation Losses:**
254
- - `dice` - Dice Loss
255
- - `tversky` - Tversky Loss
256
-
257
- ### 🏗️ Available Autoencoder Types
258
-
259
- **Classic Autoencoder (`classic`)**
260
- - Standard encoder-decoder architecture
261
- - Direct reconstruction loss minimization
262
-
263
- **Variational Autoencoder (`variational`)**
264
- - Probabilistic latent space with mean and variance
265
- - KL divergence regularization
266
- - Reparameterization trick for sampling
267
-
268
- **Beta-VAE (`beta_vae`)**
269
- - Variational autoencoder with adjustable β parameter
270
- - Better disentanglement of latent factors
271
-
272
- **Denoising Autoencoder (`denoising`)**
273
- - Adds noise to input during training
274
- - Learns robust representations
275
- - Configurable noise factor
276
-
277
- **Sparse Autoencoder (`sparse`)**
278
- - Encourages sparse latent representations
279
- - L1 regularization on latent activations
280
- - Useful for feature selection
281
-
282
- **Contractive Autoencoder (`contractive`)**
283
- - Penalizes large gradients of latent w.r.t. input
284
- - Learns smooth manifold representations
285
- - Robust to small input perturbations
286
-
287
- **Recurrent Autoencoder (`recurrent`)**
288
- - LSTM/GRU/RNN encoder-decoder architecture
289
- - Bidirectional encoding for better sequence representations
290
- - Variable length sequence support with padding
291
- - Teacher forcing during training for stable learning
292
- - Sequence-to-sequence reconstruction
293
  ```
294
 
295
- ## 📊 Model Outputs
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- ### AutoencoderOutput
 
 
 
298
 
299
- The base model `AutoencoderModel` returns the following output:
300
- ```
301
- ```python
302
 
303
- @dataclass
304
- class AutoencoderOutput(ModelOutput):
305
- last_hidden_state: torch.FloatTensor = None # Latent representation
306
- reconstructed: torch.FloatTensor = None # Reconstructed input
307
- hidden_states: Tuple[torch.FloatTensor] = None # Intermediate states
308
- attentions: Tuple[torch.FloatTensor] = None # Not used
309
- ```
 
310
 
311
- ### AutoencoderForReconstructionOutput
312
  ```python
313
- @dataclass
314
- class AutoencoderForReconstructionOutput(ModelOutput):
315
- loss: torch.FloatTensor = None # Reconstruction loss
316
- reconstructed: torch.FloatTensor = None # Reconstructed input
317
- last_hidden_state: torch.FloatTensor = None # Latent representation
318
- hidden_states: Tuple[torch.FloatTensor] = None # Intermediate states
 
 
319
  ```
320
 
321
- ## 🔬 Advanced Usage
 
 
 
 
322
 
323
- ### Custom Loss Functions
 
 
 
 
 
 
324
 
325
- You can easily extend the model with custom loss functions:
 
 
 
 
326
 
327
  ```python
328
- class CustomAutoencoder(AutoencoderForReconstruction):
329
- def _compute_reconstruction_loss(self, reconstructed, target):
330
- # Custom loss implementation
331
- return your_custom_loss(reconstructed, target)
332
  ```
333
 
334
- ### Recurrent Autoencoder for Sequences
335
 
336
- Perfect for time series, text, and sequential data:
 
 
337
 
338
  ```python
339
- config = AutoencoderConfig(
340
- input_dim=50, # Feature dimension per timestep
341
- latent_dim=32, # Compressed representation size
342
- autoencoder_type="recurrent",
343
- rnn_type="lstm", # or "gru", "rnn"
344
- num_layers=2, # Number of RNN layers
345
- bidirectional=True, # Bidirectional encoding
346
- teacher_forcing_ratio=0.7, # Teacher forcing during training
347
- sequence_length=None # Variable length sequences
348
- )
349
-
350
- # Usage with sequence data
351
- model = AutoencoderForReconstruction(config)
352
- sequence_data = torch.randn(batch_size, seq_len, input_dim)
353
- outputs = model(input_values=sequence_data)
354
  ```
355
 
356
- ### Learnable Preprocessing
 
 
 
 
 
357
 
358
- Deep learning-based data normalization that adapts to your data:
359
 
 
360
  ```python
361
- # Neural Scaler - Learnable alternative to StandardScaler
362
- config = AutoencoderConfig(
363
- input_dim=20,
364
- latent_dim=10,
365
- use_learnable_preprocessing=True,
366
- preprocessing_type="neural_scaler",
367
- preprocessing_hidden_dim=64
368
- )
369
 
370
- # Normalizing Flow - Invertible transformations
371
- config = AutoencoderConfig(
372
- input_dim=20,
373
- latent_dim=10,
374
- use_learnable_preprocessing=True,
375
- preprocessing_type="normalizing_flow",
376
- flow_coupling_layers=4
377
- )
378
 
379
- # Works with all autoencoder types and sequence data
380
- model = AutoencoderForReconstruction(config)
381
- outputs = model(input_values=data)
382
- print(f"Preprocessing loss: {outputs.preprocessing_loss}")
383
  ```
384
 
 
385
  ```python
386
- # Learnable MinMax Scaler - scales to [0, 1] with learnable bounds
387
- config = AutoencoderConfig(
388
- input_dim=20,
389
- latent_dim=10,
390
- use_learnable_preprocessing=True,
391
- preprocessing_type="minmax_scaler",
392
- )
 
 
 
 
 
393
 
394
- # Learnable Robust Scaler - robust to outliers using median/IQR
395
- config = AutoencoderConfig(
396
- input_dim=20,
397
- latent_dim=10,
398
- use_learnable_preprocessing=True,
399
- preprocessing_type="robust_scaler",
400
- )
401
 
402
- # Learnable Yeo-Johnson - power transform for skewed distributions
403
- config = AutoencoderConfig(
404
- input_dim=20,
405
- latent_dim=10,
406
- use_learnable_preprocessing=True,
407
- preprocessing_type="yeo_johnson",
408
- )
409
  ```
410
 
 
 
 
411
 
412
- ### Variational Autoencoder Extension
 
 
413
 
414
- The configuration supports variational autoencoders:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  ```python
417
- config = AutoencoderConfig(
418
- autoencoder_type="variational",
419
- beta=0.5, # β-VAE parameter
420
- # ... other parameters
421
- )
422
- ```
423
 
424
- ### Integration with Datasets Library
 
 
425
 
426
- ```python
427
- from datasets import Dataset
 
 
 
 
 
 
 
 
 
428
 
429
- # Convert your data to HF Dataset
430
- dataset = Dataset.from_dict({
431
- "input_values": your_data_list
432
- })
 
 
 
 
 
 
433
 
434
- # Use with Trainer
435
  trainer = Trainer(
436
  model=model,
437
- train_dataset=dataset,
438
- # ... other arguments
439
  )
440
- ```
441
 
442
- ## 📁 Project Structure
 
443
 
444
- ```
445
- autoencoder/
446
- ├── __init__.py # Package initialization
447
- ├── configuration_autoencoder.py # Configuration class
448
- ├── modeling_autoencoder.py # Model implementations
449
- ├── register_autoencoder.py # AutoModel registration
450
- ├── pyproject.toml # Project metadata and dependencies
451
- └── README.md # This file
452
  ```
453
 
454
- ## 🤝 Contributing
 
 
 
 
455
 
456
- This implementation follows Hugging Face conventions and can be easily extended:
457
 
458
- 1. **Adding new architectures**: Extend `AutoencoderModel` or create new model classes
459
- 2. **Custom configurations**: Add parameters to `AutoencoderConfig`
460
- 3. **Task-specific heads**: Create new classes like `AutoencoderForReconstruction`
461
- 4. **Integration**: Register new models with the AutoModel framework
462
 
463
- ## 📚 References
464
 
465
- - [Hugging Face Transformers Documentation](https://huggingface.co/docs/transformers)
466
- - [Custom Models Guide](https://huggingface.co/docs/transformers/custom_models)
467
- - [AutoModel Documentation](https://huggingface.co/docs/transformers/model_doc/auto)
 
468
 
469
- ## 🎯 Use Cases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
- This autoencoder implementation is perfect for:
 
 
 
 
472
 
473
- - **Dimensionality Reduction**: Compress high-dimensional data to lower dimensions
474
- - **Anomaly Detection**: Identify outliers based on reconstruction error
475
- - **Data Denoising**: Remove noise from corrupted data
476
- - **Feature Learning**: Learn meaningful representations for downstream tasks
477
- - **Data Generation**: Generate new samples similar to training data
478
- - **Pretraining**: Initialize encoders for other tasks
479
 
480
- ## 🔍 Model Comparison
481
 
482
- | Feature | Standard PyTorch | This Implementation |
483
- |---------|------------------|-------------------|
484
- | HF Integration | ❌ | ✅ |
485
- | AutoModel Support | ❌ | ✅ |
486
- | Trainer Compatible | ❌ | ✅ |
487
- | Hub Integration | ❌ | ✅ |
488
- | Config Management | Manual | ✅ Automatic |
489
- | Serialization | Manual | ✅ Built-in |
490
- | Checkpointing | Manual | ✅ Built-in |
491
 
492
- ## 🚀 Performance Tips
 
 
 
493
 
494
- 1. **Batch Size**: Use larger batch sizes for better GPU utilization
495
- 2. **Learning Rate**: Start with 1e-3 and adjust based on convergence
496
- 3. **Architecture**: Gradually decrease hidden dimensions for better compression
497
- 4. **Regularization**: Use dropout and batch normalization for better generalization
498
- 5. **Loss Function**: Choose appropriate loss based on your data type
 
 
 
 
 
 
499
 
500
- ## 📄 License
 
 
 
 
 
 
501
 
502
- This implementation is provided as an example and follows the same license terms as Hugging Face Transformers.
 
 
12
  - scaler
13
  ---
14
 
15
+ ## Autoencoder for Hugging Face Transformers (Block-based)
16
 
17
+ A flexible, production-grade Autoencoder implementation built to fit naturally into the Transformers ecosystem. It supports a new block-based architecture with ready-to-use templates for classic MLP, VAE/beta-VAE, Transformer, Recurrent, Convolutional, mixed hybrids, and learnable preprocessing.
18
 
19
+ ### Key features
20
+ - Block-based architecture: Linear, Attention, Recurrent (LSTM/GRU), Convolutional, Variational blocks
21
+ - Class-based configuration presets in template.py for quick starts
22
+ - Variational and beta-VAE variants (KL-controlled)
23
+ - Learnable preprocessing and inverse transforms
24
+ - Hugging Face-compatible config/model API and from_pretrained/save_pretrained
25
 
26
+ ## Install and load from the Hub (code repo)
 
 
27
 
28
  ```python
29
  from huggingface_hub import snapshot_download
30
  import sys, torch
31
 
 
32
  repo_dir = snapshot_download(
33
  repo_id="amaye15/autoencoder",
34
  repo_type="model",
35
+ allow_patterns=["*.py", "config.json", "*.safetensors"],
36
  )
 
 
37
  sys.path.append(repo_dir)
38
 
 
 
39
  from modeling_autoencoder import AutoencoderForReconstruction
 
 
40
  model = AutoencoderForReconstruction.from_pretrained(repo_dir)
41
 
 
42
  x = torch.randn(8, 20)
43
  out = model(input_values=x)
44
  print("latent:", out.last_hidden_state.shape, "reconstructed:", out.reconstructed.shape)
45
  ```
46
 
47
+ ## Quickstart with class-based templates
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ ```python
50
+ from modeling_autoencoder import AutoencoderModel
51
+ from template import ClassicAutoencoderConfig
52
 
53
+ cfg = ClassicAutoencoderConfig(input_dim=784, latent_dim=64)
54
+ model = AutoencoderModel(cfg)
55
 
56
+ x = torch.randn(4, 784)
57
+ out = model(x, return_dict=True)
58
+ print(out.last_hidden_state.shape, out.reconstructed.shape)
59
+ ```
60
 
61
+ ### Available presets (template.py)
62
+ - ClassicAutoencoderConfig: Dense MLP AE
63
+ - VariationalAutoencoderConfig: VAE with KL regularization
64
+ - BetaVariationalAutoencoderConfig: beta-VAE (beta > 1)
65
+ - TransformerAutoencoderConfig: Attention-based encoder for sequences
66
+ - RecurrentAutoencoderConfig: LSTM/GRU encoder for sequences
67
+ - ConvolutionalAutoencoderConfig: 1D Conv encoder for sequences
68
+ - ConvAttentionAutoencoderConfig: Mixed Conv + Attention encoder
69
+ - LinearRecurrentAutoencoderConfig: Linear down-projection + RNN
70
+ - PreprocessedAutoencoderConfig: MLP AE with learnable preprocessing
71
 
72
+ ## Block-based architecture
 
 
 
 
73
 
74
+ The autoencoder uses a modular block system where you define encoder_blocks and decoder_blocks as lists of dictionaries. Each block dict specifies its type and parameters.
 
 
 
 
75
 
76
+ ### Available block types
77
 
78
+ #### LinearBlock
79
+ Dense layer with optional normalization, activation, dropout, and residual connections.
80
 
81
  ```python
82
+ {
83
+ "type": "linear",
84
+ "input_dim": 256,
85
+ "output_dim": 128,
86
+ "activation": "relu", # relu, gelu, tanh, sigmoid, etc.
87
+ "normalization": "batch", # batch, layer, group, instance, none
88
+ "dropout_rate": 0.1,
89
+ "use_residual": False, # adds skip connection if input_dim == output_dim
90
+ "residual_scale": 1.0
91
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  ```
93
 
94
+ #### AttentionBlock
95
+ Multi-head self-attention with feed-forward network. Works with 2D (B, D) or 3D (B, T, D) inputs.
96
 
97
  ```python
98
+ {
99
+ "type": "attention",
100
+ "input_dim": 128,
101
+ "num_heads": 8,
102
+ "ffn_dim": 512, # if None, defaults to 4 * input_dim
103
+ "dropout_rate": 0.1
104
+ }
105
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ #### RecurrentBlock
108
+ LSTM, GRU, or vanilla RNN encoder. Outputs final hidden state or all timesteps.
109
 
110
+ ```python
111
+ {
112
+ "type": "recurrent",
113
+ "input_dim": 64,
114
+ "hidden_size": 128,
115
+ "num_layers": 2,
116
+ "rnn_type": "lstm", # lstm, gru, rnn
117
+ "bidirectional": True,
118
+ "dropout_rate": 0.1,
119
+ "output_dim": 128 # final output dimension
120
+ }
121
  ```
122
 
123
+ #### ConvolutionalBlock
124
+ 1D convolution for sequence data. Expects 3D input (B, T, D).
125
 
126
  ```python
127
+ {
128
+ "type": "conv1d",
129
+ "input_dim": 64, # input channels
130
+ "output_dim": 128, # output channels
131
+ "kernel_size": 3,
132
+ "padding": "same", # "same" or integer
133
+ "activation": "relu",
134
+ "normalization": "batch",
135
+ "dropout_rate": 0.1
136
+ }
137
+ ```
138
 
139
+ #### VariationalBlock
140
+ Produces mu and logvar for VAE reparameterization. Used internally by the model when autoencoder_type="variational".
 
141
 
142
+ ```python
143
+ {
144
+ "type": "variational",
145
+ "input_dim": 128,
146
+ "latent_dim": 64
147
+ }
148
  ```
149
 
150
+ ### Custom configuration examples
 
 
151
 
152
+ #### Mixed architecture (Conv + Attention + Linear)
153
  ```python
154
+ from configuration_autoencoder import AutoencoderConfig
155
+
156
+ enc = [
157
+ # 1D convolution for local patterns
158
+ {"type": "conv1d", "input_dim": 64, "output_dim": 128, "kernel_size": 3, "padding": "same", "activation": "relu"},
159
+ {"type": "conv1d", "input_dim": 128, "output_dim": 128, "kernel_size": 3, "padding": "same", "activation": "relu"},
160
+
161
+ # Self-attention for global dependencies
162
+ {"type": "attention", "input_dim": 128, "num_heads": 8, "ffn_dim": 512, "dropout_rate": 0.1},
163
+
164
+ # Final linear projection
165
+ {"type": "linear", "input_dim": 128, "output_dim": 64, "activation": "relu", "normalization": "batch"}
166
+ ]
167
+
168
+ dec = [
169
+ {"type": "linear", "input_dim": 32, "output_dim": 64, "activation": "relu", "normalization": "batch"},
170
+ {"type": "linear", "input_dim": 64, "output_dim": 128, "activation": "relu", "normalization": "batch"},
171
+ {"type": "linear", "input_dim": 128, "output_dim": 64, "activation": "identity", "normalization": "none"}
172
+ ]
173
+
174
+ cfg = AutoencoderConfig(
175
+ input_dim=64,
176
+ latent_dim=32,
177
+ autoencoder_type="classic",
178
+ encoder_blocks=enc,
179
+ decoder_blocks=dec
180
  )
181
  ```
182
 
183
+ #### Hierarchical encoder (multiple scales)
184
+ ```python
185
+ enc = [
186
+ # Local features
187
+ {"type": "linear", "input_dim": 784, "output_dim": 512, "activation": "relu", "normalization": "batch"},
188
+ {"type": "linear", "input_dim": 512, "output_dim": 256, "activation": "relu", "normalization": "batch"},
189
+
190
+ # Mid-level features with residual
191
+ {"type": "linear", "input_dim": 256, "output_dim": 256, "activation": "relu", "normalization": "batch", "use_residual": True},
192
+ {"type": "linear", "input_dim": 256, "output_dim": 256, "activation": "relu", "normalization": "batch", "use_residual": True},
193
+
194
+ # High-level features
195
+ {"type": "linear", "input_dim": 256, "output_dim": 128, "activation": "relu", "normalization": "batch"},
196
+ {"type": "linear", "input_dim": 128, "output_dim": 64, "activation": "relu", "normalization": "batch"}
197
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  ```
199
 
200
+ #### Sequence-to-sequence with recurrent encoder
201
+ ```python
202
+ enc = [
203
+ {"type": "recurrent", "input_dim": 100, "hidden_size": 128, "num_layers": 2, "rnn_type": "lstm", "bidirectional": True, "output_dim": 256},
204
+ {"type": "linear", "input_dim": 256, "output_dim": 128, "activation": "tanh", "normalization": "layer"}
205
+ ]
206
+
207
+ dec = [
208
+ {"type": "linear", "input_dim": 64, "output_dim": 128, "activation": "tanh", "normalization": "layer"},
209
+ {"type": "linear", "input_dim": 128, "output_dim": 100, "activation": "identity", "normalization": "none"}
210
+ ]
211
+ ```
212
 
213
+ ### Input shape handling
214
+ - **2D inputs (B, D)**: Work with Linear blocks directly. Attention/Recurrent/Conv blocks treat as (B, 1, D)
215
+ - **3D inputs (B, T, D)**: Work with all block types. Linear blocks operate per-timestep
216
+ - **Output shapes**: Decoder typically outputs same shape as input. For sequence models, final shape depends on decoder architecture
217
 
218
+ ## Configuration (configuration_autoencoder.py)
 
 
219
 
220
+ AutoencoderConfig is the core configuration class. Important fields:
221
+ - input_dim: feature dimension (D)
222
+ - latent_dim: latent size
223
+ - encoder_blocks, decoder_blocks: block lists (see block types above)
224
+ - activation, dropout_rate, use_batch_norm: defaults used by some presets
225
+ - autoencoder_type: classic | variational | beta_vae | denoising | sparse | contractive | recurrent
226
+ - Reconstruction losses: mse | bce | l1 | huber | smooth_l1 | kl_div | cosine | focal | dice | tversky | ssim | perceptual
227
+ - Preprocessing: use_learnable_preprocessing, preprocessing_type, learn_inverse_preprocessing
228
 
229
+ Example:
230
  ```python
231
+ from configuration_autoencoder import AutoencoderConfig
232
+ cfg = AutoencoderConfig(
233
+ input_dim=128,
234
+ latent_dim=32,
235
+ autoencoder_type="variational",
236
+ encoder_blocks=[{"type": "linear", "input_dim": 128, "output_dim": 64, "activation": "relu"}],
237
+ decoder_blocks=[{"type": "linear", "input_dim": 32, "output_dim": 128, "activation": "identity", "normalization": "none"}],
238
+ )
239
  ```
240
 
241
+ ## Models (modeling_autoencoder.py)
242
+
243
+ Main classes:
244
+ - AutoencoderModel: core module exposing forward that returns last_hidden_state (latent) and reconstructed
245
+ - AutoencoderForReconstruction: HF-compatible model wrapper with from_pretrained/save_pretrained
246
 
247
+ Forward usage:
248
+ ```python
249
+ from modeling_autoencoder import AutoencoderModel
250
+ x = torch.randn(8, 20)
251
+ out = model(x, return_dict=True)
252
+ print(out.last_hidden_state.shape, out.reconstructed.shape)
253
+ ```
254
 
255
+ ### Variational behavior
256
+ If cfg.autoencoder_type == "variational" or "beta_vae":
257
+ - The model uses an internal VariationalBlock to compute mu and logvar
258
+ - Samples z during training; uses mu during eval
259
+ - KL term available via model._mu/_logvar (exposed in hidden_states when requested)
260
 
261
  ```python
262
+ out = model(x, return_dict=True, output_hidden_states=True)
263
+ latent, mu, logvar = out.hidden_states
 
 
264
  ```
265
 
266
+ ## Preprocessing (preprocessing.py)
267
 
268
+ - PreprocessingBlock wraps LearnablePreprocessor and can be placed before/after the core encoder/decoder
269
+ - When enabled via config.use_learnable_preprocessing, the model constructs two blocks: pre (forward) and post (inverse)
270
+ - The block tracks reg_loss, which is added to preprocessing_loss in the model output
271
 
272
  ```python
273
+ from template import PreprocessedAutoencoderConfig
274
+ cfg = PreprocessedAutoencoderConfig(input_dim=64, latent_dim=32, preprocessing_type="neural_scaler")
275
+ model = AutoencoderModel(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
276
  ```
277
 
278
+ ## Utilities (utils.py)
279
+
280
+ Common helpers:
281
+ - _get_activation(name)
282
+ - _get_norm(name, num_groups=None)
283
+ - _flatten_3d_to_2d(x), _maybe_restore_3d(x, ref)
284
 
285
+ ## Training examples
286
 
287
+ ### Basic MSE reconstruction
288
  ```python
289
+ from modeling_autoencoder import AutoencoderModel
290
+ from template import ClassicAutoencoderConfig
 
 
 
 
 
 
291
 
292
+ cfg = ClassicAutoencoderConfig(input_dim=784, latent_dim=64)
293
+ model = AutoencoderModel(cfg)
294
+ opt = torch.optim.Adam(model.parameters(), lr=1e-3)
 
 
 
 
 
295
 
296
+ for x in dataloader: # x: (B, 784)
297
+ out = model(x, return_dict=True)
298
+ loss = torch.nn.functional.mse_loss(out.reconstructed, x)
299
+ loss.backward(); opt.step(); opt.zero_grad()
300
  ```
301
 
302
+ ### VAE with KL term
303
  ```python
304
+ from template import VariationalAutoencoderConfig
305
+ cfg = VariationalAutoencoderConfig(input_dim=784, latent_dim=32)
306
+ model = AutoencoderModel(cfg)
307
+
308
+ for x in dataloader:
309
+ out = model(x, return_dict=True, output_hidden_states=True)
310
+ recon = torch.nn.functional.mse_loss(out.reconstructed, x)
311
+ _, mu, logvar = out.hidden_states
312
+ kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
313
+ loss = recon + cfg.beta * kl
314
+ loss.backward(); opt.step(); opt.zero_grad()
315
+ ```
316
 
317
+ ### Sequence reconstruction (Conv + Attention)
318
+ ```python
319
+ from template import ConvAttentionAutoencoderConfig
320
+ cfg = ConvAttentionAutoencoderConfig(input_dim=64, latent_dim=64)
321
+ model = AutoencoderModel(cfg)
 
 
322
 
323
+ x = torch.randn(8, 50, 64) # (B, T, D)
324
+ out = model(x, return_dict=True)
 
 
 
 
 
325
  ```
326
 
327
+ ## End-to-end saving/loading
328
+ ```python
329
+ from modeling_autoencoder import AutoencoderForReconstruction
330
 
331
+ model.save_pretrained("./my_ae")
332
+ reloaded = AutoencoderForReconstruction.from_pretrained("./my_ae")
333
+ ```
334
 
335
+ ## Troubleshooting
336
+ - Check that block input_dim/output_dim align across adjacent blocks
337
+ - For attention/recurrent/conv blocks, prefer 3D inputs (B, T, D). 2D inputs are coerced to (B, 1, D)
338
+ - For variational/beta-VAE, ensure latent_dim is set; KL term available via hidden states
339
+ - When preprocessing is enabled, preprocessing_loss is included in the output for logging/regularization
340
+
341
+
342
+ ## Full AutoencoderConfig reference
343
+
344
+ Below is a comprehensive reference for all fields in configuration_autoencoder.AutoencoderConfig. Some fields are primarily used by presets or advanced features but are documented here for completeness.
345
+
346
+ - input_dim (int, default=784): Input feature dimension D. For sequences, D is per-timestep feature size.
347
+ - hidden_dims (List[int], default=[512,256,128]): Legacy convenience list for simple MLPs. Prefer encoder_blocks.
348
+ - encoder_blocks (List[dict] | None): Block list for encoder. See Block-based architecture for block schemas.
349
+ - decoder_blocks (List[dict] | None): Block list for decoder. If omitted, model may derive a simple decoder from hidden_dims.
350
+ - latent_dim (int, default=64): Latent space dimension.
351
+ - activation (str, default="relu"): Default activation for Linear blocks when using legacy paths or presets.
352
+ - dropout_rate (float, default=0.1): Default dropout used in presets and some layers.
353
+ - use_batch_norm (bool, default=True): Default normalization flag used in presets ("batch" if True, else "none").
354
+ - tie_weights (bool, default=False): If True, share/tie encoder and decoder weights (feature not always active depending on architecture).
355
+ - reconstruction_loss (str, default="mse"): Which loss to use in AutoencoderForReconstruction. One of:
356
+ - "mse", "bce", "l1", "huber", "smooth_l1", "kl_div", "cosine", "focal", "dice", "tversky", "ssim", "perceptual".
357
+ - autoencoder_type (str, default="classic"): Architecture variant. One of:
358
+ - "classic", "variational", "beta_vae", "denoising", "sparse", "contractive", "recurrent".
359
+ - beta (float, default=1.0): KL weight for VAE/beta-VAE.
360
+ - temperature (float, default=1.0): Reserved for temperature-based operations.
361
+ - noise_factor (float, default=0.1): Denoising strength used by Denoising variants.
362
+ - rnn_type (str, default="lstm"): For recurrent variants. One of: "lstm", "gru", "rnn".
363
+ - num_layers (int, default=2): Number of RNN layers for recurrent variants.
364
+ - bidirectional (bool, default=True): Whether RNN is bidirectional in recurrent variants.
365
+ - sequence_length (int | None, default=None): Optional fixed sequence length; if None, variable length is supported.
366
+ - teacher_forcing_ratio (float, default=0.5): For recurrent decoders that use teacher forcing.
367
+ - use_learnable_preprocessing (bool, default=False): Enable learnable preprocessing.
368
+ - preprocessing_type (str, default="none"): One of: "none", "neural_scaler", "normalizing_flow", "minmax_scaler", "robust_scaler", "yeo_johnson".
369
+ - preprocessing_hidden_dim (int, default=64): Hidden size for preprocessing networks.
370
+ - preprocessing_num_layers (int, default=2): Number of layers for preprocessing networks.
371
+ - learn_inverse_preprocessing (bool, default=True): Whether to learn inverse transform for reconstruction.
372
+ - flow_coupling_layers (int, default=4): Number of coupling layers for normalizing flows.
373
+
374
+ Derived helpers and flags:
375
+ - has_block_lists: True if either encoder_blocks or decoder_blocks is provided.
376
+ - is_variational: True if autoencoder_type in {"variational", "beta_vae"}.
377
+ - is_denoising, is_sparse, is_contractive, is_recurrent: Variant flags.
378
+ - has_preprocessing: True if preprocessing enabled and type != "none".
379
+
380
+ Validation notes:
381
+ - activation must be one of the supported list in configuration_autoencoder.py
382
+ - reconstruction_loss must be one of the supported list
383
+ - Many numeric parameters are validated to be positive or within [0,1]
384
+
385
+ ## Training with Hugging Face Trainer
386
+
387
+ The AutoencoderForReconstruction model computes reconstruction loss internally using config.reconstruction_loss. For VAEs/beta-VAEs, it adds the KL term scaled by config.beta. You can plug it directly into transformers.Trainer.
388
 
389
  ```python
390
+ from transformers import Trainer, TrainingArguments
391
+ from modeling_autoencoder import AutoencoderForReconstruction
392
+ from template import ClassicAutoencoderConfig
393
+ import torch
394
+ from torch.utils.data import Dataset
 
395
 
396
+ # 1) Config and model
397
+ cfg = ClassicAutoencoderConfig(input_dim=64, latent_dim=16)
398
+ model = AutoencoderForReconstruction(cfg)
399
 
400
+ # 2) Dummy dataset (replace with your own)
401
+ class ToyAEDataset(Dataset):
402
+ def __init__(self, n=1024, d=64):
403
+ self.x = torch.randn(n, d)
404
+ def __len__(self):
405
+ return self.x.size(0)
406
+ def __getitem__(self, idx):
407
+ xi = self.x[idx]
408
+ return {"input_values": xi, "labels": xi}
409
+
410
+ train_ds = ToyAEDataset()
411
 
412
+ # 3) TrainingArguments
413
+ args = TrainingArguments(
414
+ output_dir="./ae-trainer",
415
+ per_device_train_batch_size=64,
416
+ learning_rate=1e-3,
417
+ num_train_epochs=3,
418
+ logging_steps=50,
419
+ save_steps=200,
420
+ report_to=[], # disable wandb if not configured
421
+ )
422
 
423
+ # 4) Trainer
424
  trainer = Trainer(
425
  model=model,
426
+ args=args,
427
+ train_dataset=train_ds,
428
  )
 
429
 
430
+ # 5) Train
431
+ trainer.train()
432
 
433
+ # 6) Use the model
434
+ x = torch.randn(4, 64)
435
+ out = model(input_values=x, return_dict=True)
436
+ print(out.last_hidden_state.shape, out.reconstructed.shape)
 
 
 
 
437
  ```
438
 
439
+ Notes:
440
+ - The dataset must yield dicts with "input_values" and optionally "labels"; if labels are missing, the model uses input as the target.
441
+ - For sequence inputs, shape is (B, T, D). For simple vectors, (B, D).
442
+ - Set cfg.reconstruction_loss to e.g. "bce" to switch the internal loss (the decoder head applies sigmoid when BCE is used).
443
+ - For VAE/beta-VAE, use VariationalAutoencoderConfig/BetaVariationalAutoencoderConfig.
444
 
 
445
 
446
+ ### Example using AutoencoderConfig directly
 
 
 
447
 
448
+ Below shows how to define a configuration purely with block dicts using AutoencoderConfig, without the template classes.
449
 
450
+ ```python
451
+ from configuration_autoencoder import AutoencoderConfig
452
+ from modeling_autoencoder import AutoencoderModel
453
+ import torch
454
 
455
+ # Encoder: Linear -> Attention -> Linear
456
+ enc = [
457
+ {"type": "linear", "input_dim": 128, "output_dim": 128, "activation": "relu", "normalization": "batch", "dropout_rate": 0.1},
458
+ {"type": "attention", "input_dim": 128, "num_heads": 4, "ffn_dim": 512, "dropout_rate": 0.1},
459
+ {"type": "linear", "input_dim": 128, "output_dim": 64, "activation": "relu", "normalization": "batch"},
460
+ ]
461
+
462
+ # Decoder: Linear -> Linear (final identity)
463
+ dec = [
464
+ {"type": "linear", "input_dim": 32, "output_dim": 64, "activation": "relu", "normalization": "batch"},
465
+ {"type": "linear", "input_dim": 64, "output_dim": 128, "activation": "identity", "normalization": "none"},
466
+ ]
467
+
468
+ cfg = AutoencoderConfig(
469
+ input_dim=128,
470
+ latent_dim=32,
471
+ encoder_blocks=enc,
472
+ decoder_blocks=dec,
473
+ autoencoder_type="classic",
474
+ )
475
 
476
+ model = AutoencoderModel(cfg)
477
+ x = torch.randn(4, 128)
478
+ out = model(x, return_dict=True)
479
+ print(out.last_hidden_state.shape, out.reconstructed.shape)
480
+ ```
481
 
482
+ For a variational model, set autoencoder_type="variational" and the model will internally use a VariationalBlock for mu/logvar and sampling.
 
 
 
 
 
483
 
 
484
 
485
+ ## Learnable preprocessing
486
+ Enable learnable preprocessing and its inverse with the PreprocessedAutoencoderConfig class or via flags.
 
 
 
 
 
 
 
487
 
488
+ ```python
489
+ from template import PreprocessedAutoencoderConfig
490
+ cfg = PreprocessedAutoencoderConfig(input_dim=64, latent_dim=32, preprocessing_type="neural_scaler")
491
+ ```
492
 
493
+ Supported preprocessing_type values include: "neural_scaler", "normalizing_flow", "minmax_scaler", "robust_scaler", "yeo_johnson".
494
+
495
+ ## Saving and loading
496
+ ```python
497
+ from modeling_autoencoder import AutoencoderForReconstruction
498
+
499
+ # Save
500
+ model.save_pretrained("./my_ae")
501
+ # Load
502
+ reloaded = AutoencoderForReconstruction.from_pretrained("./my_ae")
503
+ ```
504
 
505
+ ## Reference
506
+ Core modules:
507
+ - configuration_autoencoder.AutoencoderConfig
508
+ - modeling_autoencoder.AutoencoderModel, AutoencoderForReconstruction
509
+ - blocks: BlockFactory, BlockSequence, Linear/Attention/Recurrent/Convolutional/Variational blocks
510
+ - preprocessing: PreprocessingBlock (learnable preprocessing wrapper)
511
+ - template: class-based presets listed above
512
 
513
+ ## License
514
+ Apache-2.0 (see LICENSE)
blocks.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modular, block-based components for building autoencoders in PyTorch.
3
+
4
+ Core goals:
5
+ - Composable building blocks with consistent interfaces
6
+ - Support 2D (B, F) and 3D (B, T, F) tensors where applicable
7
+ - Simple configs to construct blocks and sequences
8
+ - Safe-by-default validation and helpful errors
9
+
10
+ This module is intentionally self-contained to allow gradual integration with
11
+ existing models. It does not mutate current behavior.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ # Import config dataclasses that define block configurations
23
+ try:
24
+ from .configuration_autoencoder import (
25
+ BlockConfig,
26
+ LinearBlockConfig,
27
+ AttentionBlockConfig,
28
+ RecurrentBlockConfig,
29
+ ConvolutionalBlockConfig,
30
+ VariationalBlockConfig,
31
+ )
32
+ except Exception:
33
+ from configuration_autoencoder import (
34
+ BlockConfig,
35
+ LinearBlockConfig,
36
+ AttentionBlockConfig,
37
+ RecurrentBlockConfig,
38
+ ConvolutionalBlockConfig,
39
+ VariationalBlockConfig,
40
+ )
41
+
42
+
43
+ # Import shared utilities
44
+ try:
45
+ from .utils import _get_activation, _get_norm, _flatten_3d_to_2d, _maybe_restore_3d
46
+ except Exception:
47
+ from utils import _get_activation, _get_norm, _flatten_3d_to_2d, _maybe_restore_3d
48
+
49
+
50
+ # ---------------------------- Base Block ---------------------------- #
51
+
52
+ class BaseBlock(nn.Module):
53
+ """Abstract base for all blocks.
54
+
55
+ All blocks should accept 2D (B, F) or 3D (B, T, F) tensors and return the
56
+ same rank, with last-dim equal to `output_dim`.
57
+ """
58
+
59
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: # pragma: no cover - abstract
60
+ raise NotImplementedError
61
+
62
+ @property
63
+ def output_dim(self) -> int: # pragma: no cover - abstract
64
+ raise NotImplementedError
65
+
66
+
67
+ # ---------------------------- Residual Base ---------------------------- #
68
+
69
+ class ResidualBlock(BaseBlock):
70
+ """Base class for blocks supporting residual connections.
71
+
72
+ Implements a safe residual add when input and output dims match; otherwise
73
+ falls back to a learned projection. Residuals can be scaled.
74
+ """
75
+
76
+ def __init__(self, residual: bool = False, residual_scale: float = 1.0, proj_dim_in: Optional[int] = None, proj_dim_out: Optional[int] = None):
77
+ super().__init__()
78
+ self.use_residual = residual
79
+ self.residual_scale = residual_scale
80
+ self._proj: Optional[nn.Module] = None
81
+ if residual and proj_dim_in is not None and proj_dim_out is not None and proj_dim_in != proj_dim_out:
82
+ self._proj = nn.Linear(proj_dim_in, proj_dim_out)
83
+
84
+ def _apply_residual(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85
+ if not self.use_residual:
86
+ return y
87
+ x2d, hint = _flatten_3d_to_2d(x)
88
+ y2d, _ = _flatten_3d_to_2d(y)
89
+ if x2d.shape[-1] != y2d.shape[-1]:
90
+ if self._proj is None:
91
+ self._proj = nn.Linear(x2d.shape[-1], y2d.shape[-1]).to(y2d.device)
92
+ x2d = self._proj(x2d)
93
+ out = x2d + self.residual_scale * y2d
94
+ return _maybe_restore_3d(out, hint)
95
+
96
+
97
+ # ---------------------------- LinearBlock ---------------------------- #
98
+
99
+ class LinearBlock(ResidualBlock):
100
+ """Basic linear transformation with normalization and activation.
101
+
102
+ - Handles both 2D (B, F) and 3D (B, T, F) tensors
103
+ - Optional normalization: batch|layer|group|instance|none
104
+ - Configurable activation
105
+ - Optional dropout
106
+ - Optional residual connection (with auto projection)
107
+ """
108
+
109
+ def __init__(self, cfg: LinearBlockConfig):
110
+ super().__init__(residual=cfg.use_residual, residual_scale=cfg.residual_scale, proj_dim_in=cfg.input_dim, proj_dim_out=cfg.output_dim)
111
+ self.cfg = cfg
112
+
113
+ self.linear = nn.Linear(cfg.input_dim, cfg.output_dim)
114
+ # Normalizations that expect N, C require 2D tensors; for 3D we flatten
115
+ # For LayerNorm, it supports last-dim directly
116
+ if cfg.normalization == "layer":
117
+ self.norm = nn.LayerNorm(cfg.output_dim)
118
+ else:
119
+ self.norm = _get_norm(cfg.normalization, cfg.output_dim)
120
+ self.act = _get_activation(cfg.activation)
121
+ self.drop = nn.Dropout(cfg.dropout_rate) if cfg.dropout_rate and cfg.dropout_rate > 0 else nn.Identity()
122
+
123
+ @property
124
+ def output_dim(self) -> int:
125
+ return self.cfg.output_dim
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ x_in = x
129
+ x2d, hint = _flatten_3d_to_2d(x)
130
+ y = self.linear(x2d)
131
+ # Apply norm safely
132
+ if isinstance(self.norm, (nn.BatchNorm1d, nn.InstanceNorm1d, nn.GroupNorm)):
133
+ y = self.norm(y)
134
+ else:
135
+ # LayerNorm or Identity operates on last dim and supports both 2D/3D; we already have 2D
136
+ y = self.norm(y)
137
+ y = self.act(y)
138
+ y = self.drop(y)
139
+ y = _maybe_restore_3d(y, hint)
140
+ return self._apply_residual(x_in, y)
141
+
142
+
143
+ # ---------------------------- AttentionBlock ---------------------------- #
144
+
145
+ class AttentionBlock(BaseBlock):
146
+ """Multi-head self-attention with optional FFN.
147
+
148
+ Expects inputs as 3D (B, T, D) or 2D (B, D) which will be treated as (B, 1, D).
149
+ Supports optional attn mask and key padding mask via kwargs.
150
+ """
151
+
152
+ def __init__(self, cfg: AttentionBlockConfig):
153
+ super().__init__()
154
+ self.cfg = cfg
155
+ d_model = cfg.input_dim
156
+ self.mha = nn.MultiheadAttention(d_model, num_heads=cfg.num_heads, dropout=cfg.dropout_rate, batch_first=True)
157
+ self.ln1 = nn.LayerNorm(d_model)
158
+ ffn_dim = cfg.ffn_dim or (4 * d_model)
159
+ self.ffn = nn.Sequential(
160
+ nn.Linear(d_model, ffn_dim),
161
+ _get_activation("gelu"),
162
+ nn.Dropout(cfg.dropout_rate),
163
+ nn.Linear(ffn_dim, d_model),
164
+ )
165
+ self.ln2 = nn.LayerNorm(d_model)
166
+ self.dropout = nn.Dropout(cfg.dropout_rate)
167
+
168
+ @property
169
+ def output_dim(self) -> int:
170
+ return self.cfg.input_dim
171
+
172
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
173
+ if x.dim() == 2:
174
+ x = x.unsqueeze(1)
175
+ squeeze_back = True
176
+ else:
177
+ squeeze_back = False
178
+ # Self-attention
179
+ residual = x
180
+ attn_out, _ = self.mha(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
181
+ x = self.ln1(residual + self.dropout(attn_out))
182
+ # FFN
183
+ residual = x
184
+ x = self.ffn(x)
185
+ x = self.ln2(residual + self.dropout(x))
186
+ if squeeze_back:
187
+ x = x.squeeze(1)
188
+ return x
189
+
190
+
191
+ # ---------------------------- RecurrentBlock ---------------------------- #
192
+
193
+ class RecurrentBlock(BaseBlock):
194
+ """RNN processing block supporting LSTM/GRU/RNN.
195
+
196
+ Input: 3D (B, T, F) preferred. If 2D, treated as (B, 1, F).
197
+ Output dim equals cfg.output_dim if set; otherwise hidden_size * directions.
198
+ """
199
+
200
+ def __init__(self, cfg: RecurrentBlockConfig):
201
+ super().__init__()
202
+ self.cfg = cfg
203
+ rnn_type = cfg.rnn_type.lower()
204
+ rnn_cls = {"lstm": nn.LSTM, "gru": nn.GRU, "rnn": nn.RNN}.get(rnn_type)
205
+ if rnn_cls is None:
206
+ raise ValueError(f"Unknown rnn_type: {cfg.rnn_type}")
207
+ self.rnn = rnn_cls(
208
+ input_size=cfg.input_dim,
209
+ hidden_size=cfg.hidden_size,
210
+ num_layers=cfg.num_layers,
211
+ batch_first=True,
212
+ dropout=cfg.dropout_rate if cfg.num_layers > 1 else 0.0,
213
+ bidirectional=cfg.bidirectional,
214
+ )
215
+ out_dim = cfg.hidden_size * (2 if cfg.bidirectional else 1)
216
+ self._out_dim = cfg.output_dim or out_dim
217
+ self.proj = None if self._out_dim == out_dim else nn.Linear(out_dim, self._out_dim)
218
+
219
+ @property
220
+ def output_dim(self) -> int:
221
+ return self._out_dim
222
+
223
+ def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
224
+ squeeze_back = False
225
+ if x.dim() == 2:
226
+ x = x.unsqueeze(1)
227
+ squeeze_back = True
228
+ if lengths is not None:
229
+ x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
230
+ if isinstance(self.rnn, nn.LSTM):
231
+ out, (h, c) = self.rnn(x)
232
+ else:
233
+ out, h = self.rnn(x)
234
+ if lengths is not None:
235
+ out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
236
+ # Use last timestep
237
+ y = out[:, -1, :]
238
+ if self.proj is not None:
239
+ y = self.proj(y)
240
+ if squeeze_back:
241
+ # Keep 2D output
242
+ return y
243
+ # Return (B, 1, D) to keep 3D shape consistent with sequences
244
+ return y.unsqueeze(1)
245
+
246
+
247
+ # ---------------------------- ConvolutionalBlock ---------------------------- #
248
+
249
+ class ConvolutionalBlock(BaseBlock):
250
+ """1D convolutional block for sequence-like data.
251
+ Accepts 3D (B, T, F) or 2D (B, F) which is treated as (B, 1, F).
252
+ """
253
+
254
+ def __init__(self, cfg: ConvolutionalBlockConfig):
255
+ super().__init__()
256
+ self.cfg = cfg
257
+ # Conv1d expects (B, C_in, L). We interpret features as channels and time as length.
258
+ # For inputs shaped (B, T, F): we transpose to (B, F, T), apply conv, transpose back.
259
+ padding = cfg.padding
260
+ if isinstance(padding, str) and padding == "same":
261
+ pad = cfg.kernel_size // 2
262
+ else:
263
+ pad = int(padding)
264
+ self.conv = nn.Conv1d(cfg.input_dim, cfg.output_dim, kernel_size=cfg.kernel_size, padding=pad)
265
+ # Norm: for Conv1d, use 1d norms over channels
266
+ if cfg.normalization == "layer":
267
+ self.norm = nn.GroupNorm(1, cfg.output_dim) # Layer-like over channels
268
+ else:
269
+ self.norm = _get_norm(cfg.normalization, cfg.output_dim)
270
+ self.act = _get_activation(cfg.activation)
271
+ self.drop = nn.Dropout(cfg.dropout_rate) if cfg.dropout_rate and cfg.dropout_rate > 0 else nn.Identity()
272
+
273
+ @property
274
+ def output_dim(self) -> int:
275
+ return self.cfg.output_dim
276
+
277
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
278
+ squeeze_back = False
279
+ if x.dim() == 2:
280
+ x = x.unsqueeze(1)
281
+ squeeze_back = True
282
+ # x: (B, T, F) -> (B, F, T)
283
+ x = x.transpose(1, 2)
284
+
285
+ y = self.conv(x)
286
+ if isinstance(self.norm, (nn.BatchNorm1d, nn.InstanceNorm1d, nn.GroupNorm)):
287
+ y = self.norm(y)
288
+ y = self.act(y)
289
+ y = self.drop(y)
290
+ y = y.transpose(1, 2)
291
+ if squeeze_back:
292
+ y = y.squeeze(1)
293
+ return y
294
+
295
+ # ---------------------------- VariationalBlock ---------------------------- #
296
+
297
+ class VariationalBlock(BaseBlock):
298
+ """Encapsulates mu/logvar projection and reparameterization.
299
+
300
+ Input can be 2D (B, F) or 3D (B, T, F); for 3D, operates per timestep and returns same rank.
301
+ Stores mu/logvar on the module for downstream loss usage.
302
+ """
303
+
304
+ def __init__(self, cfg: VariationalBlockConfig):
305
+ super().__init__()
306
+ self.cfg = cfg
307
+ self.fc_mu = nn.Linear(cfg.input_dim, cfg.latent_dim)
308
+ self.fc_logvar = nn.Linear(cfg.input_dim, cfg.latent_dim)
309
+ self._mu: Optional[torch.Tensor] = None
310
+ self._logvar: Optional[torch.Tensor] = None
311
+
312
+ @property
313
+ def output_dim(self) -> int:
314
+ return self.cfg.latent_dim
315
+
316
+ def forward(self, x: torch.Tensor, training: Optional[bool] = None) -> torch.Tensor:
317
+ if training is None:
318
+ training = self.training
319
+ x2d, hint = _flatten_3d_to_2d(x)
320
+ mu = self.fc_mu(x2d)
321
+ logvar = self.fc_logvar(x2d)
322
+ if training:
323
+ std = torch.exp(0.5 * logvar)
324
+ eps = torch.randn_like(std)
325
+ z = mu + eps * std
326
+ else:
327
+ z = mu
328
+ self._mu = mu
329
+ self._logvar = logvar
330
+ z = _maybe_restore_3d(z, hint)
331
+ return z
332
+
333
+
334
+
335
+
336
+ # ---------------------------- BlockSequence ---------------------------- #
337
+
338
+ class BlockSequence(nn.Module):
339
+ """Compose multiple blocks into a validated sequence.
340
+
341
+ - Validates dimension flow between blocks
342
+ - Supports gradient checkpointing (per-block) via forward(checkpoint=True)
343
+ - Supports optional skip connections: pass `skips` as list of (src_idx, dst_idx)
344
+ """
345
+
346
+ def __init__(self, blocks: Sequence[BaseBlock], validate_dims: bool = True, skips: Optional[List[Tuple[int, int]]] = None):
347
+ super().__init__()
348
+ self.blocks = nn.ModuleList(blocks)
349
+ self.skips = skips or []
350
+ if validate_dims and len(blocks) > 1:
351
+ for i in range(1, len(blocks)):
352
+ prev = blocks[i - 1]
353
+ cur = blocks[i]
354
+ if getattr(prev, "output_dim", None) is None or getattr(cur, "output_dim", None) is None:
355
+ continue
356
+ if prev.output_dim != cur.output_dim and not isinstance(cur, LinearBlock):
357
+ # Allow LinearBlock to change dims; others must preserve unless they project internally
358
+ pass # Only warn; users may know what they're doing
359
+
360
+ def forward(self, x: torch.Tensor, checkpoint: bool = False, **kwargs) -> torch.Tensor:
361
+ activations: Dict[int, torch.Tensor] = {}
362
+ for i, block in enumerate(self.blocks):
363
+ if checkpoint and x.requires_grad:
364
+ x = torch.utils.checkpoint.checkpoint(lambda inp: block(inp, **kwargs), x)
365
+ else:
366
+ x = block(x, **kwargs)
367
+ activations[i] = x
368
+ # Apply any pending skips to this idx
369
+ for src, dst in self.skips:
370
+ if dst == i and src in activations:
371
+ x = x + activations[src]
372
+ return x
373
+
374
+
375
+ # ---------------------------- Factory ---------------------------- #
376
+
377
+ class BlockFactory:
378
+ """Factory to build blocks/sequences from configs.
379
+
380
+ This is intentionally minimal; extend as needed.
381
+ """
382
+
383
+ @staticmethod
384
+ def build_block(cfg: Union[BlockConfig, Dict[str, Any]]) -> BaseBlock:
385
+ # Allow dict-like
386
+ if isinstance(cfg, dict):
387
+ type_name = cfg.get("type")
388
+ # copy and remove 'type' to satisfy dataclass init
389
+ params = dict(cfg)
390
+ params.pop("type", None)
391
+ if type_name == "linear":
392
+ return LinearBlock(LinearBlockConfig(**params))
393
+ if type_name == "attention":
394
+ return AttentionBlock(AttentionBlockConfig(**params))
395
+ if type_name == "recurrent":
396
+ return RecurrentBlock(RecurrentBlockConfig(**params))
397
+ if type_name == "conv1d":
398
+ return ConvolutionalBlock(ConvolutionalBlockConfig(**params))
399
+ raise ValueError(f"Unsupported block type in dict cfg: {type_name} cfg={cfg}")
400
+ # Dataclass path
401
+ if isinstance(cfg, LinearBlockConfig) or getattr(cfg, "type", None) == "linear":
402
+ if not isinstance(cfg, LinearBlockConfig):
403
+ cfg = LinearBlockConfig(**cfg.__dict__) # type: ignore[arg-type]
404
+ return LinearBlock(cfg)
405
+ if isinstance(cfg, AttentionBlockConfig) or getattr(cfg, "type", None) == "attention":
406
+ if not isinstance(cfg, AttentionBlockConfig):
407
+ cfg = AttentionBlockConfig(**cfg.__dict__) # type: ignore[arg-type]
408
+ return AttentionBlock(cfg)
409
+ if isinstance(cfg, RecurrentBlockConfig) or getattr(cfg, "type", None) == "recurrent":
410
+ if not isinstance(cfg, RecurrentBlockConfig):
411
+ cfg = RecurrentBlockConfig(**cfg.__dict__) # type: ignore[arg-type]
412
+ return RecurrentBlock(cfg)
413
+ if isinstance(cfg, ConvolutionalBlockConfig) or getattr(cfg, "type", None) == "conv1d":
414
+ if not isinstance(cfg, ConvolutionalBlockConfig):
415
+ cfg = ConvolutionalBlockConfig(**cfg.__dict__) # type: ignore[arg-type]
416
+ return ConvolutionalBlock(cfg)
417
+ if isinstance(cfg, VariationalBlockConfig) or getattr(cfg, "type", None) == "variational":
418
+ if not isinstance(cfg, VariationalBlockConfig):
419
+ cfg = VariationalBlockConfig(**cfg.__dict__) # type: ignore[arg-type]
420
+ return VariationalBlock(cfg)
421
+ raise ValueError(f"Unsupported block type: {cfg}")
422
+
423
+ @staticmethod
424
+ def build_sequence(configs: Sequence[Union[BlockConfig, Dict[str, Any]]]) -> BlockSequence:
425
+ blocks: List[BaseBlock] = [BlockFactory.build_block(c) for c in configs]
426
+ return BlockSequence(blocks)
427
+
428
+
429
+ __all__ = [
430
+ "BlockConfig",
431
+ "LinearBlockConfig",
432
+ "AttentionBlockConfig",
433
+ "RecurrentBlockConfig",
434
+ "ConvolutionalBlockConfig",
435
+ "VariationalBlockConfig",
436
+ "BaseBlock",
437
+ "ResidualBlock",
438
+ "LinearBlock",
439
+ "AttentionBlock",
440
+ "RecurrentBlock",
441
+ "ConvolutionalBlock",
442
+ "VariationalBlock",
443
+ "BlockSequence",
444
+ "BlockFactory",
445
+ ]
446
+
config.json CHANGED
@@ -10,7 +10,9 @@
10
  "autoencoder_type": "classic",
11
  "beta": 1.0,
12
  "bidirectional": true,
 
13
  "dropout_rate": 0.1,
 
14
  "flow_coupling_layers": 2,
15
  "hidden_dims": [
16
  16,
 
10
  "autoencoder_type": "classic",
11
  "beta": 1.0,
12
  "bidirectional": true,
13
+ "decoder_blocks": null,
14
  "dropout_rate": 0.1,
15
+ "encoder_blocks": null,
16
  "flow_coupling_layers": 2,
17
  "hidden_dims": [
18
  16,
configuration_autoencoder.py CHANGED
@@ -2,6 +2,9 @@
2
  Autoencoder configuration for Hugging Face Transformers.
3
  """
4
 
 
 
 
5
  from transformers import PretrainedConfig
6
  from typing import List, Optional
7
 
@@ -11,25 +14,114 @@ try:
11
  except Exception: # pragma: no cover
12
  _pkg_version = None
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class AutoencoderConfig(PretrainedConfig):
16
  """
17
  Configuration class for Autoencoder models.
18
-
19
  This configuration class stores the configuration of an autoencoder model. It is used to instantiate
20
  an autoencoder model according to the specified arguments, defining the model architecture.
21
-
22
  Args:
23
  input_dim (int, optional): Dimensionality of the input data. Defaults to 784.
24
- hidden_dims (List[int], optional): List of hidden layer dimensions for the encoder.
25
- The decoder will use the reverse of this list. Defaults to [512, 256, 128].
 
26
  latent_dim (int, optional): Dimensionality of the latent space. Defaults to 64.
27
- activation (str, optional): Activation function to use. Options: "relu", "tanh", "sigmoid",
28
- "leaky_relu", "gelu", "swish", "silu", "elu", "prelu", "relu6", "hardtanh",
29
- "hardsigmoid", "hardswish", "mish", "softplus", "softsign", "tanhshrink", "threshold".
30
- Defaults to "relu".
31
- dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
32
- use_batch_norm (bool, optional): Whether to use batch normalization. Defaults to True.
33
  tie_weights (bool, optional): Whether to tie encoder and decoder weights. Defaults to False.
34
  reconstruction_loss (str, optional): Type of reconstruction loss. Options: "mse", "bce", "l1",
35
  "huber", "smooth_l1", "kl_div", "cosine", "focal", "dice", "tversky", "ssim", "perceptual".
@@ -57,13 +149,15 @@ class AutoencoderConfig(PretrainedConfig):
57
  flow_coupling_layers (int, optional): Number of coupling layers for normalizing flows. Defaults to 4.
58
  **kwargs: Additional keyword arguments passed to the parent class.
59
  """
60
-
61
  model_type = "autoencoder"
62
-
63
  def __init__(
64
  self,
65
  input_dim: int = 784,
66
  hidden_dims: List[int] = None,
 
 
67
  latent_dim: int = 64,
68
  activation: str = "relu",
69
  dropout_rate: float = 0.1,
@@ -92,7 +186,7 @@ class AutoencoderConfig(PretrainedConfig):
92
  # Validate parameters
93
  if hidden_dims is None:
94
  hidden_dims = [512, 256, 128]
95
-
96
  # Extended activation functions
97
  valid_activations = [
98
  "relu", "tanh", "sigmoid", "leaky_relu", "gelu", "swish", "silu",
@@ -127,19 +221,19 @@ class AutoencoderConfig(PretrainedConfig):
127
  raise ValueError(
128
  f"`rnn_type` must be one of {valid_rnn_types}, got {rnn_type}."
129
  )
130
-
131
  if not (0.0 <= dropout_rate <= 1.0):
132
  raise ValueError(f"`dropout_rate` must be between 0.0 and 1.0, got {dropout_rate}.")
133
-
134
  if input_dim <= 0:
135
  raise ValueError(f"`input_dim` must be positive, got {input_dim}.")
136
-
137
  if latent_dim <= 0:
138
  raise ValueError(f"`latent_dim` must be positive, got {latent_dim}.")
139
-
140
  if not all(dim > 0 for dim in hidden_dims):
141
  raise ValueError("All dimensions in `hidden_dims` must be positive.")
142
-
143
  if beta <= 0:
144
  raise ValueError(f"`beta` must be positive, got {beta}.")
145
 
@@ -174,10 +268,12 @@ class AutoencoderConfig(PretrainedConfig):
174
 
175
  if flow_coupling_layers <= 0:
176
  raise ValueError(f"`flow_coupling_layers` must be positive, got {flow_coupling_layers}.")
177
-
178
  # Set configuration attributes
179
  self.input_dim = input_dim
180
  self.hidden_dims = hidden_dims
 
 
181
  self.latent_dim = latent_dim
182
  self.activation = activation
183
  self.dropout_rate = dropout_rate
@@ -199,15 +295,20 @@ class AutoencoderConfig(PretrainedConfig):
199
  self.preprocessing_num_layers = preprocessing_num_layers
200
  self.learn_inverse_preprocessing = learn_inverse_preprocessing
201
  self.flow_coupling_layers = flow_coupling_layers
202
-
203
  # Call parent constructor
204
  super().__init__(**kwargs)
205
-
206
  @property
207
  def decoder_dims(self) -> List[int]:
208
  """Get decoder dimensions (reverse of encoder hidden dims)."""
209
  return list(reversed(self.hidden_dims))
210
 
 
 
 
 
 
211
  @property
212
  def is_variational(self) -> bool:
213
  """Check if this is a variational autoencoder."""
 
2
  Autoencoder configuration for Hugging Face Transformers.
3
  """
4
 
5
+ from dataclasses import dataclass
6
+ from typing import Union
7
+
8
  from transformers import PretrainedConfig
9
  from typing import List, Optional
10
 
 
14
  except Exception: # pragma: no cover
15
  _pkg_version = None
16
 
17
+ @dataclass
18
+ class BlockConfig:
19
+ type: str
20
+
21
+
22
+ @dataclass
23
+ class LinearBlockConfig(BlockConfig):
24
+ input_dim: int
25
+ output_dim: int
26
+ activation: str = "relu"
27
+ normalization: Optional[str] = "batch" # batch|layer|group|instance|none
28
+ dropout_rate: float = 0.0
29
+ use_residual: bool = False
30
+ residual_scale: float = 1.0
31
+
32
+ def __init__(self, input_dim: int, output_dim: int, activation: str = "relu", normalization: Optional[str] = "batch", dropout_rate: float = 0.0, use_residual: bool = False, residual_scale: float = 1.0):
33
+ super().__init__(type="linear")
34
+ self.input_dim = input_dim
35
+ self.output_dim = output_dim
36
+ self.activation = activation
37
+ self.normalization = normalization
38
+ self.dropout_rate = dropout_rate
39
+ self.use_residual = use_residual
40
+ self.residual_scale = residual_scale
41
+
42
+
43
+ @dataclass
44
+ class AttentionBlockConfig(BlockConfig):
45
+ input_dim: int
46
+ num_heads: int = 8
47
+ ffn_dim: Optional[int] = None
48
+ dropout_rate: float = 0.0
49
+
50
+ def __init__(self, input_dim: int, num_heads: int = 8, ffn_dim: Optional[int] = None, dropout_rate: float = 0.0):
51
+ super().__init__(type="attention")
52
+ self.input_dim = input_dim
53
+ self.num_heads = num_heads
54
+ self.ffn_dim = ffn_dim
55
+ self.dropout_rate = dropout_rate
56
+
57
+ @dataclass
58
+ class RecurrentBlockConfig(BlockConfig):
59
+ input_dim: int
60
+ hidden_size: int
61
+ num_layers: int = 1
62
+ rnn_type: str = "lstm" # lstm|gru|rnn
63
+ bidirectional: bool = False
64
+ dropout_rate: float = 0.0
65
+ output_dim: Optional[int] = None # if None, use hidden_size * directions
66
+
67
+ def __init__(self, input_dim: int, hidden_size: int, num_layers: int = 1, rnn_type: str = "lstm", bidirectional: bool = False, dropout_rate: float = 0.0, output_dim: Optional[int] = None):
68
+ super().__init__(type="recurrent")
69
+ self.input_dim = input_dim
70
+ self.hidden_size = hidden_size
71
+ self.num_layers = num_layers
72
+ self.rnn_type = rnn_type
73
+ self.bidirectional = bidirectional
74
+ self.dropout_rate = dropout_rate
75
+ self.output_dim = output_dim
76
+
77
+
78
+ @dataclass
79
+ class ConvolutionalBlockConfig(BlockConfig):
80
+ input_dim: int # channels in (features)
81
+ output_dim: int # channels out
82
+ kernel_size: int = 3
83
+ padding: Union[int, str] = "same" # "same" or int
84
+ activation: str = "relu"
85
+ normalization: Optional[str] = "batch"
86
+ dropout_rate: float = 0.0
87
+
88
+ def __init__(self, input_dim: int, output_dim: int, kernel_size: int = 3, padding: Union[int, str] = "same", activation: str = "relu", normalization: Optional[str] = "batch", dropout_rate: float = 0.0):
89
+ super().__init__(type="conv1d")
90
+ self.input_dim = input_dim
91
+ self.output_dim = output_dim
92
+ self.kernel_size = kernel_size
93
+ self.padding = padding
94
+ self.activation = activation
95
+ self.normalization = normalization
96
+ self.dropout_rate = dropout_rate
97
+
98
+ @dataclass
99
+ class VariationalBlockConfig(BlockConfig):
100
+ input_dim: int
101
+ latent_dim: int
102
+
103
+ def __init__(self, input_dim: int, latent_dim: int):
104
+ super().__init__(type="variational")
105
+ self.input_dim = input_dim
106
+ self.latent_dim = latent_dim
107
+
108
 
109
  class AutoencoderConfig(PretrainedConfig):
110
  """
111
  Configuration class for Autoencoder models.
112
+
113
  This configuration class stores the configuration of an autoencoder model. It is used to instantiate
114
  an autoencoder model according to the specified arguments, defining the model architecture.
115
+
116
  Args:
117
  input_dim (int, optional): Dimensionality of the input data. Defaults to 784.
118
+ hidden_dims (List[int], optional): Legacy: List of hidden layer dims for simple MLP encoder.
119
+ encoder_blocks (List[dict], optional): New: List of block configs for encoder.
120
+ decoder_blocks (List[dict], optional): New: List of block configs for decoder.
121
  latent_dim (int, optional): Dimensionality of the latent space. Defaults to 64.
122
+ activation (str, optional): Default activation for Linear blocks. See supported list below.
123
+ dropout_rate (float, optional): Default dropout for Linear blocks. Defaults to 0.1.
124
+ use_batch_norm (bool, optional): Default normalization for Linear blocks (batch vs none). Defaults to True.
 
 
 
125
  tie_weights (bool, optional): Whether to tie encoder and decoder weights. Defaults to False.
126
  reconstruction_loss (str, optional): Type of reconstruction loss. Options: "mse", "bce", "l1",
127
  "huber", "smooth_l1", "kl_div", "cosine", "focal", "dice", "tversky", "ssim", "perceptual".
 
149
  flow_coupling_layers (int, optional): Number of coupling layers for normalizing flows. Defaults to 4.
150
  **kwargs: Additional keyword arguments passed to the parent class.
151
  """
152
+
153
  model_type = "autoencoder"
154
+
155
  def __init__(
156
  self,
157
  input_dim: int = 784,
158
  hidden_dims: List[int] = None,
159
+ encoder_blocks: Optional[List[dict]] = None,
160
+ decoder_blocks: Optional[List[dict]] = None,
161
  latent_dim: int = 64,
162
  activation: str = "relu",
163
  dropout_rate: float = 0.1,
 
186
  # Validate parameters
187
  if hidden_dims is None:
188
  hidden_dims = [512, 256, 128]
189
+
190
  # Extended activation functions
191
  valid_activations = [
192
  "relu", "tanh", "sigmoid", "leaky_relu", "gelu", "swish", "silu",
 
221
  raise ValueError(
222
  f"`rnn_type` must be one of {valid_rnn_types}, got {rnn_type}."
223
  )
224
+
225
  if not (0.0 <= dropout_rate <= 1.0):
226
  raise ValueError(f"`dropout_rate` must be between 0.0 and 1.0, got {dropout_rate}.")
227
+
228
  if input_dim <= 0:
229
  raise ValueError(f"`input_dim` must be positive, got {input_dim}.")
230
+
231
  if latent_dim <= 0:
232
  raise ValueError(f"`latent_dim` must be positive, got {latent_dim}.")
233
+
234
  if not all(dim > 0 for dim in hidden_dims):
235
  raise ValueError("All dimensions in `hidden_dims` must be positive.")
236
+
237
  if beta <= 0:
238
  raise ValueError(f"`beta` must be positive, got {beta}.")
239
 
 
268
 
269
  if flow_coupling_layers <= 0:
270
  raise ValueError(f"`flow_coupling_layers` must be positive, got {flow_coupling_layers}.")
271
+
272
  # Set configuration attributes
273
  self.input_dim = input_dim
274
  self.hidden_dims = hidden_dims
275
+ self.encoder_blocks = encoder_blocks
276
+ self.decoder_blocks = decoder_blocks
277
  self.latent_dim = latent_dim
278
  self.activation = activation
279
  self.dropout_rate = dropout_rate
 
295
  self.preprocessing_num_layers = preprocessing_num_layers
296
  self.learn_inverse_preprocessing = learn_inverse_preprocessing
297
  self.flow_coupling_layers = flow_coupling_layers
298
+
299
  # Call parent constructor
300
  super().__init__(**kwargs)
301
+
302
  @property
303
  def decoder_dims(self) -> List[int]:
304
  """Get decoder dimensions (reverse of encoder hidden dims)."""
305
  return list(reversed(self.hidden_dims))
306
 
307
+ @property
308
+ def has_block_lists(self) -> bool:
309
+ """Whether explicit encoder/decoder block configs are provided."""
310
+ return (self.encoder_blocks is not None) or (self.decoder_blocks is not None)
311
+
312
  @property
313
  def is_variational(self) -> bool:
314
  """Check if this is a variational autoencoder."""
model.safetensors CHANGED
Binary files a/model.safetensors and b/model.safetensors differ
 
modeling_autoencoder.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
  from typing import Optional, Tuple, Union, Dict, Any, List
9
  from dataclasses import dataclass
10
  import random
 
11
 
12
  from transformers import PreTrainedModel
13
  from transformers.modeling_outputs import BaseModelOutput
@@ -18,653 +19,41 @@ try:
18
  except Exception:
19
  from configuration_autoencoder import AutoencoderConfig # local usage
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- class NeuralScaler(nn.Module):
23
- """Learnable alternative to StandardScaler using neural networks."""
24
-
25
- def __init__(self, config: AutoencoderConfig):
26
- super().__init__()
27
- self.config = config
28
- input_dim = config.input_dim
29
- hidden_dim = config.preprocessing_hidden_dim
30
-
31
- # Networks to learn data-dependent statistics
32
- self.mean_estimator = nn.Sequential(
33
- nn.Linear(input_dim, hidden_dim),
34
- nn.ReLU(),
35
- nn.Linear(hidden_dim, hidden_dim),
36
- nn.ReLU(),
37
- nn.Linear(hidden_dim, input_dim)
38
- )
39
-
40
- self.std_estimator = nn.Sequential(
41
- nn.Linear(input_dim, hidden_dim),
42
- nn.ReLU(),
43
- nn.Linear(hidden_dim, hidden_dim),
44
- nn.ReLU(),
45
- nn.Linear(hidden_dim, input_dim),
46
- nn.Softplus() # Ensure positive standard deviation
47
- )
48
-
49
- # Learnable affine transformation parameters
50
- self.weight = nn.Parameter(torch.ones(input_dim))
51
- self.bias = nn.Parameter(torch.zeros(input_dim))
52
-
53
- # Running statistics for inference (like BatchNorm)
54
- self.register_buffer('running_mean', torch.zeros(input_dim))
55
- self.register_buffer('running_std', torch.ones(input_dim))
56
- self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
57
-
58
- # Momentum for running statistics
59
- self.momentum = 0.1
60
-
61
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
62
- """
63
- Forward pass through neural scaler.
64
-
65
- Args:
66
- x: Input tensor (2D or 3D)
67
- inverse: Whether to apply inverse transformation
68
-
69
- Returns:
70
- Tuple of (transformed_tensor, regularization_loss)
71
- """
72
- if inverse:
73
- return self._inverse_transform(x)
74
-
75
- # Handle both 2D and 3D tensors
76
- original_shape = x.shape
77
- if x.dim() == 3:
78
- # Reshape (batch, seq, features) -> (batch*seq, features)
79
- x = x.view(-1, x.size(-1))
80
-
81
- if self.training:
82
- # Training mode: learn statistics from current batch
83
- batch_mean = x.mean(dim=0, keepdim=True)
84
- batch_std = x.std(dim=0, keepdim=True)
85
-
86
- # Learn data-dependent adjustments
87
- learned_mean_adj = self.mean_estimator(batch_mean)
88
- learned_std_adj = self.std_estimator(batch_std)
89
-
90
- # Combine batch statistics with learned adjustments
91
- effective_mean = batch_mean + learned_mean_adj
92
- effective_std = batch_std + learned_std_adj + 1e-8
93
-
94
- # Update running statistics
95
- with torch.no_grad():
96
- self.num_batches_tracked += 1
97
- if self.num_batches_tracked == 1:
98
- self.running_mean.copy_(batch_mean.squeeze())
99
- self.running_std.copy_(batch_std.squeeze())
100
- else:
101
- self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
102
- self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
103
- else:
104
- # Inference mode: use running statistics
105
- effective_mean = self.running_mean.unsqueeze(0)
106
- effective_std = self.running_std.unsqueeze(0) + 1e-8
107
-
108
- # Normalize
109
- normalized = (x - effective_mean) / effective_std
110
-
111
- # Apply learnable affine transformation
112
- transformed = normalized * self.weight + self.bias
113
-
114
- # Reshape back to original shape if needed
115
- if len(original_shape) == 3:
116
- transformed = transformed.view(original_shape)
117
-
118
- # Regularization loss to encourage meaningful learning
119
- reg_loss = 0.01 * (self.weight.var() + self.bias.var())
120
-
121
- return transformed, reg_loss
122
-
123
- def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
124
- """Apply inverse transformation to get back original scale."""
125
- if not self.config.learn_inverse_preprocessing:
126
- return x, torch.tensor(0.0, device=x.device)
127
-
128
- # Handle both 2D and 3D tensors
129
- original_shape = x.shape
130
- if x.dim() == 3:
131
- # Reshape (batch, seq, features) -> (batch*seq, features)
132
- x = x.view(-1, x.size(-1))
133
-
134
- # Reverse affine transformation
135
- x = (x - self.bias) / (self.weight + 1e-8)
136
-
137
- # Reverse normalization using running statistics
138
- effective_mean = self.running_mean.unsqueeze(0)
139
- effective_std = self.running_std.unsqueeze(0) + 1e-8
140
- x = x * effective_std + effective_mean
141
-
142
- # Reshape back to original shape if needed
143
- if len(original_shape) == 3:
144
- x = x.view(original_shape)
145
-
146
- return x, torch.tensor(0.0, device=x.device)
147
-
148
-
149
-
150
- class LearnableMinMaxScaler(nn.Module):
151
- """Learnable MinMax scaler that adapts bounds during training.
152
-
153
- Scales features to [0, 1] using batch min/range with learnable adjustments and
154
- a learnable affine transform. Supports 2D (B, F) and 3D (B, T, F) inputs.
155
- """
156
-
157
- def __init__(self, config: AutoencoderConfig):
158
- super().__init__()
159
- self.config = config
160
- input_dim = config.input_dim
161
- hidden_dim = config.preprocessing_hidden_dim
162
-
163
- # Networks to learn adjustments to batch min and range
164
- self.min_estimator = nn.Sequential(
165
- nn.Linear(input_dim, hidden_dim),
166
- nn.ReLU(),
167
- nn.Linear(hidden_dim, hidden_dim),
168
- nn.ReLU(),
169
- nn.Linear(hidden_dim, input_dim),
170
- )
171
- self.range_estimator = nn.Sequential(
172
- nn.Linear(input_dim, hidden_dim),
173
- nn.ReLU(),
174
- nn.Linear(hidden_dim, hidden_dim),
175
- nn.ReLU(),
176
- nn.Linear(hidden_dim, input_dim),
177
- nn.Softplus(), # Ensure positive adjustment to range
178
- )
179
-
180
- # Learnable affine transformation parameters
181
- self.weight = nn.Parameter(torch.ones(input_dim))
182
- self.bias = nn.Parameter(torch.zeros(input_dim))
183
-
184
- # Running statistics for inference
185
- self.register_buffer("running_min", torch.zeros(input_dim))
186
- self.register_buffer("running_range", torch.ones(input_dim))
187
- self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
188
-
189
- self.momentum = 0.1
190
-
191
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
192
- if inverse:
193
- return self._inverse_transform(x)
194
-
195
- original_shape = x.shape
196
- if x.dim() == 3:
197
- x = x.view(-1, x.size(-1))
198
-
199
- eps = 1e-8
200
- if self.training:
201
- batch_min = x.min(dim=0, keepdim=True).values
202
- batch_max = x.max(dim=0, keepdim=True).values
203
- batch_range = (batch_max - batch_min).clamp_min(eps)
204
-
205
- # Learn adjustments
206
- learned_min_adj = self.min_estimator(batch_min)
207
- learned_range_adj = self.range_estimator(batch_range)
208
-
209
- effective_min = batch_min + learned_min_adj
210
- effective_range = batch_range + learned_range_adj + eps
211
-
212
- # Update running stats with raw batch min/range for stable inversion
213
- with torch.no_grad():
214
- self.num_batches_tracked += 1
215
- if self.num_batches_tracked == 1:
216
- self.running_min.copy_(batch_min.squeeze())
217
- self.running_range.copy_(batch_range.squeeze())
218
- else:
219
- self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum)
220
- self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum)
221
- else:
222
- effective_min = self.running_min.unsqueeze(0)
223
- effective_range = self.running_range.unsqueeze(0)
224
-
225
- # Scale to [0, 1]
226
- scaled = (x - effective_min) / effective_range
227
-
228
- # Learnable affine transform
229
- transformed = scaled * self.weight + self.bias
230
-
231
- if len(original_shape) == 3:
232
- transformed = transformed.view(original_shape)
233
-
234
- # Regularization: encourage non-degenerate range and modest affine params
235
- reg_loss = 0.01 * (self.weight.var() + self.bias.var())
236
- if self.training:
237
- reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean()
238
-
239
- return transformed, reg_loss
240
-
241
- def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
242
- if not self.config.learn_inverse_preprocessing:
243
- return x, torch.tensor(0.0, device=x.device)
244
-
245
- original_shape = x.shape
246
- if x.dim() == 3:
247
- x = x.view(-1, x.size(-1))
248
-
249
- # Reverse affine
250
- x = (x - self.bias) / (self.weight + 1e-8)
251
- # Reverse MinMax using running stats
252
- x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0)
253
-
254
- if len(original_shape) == 3:
255
- x = x.view(original_shape)
256
-
257
- return x, torch.tensor(0.0, device=x.device)
258
-
259
-
260
- class LearnableRobustScaler(nn.Module):
261
- """Learnable Robust scaler using median and IQR with learnable adjustments.
262
-
263
- Normalizes as (x - median) / IQR with learnable adjustments and an affine head.
264
- Supports 2D (B, F) and 3D (B, T, F) inputs.
265
- """
266
-
267
- def __init__(self, config: AutoencoderConfig):
268
- super().__init__()
269
- self.config = config
270
- input_dim = config.input_dim
271
- hidden_dim = config.preprocessing_hidden_dim
272
-
273
- self.median_estimator = nn.Sequential(
274
- nn.Linear(input_dim, hidden_dim),
275
- nn.ReLU(),
276
- nn.Linear(hidden_dim, hidden_dim),
277
- nn.ReLU(),
278
- nn.Linear(hidden_dim, input_dim),
279
- )
280
- self.iqr_estimator = nn.Sequential(
281
- nn.Linear(input_dim, hidden_dim),
282
- nn.ReLU(),
283
- nn.Linear(hidden_dim, hidden_dim),
284
- nn.ReLU(),
285
- nn.Linear(hidden_dim, input_dim),
286
- nn.Softplus(), # Ensure positive IQR adjustment
287
- )
288
-
289
- self.weight = nn.Parameter(torch.ones(input_dim))
290
- self.bias = nn.Parameter(torch.zeros(input_dim))
291
-
292
- self.register_buffer("running_median", torch.zeros(input_dim))
293
- self.register_buffer("running_iqr", torch.ones(input_dim))
294
- self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
295
-
296
- self.momentum = 0.1
297
-
298
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
299
- if inverse:
300
- return self._inverse_transform(x)
301
-
302
- original_shape = x.shape
303
- if x.dim() == 3:
304
- x = x.view(-1, x.size(-1))
305
-
306
- eps = 1e-8
307
- if self.training:
308
- qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0)
309
- q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :]
310
- iqr = (q75 - q25).clamp_min(eps)
311
-
312
- learned_med_adj = self.median_estimator(med)
313
- learned_iqr_adj = self.iqr_estimator(iqr)
314
-
315
- effective_median = med + learned_med_adj
316
- effective_iqr = iqr + learned_iqr_adj + eps
317
-
318
- with torch.no_grad():
319
- self.num_batches_tracked += 1
320
- if self.num_batches_tracked == 1:
321
- self.running_median.copy_(med.squeeze())
322
- self.running_iqr.copy_(iqr.squeeze())
323
- else:
324
- self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum)
325
- self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum)
326
- else:
327
- effective_median = self.running_median.unsqueeze(0)
328
- effective_iqr = self.running_iqr.unsqueeze(0)
329
-
330
- normalized = (x - effective_median) / effective_iqr
331
- transformed = normalized * self.weight + self.bias
332
-
333
- if len(original_shape) == 3:
334
- transformed = transformed.view(original_shape)
335
-
336
- reg_loss = 0.01 * (self.weight.var() + self.bias.var())
337
- if self.training:
338
- reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean()
339
-
340
- return transformed, reg_loss
341
-
342
- def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
343
- if not self.config.learn_inverse_preprocessing:
344
- return x, torch.tensor(0.0, device=x.device)
345
-
346
- original_shape = x.shape
347
- if x.dim() == 3:
348
- x = x.view(-1, x.size(-1))
349
-
350
- x = (x - self.bias) / (self.weight + 1e-8)
351
- x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0)
352
-
353
- if len(original_shape) == 3:
354
- x = x.view(original_shape)
355
-
356
- return x, torch.tensor(0.0, device=x.device)
357
-
358
-
359
- class LearnableYeoJohnsonPreprocessor(nn.Module):
360
- """Learnable Yeo-Johnson power transform with per-feature λ and affine head.
361
-
362
- Applies Yeo-Johnson transform elementwise with learnable lambda per feature,
363
- followed by standardization and a learnable affine transform. Supports 2D and 3D inputs.
364
- """
365
-
366
- def __init__(self, config: AutoencoderConfig):
367
- super().__init__()
368
- self.config = config
369
- input_dim = config.input_dim
370
-
371
- # Learnable lambda per feature (unconstrained). Initialize around 1.0
372
- self.lmbda = nn.Parameter(torch.ones(input_dim))
373
-
374
- # Learnable affine parameters after standardization
375
- self.weight = nn.Parameter(torch.ones(input_dim))
376
- self.bias = nn.Parameter(torch.zeros(input_dim))
377
-
378
- # Running stats for transformed data
379
- self.register_buffer("running_mean", torch.zeros(input_dim))
380
- self.register_buffer("running_std", torch.ones(input_dim))
381
- self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
382
- self.momentum = 0.1
383
-
384
- def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
385
- eps = 1e-6
386
- lmbda = lmbda.unsqueeze(0) # broadcast over batch
387
- pos = x >= 0
388
- # For x >= 0
389
- if_part = torch.where(
390
- torch.abs(lmbda) > eps,
391
- ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda,
392
- torch.log((x + 1.0).clamp_min(eps)),
393
- )
394
- # For x < 0
395
- two_minus_lambda = 2.0 - lmbda
396
- else_part = torch.where(
397
- torch.abs(two_minus_lambda) > eps,
398
- -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda,
399
- -torch.log((1.0 - x).clamp_min(eps)),
400
- )
401
- return torch.where(pos, if_part, else_part)
402
-
403
- def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
404
- eps = 1e-6
405
- lmbda = lmbda.unsqueeze(0)
406
- pos = y >= 0
407
- # Inverse for y >= 0
408
- x_pos = torch.where(
409
- torch.abs(lmbda) > eps,
410
- (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0,
411
- torch.exp(y) - 1.0,
412
- )
413
- # Inverse for y < 0
414
- two_minus_lambda = 2.0 - lmbda
415
- x_neg = torch.where(
416
- torch.abs(two_minus_lambda) > eps,
417
- 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda),
418
- 1.0 - torch.exp(-y),
419
- )
420
- return torch.where(pos, x_pos, x_neg)
421
-
422
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
423
- if inverse:
424
- return self._inverse_transform(x)
425
-
426
- orig_shape = x.shape
427
- if x.dim() == 3:
428
- x = x.view(-1, x.size(-1))
429
-
430
- # Apply Yeo-Johnson
431
- y = self._yeo_johnson(x, self.lmbda)
432
-
433
- # Batch stats and running stats on transformed data
434
- if self.training:
435
- batch_mean = y.mean(dim=0, keepdim=True)
436
- batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6)
437
- with torch.no_grad():
438
- self.num_batches_tracked += 1
439
- if self.num_batches_tracked == 1:
440
- self.running_mean.copy_(batch_mean.squeeze())
441
- self.running_std.copy_(batch_std.squeeze())
442
- else:
443
- self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
444
- self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
445
- mean = batch_mean
446
- std = batch_std
447
- else:
448
- mean = self.running_mean.unsqueeze(0)
449
- std = self.running_std.unsqueeze(0)
450
-
451
- y_norm = (y - mean) / std
452
- out = y_norm * self.weight + self.bias
453
-
454
- if len(orig_shape) == 3:
455
- out = out.view(orig_shape)
456
-
457
- # Regularize lambda to avoid extreme values; encourage identity around 1
458
- reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var())
459
- return out, reg
460
-
461
- def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
462
- if not self.config.learn_inverse_preprocessing:
463
- return x, torch.tensor(0.0, device=x.device)
464
-
465
- orig_shape = x.shape
466
- if x.dim() == 3:
467
- x = x.view(-1, x.size(-1))
468
-
469
- # Reverse affine and normalization with running stats
470
- y = (x - self.bias) / (self.weight + 1e-8)
471
- y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0)
472
-
473
- # Inverse Yeo-Johnson
474
- out = self._yeo_johnson_inverse(y, self.lmbda)
475
-
476
- if len(orig_shape) == 3:
477
- out = out.view(orig_shape)
478
-
479
- return out, torch.tensor(0.0, device=x.device)
480
-
481
- class CouplingLayer(nn.Module):
482
- """Coupling layer for normalizing flows."""
483
-
484
- def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"):
485
- super().__init__()
486
- self.input_dim = input_dim
487
- self.hidden_dim = hidden_dim
488
-
489
- # Create mask for coupling
490
- if mask_type == "alternating":
491
- self.register_buffer('mask', torch.arange(input_dim) % 2)
492
- elif mask_type == "half":
493
- mask = torch.zeros(input_dim)
494
- mask[:input_dim // 2] = 1
495
- self.register_buffer('mask', mask)
496
- else:
497
- raise ValueError(f"Unknown mask type: {mask_type}")
498
-
499
- # Scale and translation networks
500
- masked_dim = int(self.mask.sum().item())
501
- unmasked_dim = input_dim - masked_dim
502
-
503
- self.scale_net = nn.Sequential(
504
- nn.Linear(masked_dim, hidden_dim),
505
- nn.ReLU(),
506
- nn.Linear(hidden_dim, hidden_dim),
507
- nn.ReLU(),
508
- nn.Linear(hidden_dim, unmasked_dim),
509
- nn.Tanh() # Bounded output for stability
510
- )
511
-
512
- self.translate_net = nn.Sequential(
513
- nn.Linear(masked_dim, hidden_dim),
514
- nn.ReLU(),
515
- nn.Linear(hidden_dim, hidden_dim),
516
- nn.ReLU(),
517
- nn.Linear(hidden_dim, unmasked_dim)
518
- )
519
-
520
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
521
- """
522
- Forward pass through coupling layer.
523
-
524
- Args:
525
- x: Input tensor
526
- inverse: Whether to apply inverse transformation
527
-
528
- Returns:
529
- Tuple of (transformed_tensor, log_determinant)
530
- """
531
- mask = self.mask.bool()
532
- x_masked = x[:, mask]
533
- x_unmasked = x[:, ~mask]
534
-
535
- # Compute scale and translation
536
- s = self.scale_net(x_masked)
537
- t = self.translate_net(x_masked)
538
-
539
- if not inverse:
540
- # Forward transformation
541
- y_unmasked = x_unmasked * torch.exp(s) + t
542
- log_det = s.sum(dim=1)
543
- else:
544
- # Inverse transformation
545
- y_unmasked = (x_unmasked - t) * torch.exp(-s)
546
- log_det = -s.sum(dim=1)
547
-
548
- # Reconstruct output
549
- y = torch.zeros_like(x)
550
- y[:, mask] = x_masked
551
- y[:, ~mask] = y_unmasked
552
-
553
- return y, log_det
554
-
555
-
556
- class NormalizingFlowPreprocessor(nn.Module):
557
- """Normalizing flow for learnable data preprocessing."""
558
-
559
- def __init__(self, config: AutoencoderConfig):
560
- super().__init__()
561
- self.config = config
562
- input_dim = config.input_dim
563
- hidden_dim = config.preprocessing_hidden_dim
564
- num_layers = config.flow_coupling_layers
565
-
566
- # Create coupling layers with alternating masks
567
- self.layers = nn.ModuleList()
568
- for i in range(num_layers):
569
- mask_type = "alternating" if i % 2 == 0 else "half"
570
- self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type))
571
-
572
- # Optional: Add batch normalization between layers
573
- if config.use_batch_norm:
574
- self.batch_norms = nn.ModuleList([
575
- nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)
576
- ])
577
- else:
578
- self.batch_norms = None
579
-
580
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
581
- """
582
- Forward pass through normalizing flow.
583
-
584
- Args:
585
- x: Input tensor (2D or 3D)
586
- inverse: Whether to apply inverse transformation
587
-
588
- Returns:
589
- Tuple of (transformed_tensor, total_log_determinant)
590
- """
591
- # Handle both 2D and 3D tensors
592
- original_shape = x.shape
593
- if x.dim() == 3:
594
- # Reshape (batch, seq, features) -> (batch*seq, features)
595
- x = x.view(-1, x.size(-1))
596
-
597
- log_det_total = torch.zeros(x.size(0), device=x.device)
598
-
599
- if not inverse:
600
- # Forward pass
601
- for i, layer in enumerate(self.layers):
602
- x, log_det = layer(x, inverse=False)
603
- log_det_total += log_det
604
-
605
- # Apply batch normalization (except for last layer)
606
- if self.batch_norms and i < len(self.layers) - 1:
607
- x = self.batch_norms[i](x)
608
- else:
609
- # Inverse pass
610
- for i, layer in enumerate(reversed(self.layers)):
611
- # Reverse batch normalization (except for first layer in reverse)
612
- if self.batch_norms and i > 0:
613
- # Note: This is approximate inverse of batch norm
614
- bn_idx = len(self.layers) - 1 - i
615
- x = self.batch_norms[bn_idx](x)
616
-
617
- x, log_det = layer(x, inverse=True)
618
- log_det_total += log_det
619
-
620
- # Reshape back to original shape if needed
621
- if len(original_shape) == 3:
622
- x = x.view(original_shape)
623
-
624
- # Convert log determinant to regularization loss
625
- # Encourage the flow to preserve information (log_det close to 0)
626
- reg_loss = 0.01 * log_det_total.abs().mean()
627
-
628
- return x, reg_loss
629
-
630
-
631
- class LearnablePreprocessor(nn.Module):
632
- """Unified interface for learnable preprocessing methods."""
633
-
634
- def __init__(self, config: AutoencoderConfig):
635
- super().__init__()
636
- self.config = config
637
-
638
- if not config.has_preprocessing:
639
- self.preprocessor = nn.Identity()
640
- elif config.is_neural_scaler:
641
- self.preprocessor = NeuralScaler(config)
642
- elif config.is_normalizing_flow:
643
- self.preprocessor = NormalizingFlowPreprocessor(config)
644
- elif getattr(config, "is_minmax_scaler", False):
645
- self.preprocessor = LearnableMinMaxScaler(config)
646
- elif getattr(config, "is_robust_scaler", False):
647
- self.preprocessor = LearnableRobustScaler(config)
648
- elif getattr(config, "is_yeo_johnson", False):
649
- self.preprocessor = LearnableYeoJohnsonPreprocessor(config)
650
- else:
651
- raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
652
-
653
- def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
654
- """
655
- Apply preprocessing transformation.
656
-
657
- Args:
658
- x: Input tensor
659
- inverse: Whether to apply inverse transformation
660
-
661
- Returns:
662
- Tuple of (transformed_tensor, regularization_loss)
663
- """
664
- if isinstance(self.preprocessor, nn.Identity):
665
- return x, torch.tensor(0.0, device=x.device)
666
-
667
- return self.preprocessor(x, inverse=inverse)
668
 
669
 
670
  @dataclass
@@ -741,29 +130,6 @@ class AutoencoderEncoder(nn.Module):
741
  # Standard encoder output
742
  self.fc_out = nn.Linear(input_dim, config.latent_dim)
743
 
744
- def _get_activation(self, activation: str) -> nn.Module:
745
- """Get activation function by name."""
746
- activations = {
747
- "relu": nn.ReLU(),
748
- "tanh": nn.Tanh(),
749
- "sigmoid": nn.Sigmoid(),
750
- "leaky_relu": nn.LeakyReLU(),
751
- "gelu": nn.GELU(),
752
- "swish": nn.SiLU(),
753
- "silu": nn.SiLU(),
754
- "elu": nn.ELU(),
755
- "prelu": nn.PReLU(),
756
- "relu6": nn.ReLU6(),
757
- "hardtanh": nn.Hardtanh(),
758
- "hardsigmoid": nn.Hardsigmoid(),
759
- "hardswish": nn.Hardswish(),
760
- "mish": nn.Mish(),
761
- "softplus": nn.Softplus(),
762
- "softsign": nn.Softsign(),
763
- "tanhshrink": nn.Tanhshrink(),
764
- "threshold": nn.Threshold(threshold=0.1, value=0),
765
- }
766
- return activations[activation]
767
 
768
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
769
  """Forward pass through encoder."""
@@ -820,7 +186,7 @@ class AutoencoderDecoder(nn.Module):
820
  if config.use_batch_norm:
821
  layers.append(nn.BatchNorm1d(hidden_dim))
822
 
823
- layers.append(self._get_activation(config.activation))
824
 
825
  if config.dropout_rate > 0:
826
  layers.append(nn.Dropout(config.dropout_rate))
@@ -833,29 +199,6 @@ class AutoencoderDecoder(nn.Module):
833
 
834
  self.decoder = nn.Sequential(*layers)
835
 
836
- def _get_activation(self, activation: str) -> nn.Module:
837
- """Get activation function by name."""
838
- activations = {
839
- "relu": nn.ReLU(),
840
- "tanh": nn.Tanh(),
841
- "sigmoid": nn.Sigmoid(),
842
- "leaky_relu": nn.LeakyReLU(),
843
- "gelu": nn.GELU(),
844
- "swish": nn.SiLU(),
845
- "silu": nn.SiLU(),
846
- "elu": nn.ELU(),
847
- "prelu": nn.PReLU(),
848
- "relu6": nn.ReLU6(),
849
- "hardtanh": nn.Hardtanh(),
850
- "hardsigmoid": nn.Hardsigmoid(),
851
- "hardswish": nn.Hardswish(),
852
- "mish": nn.Mish(),
853
- "softplus": nn.Softplus(),
854
- "softsign": nn.Softsign(),
855
- "tanhshrink": nn.Tanhshrink(),
856
- "threshold": nn.Threshold(threshold=0.1, value=0),
857
- }
858
- return activations[activation]
859
 
860
  def forward(self, x: torch.Tensor) -> torch.Tensor:
861
  """Forward pass through decoder."""
@@ -1111,21 +454,75 @@ class AutoencoderModel(PreTrainedModel):
1111
  super().__init__(config)
1112
  self.config = config
1113
 
1114
- # Initialize learnable preprocessing
1115
  if config.has_preprocessing:
1116
- self.preprocessor = LearnablePreprocessor(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1117
  else:
1118
- self.preprocessor = None
 
 
 
 
1119
 
1120
- # Initialize encoder and decoder based on type
1121
- if config.is_recurrent:
1122
- self.encoder = RecurrentEncoder(config)
1123
- self.decoder = RecurrentDecoder(config)
 
1124
  else:
1125
- self.encoder = AutoencoderEncoder(config)
1126
- self.decoder = AutoencoderDecoder(config)
 
1127
 
1128
- # Tie weights if specified
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1129
  if config.tie_weights:
1130
  self._tie_weights()
1131
 
@@ -1173,62 +570,37 @@ class AutoencoderModel(PreTrainedModel):
1173
  )
1174
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1175
 
1176
- # Apply learnable preprocessing
 
 
1177
  preprocessing_loss = torch.tensor(0.0, device=input_values.device)
1178
- if self.preprocessor is not None:
1179
- input_values, preprocessing_loss = self.preprocessor(input_values, inverse=False)
1180
-
1181
- # Handle different autoencoder types
1182
- if self.config.is_recurrent:
1183
- # Recurrent autoencoder
1184
- if sequence_lengths is not None:
1185
- encoder_output = self.encoder(input_values, sequence_lengths)
1186
- else:
1187
- encoder_output = self.encoder(input_values)
1188
 
1189
- if self.config.is_variational:
1190
- latent, mu, logvar = encoder_output
1191
- self._mu = mu
1192
- self._logvar = logvar
1193
- else:
1194
- latent = encoder_output
1195
- self._mu = None
1196
- self._logvar = None
1197
-
1198
- # Determine target length for decoder
1199
- if target_length is None:
1200
- if self.config.sequence_length is not None:
1201
- target_length = self.config.sequence_length
1202
- else:
1203
- target_length = input_values.size(1) # Use input sequence length
1204
 
1205
- # Decode latent back to sequence space
1206
- reconstructed = self.decoder(latent, target_length, input_values if self.training else None)
 
 
 
 
 
 
 
1207
  else:
1208
- # Standard autoencoder
1209
- encoder_output = self.encoder(input_values)
1210
 
1211
- if self.config.is_variational:
1212
- latent, mu, logvar = encoder_output
1213
- self._mu = mu
1214
- self._logvar = logvar
1215
- else:
1216
- latent = encoder_output
1217
- self._mu = None
1218
- self._logvar = None
1219
 
1220
- # Decode latent back to input space
1221
- reconstructed = self.decoder(latent)
1222
 
1223
- # Apply inverse preprocessing to reconstruction
1224
- if self.preprocessor is not None and self.config.learn_inverse_preprocessing:
1225
- reconstructed, inverse_loss = self.preprocessor(reconstructed, inverse=True)
1226
- preprocessing_loss += inverse_loss
1227
 
1228
  hidden_states = None
1229
  if output_hidden_states:
1230
  if self.config.is_variational:
1231
- hidden_states = (latent, mu, logvar)
1232
  else:
1233
  hidden_states = (latent,)
1234
 
@@ -1263,6 +635,8 @@ class AutoencoderForReconstruction(PreTrainedModel):
1263
  # Initialize weights
1264
  self.post_init()
1265
 
 
 
1266
  def get_input_embeddings(self):
1267
  """Get input embeddings."""
1268
  return self.autoencoder.get_input_embeddings()
 
8
  from typing import Optional, Tuple, Union, Dict, Any, List
9
  from dataclasses import dataclass
10
  import random
11
+ import re
12
 
13
  from transformers import PreTrainedModel
14
  from transformers.modeling_outputs import BaseModelOutput
 
19
  except Exception:
20
  from configuration_autoencoder import AutoencoderConfig # local usage
21
 
22
+ # Block-based architecture components
23
+ try:
24
+ from .blocks import (
25
+ BlockFactory,
26
+ BlockSequence,
27
+ LinearBlockConfig,
28
+ AttentionBlockConfig,
29
+ RecurrentBlockConfig,
30
+ ConvolutionalBlockConfig,
31
+ VariationalBlockConfig,
32
+ VariationalBlock,
33
+ ) # when in package
34
+ except Exception:
35
+ from blocks import (
36
+ BlockFactory,
37
+ BlockSequence,
38
+ LinearBlockConfig,
39
+ AttentionBlockConfig,
40
+ RecurrentBlockConfig,
41
+ ConvolutionalBlockConfig,
42
+ VariationalBlockConfig,
43
+ VariationalBlock,
44
+ ) # local usage
45
+
46
+ # Shared utilities
47
+ try:
48
+ from .utils import _get_activation
49
+ except Exception:
50
+ from utils import _get_activation
51
 
52
+ # Preprocessing components
53
+ try:
54
+ from .preprocessing import PreprocessingBlock # when in package
55
+ except Exception:
56
+ from preprocessing import PreprocessingBlock # local usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  @dataclass
 
130
  # Standard encoder output
131
  self.fc_out = nn.Linear(input_dim, config.latent_dim)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
135
  """Forward pass through encoder."""
 
186
  if config.use_batch_norm:
187
  layers.append(nn.BatchNorm1d(hidden_dim))
188
 
189
+ layers.append(_get_activation(config.activation))
190
 
191
  if config.dropout_rate > 0:
192
  layers.append(nn.Dropout(config.dropout_rate))
 
199
 
200
  self.decoder = nn.Sequential(*layers)
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  def forward(self, x: torch.Tensor) -> torch.Tensor:
204
  """Forward pass through decoder."""
 
454
  super().__init__(config)
455
  self.config = config
456
 
457
+ # Initialize learnable preprocessing as a single forward block only
458
  if config.has_preprocessing:
459
+ self.pre_block = PreprocessingBlock(config, inverse=False)
460
+ else:
461
+ self.pre_block = None
462
+
463
+ # Build block-based encoder/decoder sequences (breaking change refactor)
464
+ norm = "batch" if config.use_batch_norm else "none"
465
+
466
+ def default_linear_sequence(in_dim: int, dims: List[int], activation: str, normalization: str, dropout: float) -> List[LinearBlockConfig]:
467
+ cfgs: List[LinearBlockConfig] = []
468
+ prev = in_dim
469
+ for h in dims:
470
+ cfgs.append(
471
+ LinearBlockConfig(
472
+ input_dim=prev,
473
+ output_dim=h,
474
+ activation=activation,
475
+ normalization=normalization,
476
+ dropout_rate=dropout,
477
+ use_residual=False,
478
+ )
479
+ )
480
+ prev = h
481
+ return cfgs
482
+
483
+ # Encoder: use explicit block list if provided, else hidden_dims default
484
+ if getattr(config, "encoder_blocks", None):
485
+ enc_cfgs = config.encoder_blocks
486
+ # Compute enc_out_dim from last block's output_dim if linear/conv, else assume input_dim
487
+ last_out = None
488
+ for b in enc_cfgs:
489
+ if isinstance(b, dict):
490
+ last_out = b.get("output_dim", last_out)
491
+ else:
492
+ last_out = getattr(b, "output_dim", last_out)
493
+ enc_out_dim = last_out or (config.hidden_dims[-1] if config.hidden_dims else config.input_dim)
494
  else:
495
+ enc_cfgs = default_linear_sequence(config.input_dim, config.hidden_dims, config.activation, norm, config.dropout_rate)
496
+ enc_out_dim = config.hidden_dims[-1] if config.hidden_dims else config.input_dim
497
+ base_encoder_seq: BlockSequence = BlockFactory.build_sequence(enc_cfgs) if len(enc_cfgs) > 0 else BlockSequence([])
498
+ # Do not inject pre_block into encoder sequence; apply it explicitly in forward
499
+ self.encoder_seq = base_encoder_seq
500
 
501
+ # Project to latent
502
+ if config.is_variational:
503
+ self.fc_mu = nn.Linear(enc_out_dim, config.latent_dim)
504
+ self.fc_logvar = nn.Linear(enc_out_dim, config.latent_dim)
505
+ self.to_latent = None
506
  else:
507
+ self.fc_mu = None
508
+ self.fc_logvar = None
509
+ self.to_latent = nn.Linear(enc_out_dim, config.latent_dim)
510
 
511
+ # Decoder: use explicit block list if provided, else default MLP back to input
512
+ if getattr(config, "decoder_blocks", None):
513
+ dec_cfgs = config.decoder_blocks
514
+ else:
515
+ dec_dims = config.decoder_dims + [config.input_dim]
516
+ dec_cfgs = default_linear_sequence(config.latent_dim, dec_dims, config.activation, norm, config.dropout_rate)
517
+ # For final projection to input_dim: identity activation and no norm/dropout
518
+ if len(dec_cfgs) > 0:
519
+ last = dec_cfgs[-1]
520
+ last.activation = "identity"
521
+ last.normalization = "none"
522
+ last.dropout_rate = 0.0
523
+ self.decoder_seq: BlockSequence = BlockFactory.build_sequence(dec_cfgs) if len(dec_cfgs) > 0 else BlockSequence([])
524
+
525
+ # Tie weights if specified (no-op for now)
526
  if config.tie_weights:
527
  self._tie_weights()
528
 
 
570
  )
571
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
 
573
+ # Apply learnable preprocessing via block (forward only)
574
+ if self.pre_block is not None:
575
+ input_values = self.pre_block(input_values)
576
  preprocessing_loss = torch.tensor(0.0, device=input_values.device)
 
 
 
 
 
 
 
 
 
 
577
 
578
+ # Block-based forward
579
+ # Encode through block sequence
580
+ enc_out = self.encoder_seq(input_values)
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
+ # Sample or project to latent
583
+ if self.config.is_variational:
584
+ # Use VariationalBlock to encapsulate VAE behavior
585
+ self._variational = getattr(self, '_variational', None)
586
+ if self._variational is None:
587
+ self._variational = VariationalBlock(VariationalBlockConfig(input_dim=enc_out.shape[-1], latent_dim=self.config.latent_dim)).to(enc_out.device)
588
+ latent = self._variational(enc_out, training=self.training)
589
+ self._mu = self._variational._mu
590
+ self._logvar = self._variational._logvar
591
  else:
592
+ latent = self.to_latent(enc_out) if self.to_latent is not None else enc_out
593
+ self._mu, self._logvar = None, None
594
 
595
+ # Decode back to input space
596
+ reconstructed = self.decoder_seq(latent)
 
 
 
 
 
 
597
 
 
 
598
 
 
 
 
 
599
 
600
  hidden_states = None
601
  if output_hidden_states:
602
  if self.config.is_variational:
603
+ hidden_states = (latent, getattr(self, '_mu', None), getattr(self, '_logvar', None))
604
  else:
605
  hidden_states = (latent,)
606
 
 
635
  # Initialize weights
636
  self.post_init()
637
 
638
+
639
+
640
  def get_input_embeddings(self):
641
  """Get input embeddings."""
642
  return self.autoencoder.get_input_embeddings()
preprocessing.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ """
4
+ Learnable preprocessing components for the block-based autoencoder.
5
+ Extracted from modeling_autoencoder.py to a dedicated module.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ from typing import Tuple
14
+
15
+ try:
16
+ from .blocks import BaseBlock
17
+ except Exception:
18
+ from blocks import BaseBlock
19
+
20
+ import torch.nn as nn
21
+
22
+ try:
23
+ from .configuration_autoencoder import AutoencoderConfig # when loaded via HF dynamic module
24
+ except Exception:
25
+ from configuration_autoencoder import AutoencoderConfig # local usage
26
+
27
+
28
+ class NeuralScaler(nn.Module):
29
+ """Learnable alternative to StandardScaler using neural networks."""
30
+
31
+ def __init__(self, config: AutoencoderConfig):
32
+ super().__init__()
33
+ self.config = config
34
+ input_dim = config.input_dim
35
+ hidden_dim = config.preprocessing_hidden_dim
36
+
37
+ self.mean_estimator = nn.Sequential(
38
+ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim)
39
+ )
40
+ self.std_estimator = nn.Sequential(
41
+ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus()
42
+ )
43
+
44
+ self.weight = nn.Parameter(torch.ones(input_dim))
45
+ self.bias = nn.Parameter(torch.zeros(input_dim))
46
+
47
+ self.register_buffer("running_mean", torch.zeros(input_dim))
48
+ self.register_buffer("running_std", torch.ones(input_dim))
49
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
50
+ self.momentum = 0.1
51
+
52
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ if inverse:
54
+ return self._inverse_transform(x)
55
+ original_shape = x.shape
56
+ if x.dim() == 3:
57
+ x = x.view(-1, x.size(-1))
58
+ if self.training:
59
+ batch_mean = x.mean(dim=0, keepdim=True)
60
+ batch_std = x.std(dim=0, keepdim=True)
61
+ learned_mean_adj = self.mean_estimator(batch_mean)
62
+ learned_std_adj = self.std_estimator(batch_std)
63
+ effective_mean = batch_mean + learned_mean_adj
64
+ effective_std = batch_std + learned_std_adj + 1e-8
65
+ with torch.no_grad():
66
+ self.num_batches_tracked += 1
67
+ if self.num_batches_tracked == 1:
68
+ self.running_mean.copy_(batch_mean.squeeze())
69
+ self.running_std.copy_(batch_std.squeeze())
70
+ else:
71
+ self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
72
+ self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
73
+ else:
74
+ effective_mean = self.running_mean.unsqueeze(0)
75
+ effective_std = self.running_std.unsqueeze(0) + 1e-8
76
+ normalized = (x - effective_mean) / effective_std
77
+ transformed = normalized * self.weight + self.bias
78
+ if len(original_shape) == 3:
79
+ transformed = transformed.view(original_shape)
80
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
81
+ return transformed, reg_loss
82
+
83
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
84
+ if not self.config.learn_inverse_preprocessing:
85
+ return x, torch.tensor(0.0, device=x.device)
86
+ original_shape = x.shape
87
+ if x.dim() == 3:
88
+ x = x.view(-1, x.size(-1))
89
+ x = (x - self.bias) / (self.weight + 1e-8)
90
+ effective_mean = self.running_mean.unsqueeze(0)
91
+ effective_std = self.running_std.unsqueeze(0) + 1e-8
92
+ x = x * effective_std + effective_mean
93
+ if len(original_shape) == 3:
94
+ x = x.view(original_shape)
95
+ return x, torch.tensor(0.0, device=x.device)
96
+
97
+
98
+ class LearnableMinMaxScaler(nn.Module):
99
+ """Learnable MinMax scaler that adapts bounds during training."""
100
+
101
+ def __init__(self, config: AutoencoderConfig):
102
+ super().__init__()
103
+ self.config = config
104
+ input_dim = config.input_dim
105
+ hidden_dim = config.preprocessing_hidden_dim
106
+ self.min_estimator = nn.Sequential(
107
+ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim)
108
+ )
109
+ self.range_estimator = nn.Sequential(
110
+ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus()
111
+ )
112
+ self.weight = nn.Parameter(torch.ones(input_dim))
113
+ self.bias = nn.Parameter(torch.zeros(input_dim))
114
+ self.register_buffer("running_min", torch.zeros(input_dim))
115
+ self.register_buffer("running_range", torch.ones(input_dim))
116
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
117
+ self.momentum = 0.1
118
+
119
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
120
+ if inverse:
121
+ return self._inverse_transform(x)
122
+ original_shape = x.shape
123
+ if x.dim() == 3:
124
+ x = x.view(-1, x.size(-1))
125
+ eps = 1e-8
126
+ if self.training:
127
+ batch_min = x.min(dim=0, keepdim=True).values
128
+ batch_max = x.max(dim=0, keepdim=True).values
129
+ batch_range = (batch_max - batch_min).clamp_min(eps)
130
+ learned_min_adj = self.min_estimator(batch_min)
131
+ learned_range_adj = self.range_estimator(batch_range)
132
+ effective_min = batch_min + learned_min_adj
133
+ effective_range = batch_range + learned_range_adj + eps
134
+ with torch.no_grad():
135
+ self.num_batches_tracked += 1
136
+ if self.num_batches_tracked == 1:
137
+ self.running_min.copy_(batch_min.squeeze())
138
+ self.running_range.copy_(batch_range.squeeze())
139
+ else:
140
+ self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum)
141
+ self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum)
142
+ else:
143
+ effective_min = self.running_min.unsqueeze(0)
144
+ effective_range = self.running_range.unsqueeze(0)
145
+ scaled = (x - effective_min) / effective_range
146
+ transformed = scaled * self.weight + self.bias
147
+ if len(original_shape) == 3:
148
+ transformed = transformed.view(original_shape)
149
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
150
+ if self.training:
151
+ reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean()
152
+ return transformed, reg_loss
153
+
154
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
155
+ if not self.config.learn_inverse_preprocessing:
156
+ return x, torch.tensor(0.0, device=x.device)
157
+ original_shape = x.shape
158
+ if x.dim() == 3:
159
+ x = x.view(-1, x.size(-1))
160
+ x = (x - self.bias) / (self.weight + 1e-8)
161
+ x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0)
162
+ if len(original_shape) == 3:
163
+ x = x.view(original_shape)
164
+ return x, torch.tensor(0.0, device=x.device)
165
+
166
+
167
+ class LearnableRobustScaler(nn.Module):
168
+ """Learnable Robust scaler using median and IQR with learnable adjustments."""
169
+
170
+ def __init__(self, config: AutoencoderConfig):
171
+ super().__init__()
172
+ self.config = config
173
+ input_dim = config.input_dim
174
+ hidden_dim = config.preprocessing_hidden_dim
175
+ self.median_estimator = nn.Sequential(
176
+ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim)
177
+ )
178
+ self.iqr_estimator = nn.Sequential(
179
+ nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus()
180
+ )
181
+ self.weight = nn.Parameter(torch.ones(input_dim))
182
+ self.bias = nn.Parameter(torch.zeros(input_dim))
183
+ self.register_buffer("running_median", torch.zeros(input_dim))
184
+ self.register_buffer("running_iqr", torch.ones(input_dim))
185
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
186
+ self.momentum = 0.1
187
+
188
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
189
+ if inverse:
190
+ return self._inverse_transform(x)
191
+ original_shape = x.shape
192
+ if x.dim() == 3:
193
+ x = x.view(-1, x.size(-1))
194
+ eps = 1e-8
195
+ if self.training:
196
+ qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0)
197
+ q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :]
198
+ iqr = (q75 - q25).clamp_min(eps)
199
+ learned_med_adj = self.median_estimator(med)
200
+ learned_iqr_adj = self.iqr_estimator(iqr)
201
+ effective_median = med + learned_med_adj
202
+ effective_iqr = iqr + learned_iqr_adj + eps
203
+ with torch.no_grad():
204
+ self.num_batches_tracked += 1
205
+ if self.num_batches_tracked == 1:
206
+ self.running_median.copy_(med.squeeze())
207
+ self.running_iqr.copy_(iqr.squeeze())
208
+ else:
209
+ self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum)
210
+ self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum)
211
+ else:
212
+ effective_median = self.running_median.unsqueeze(0)
213
+ effective_iqr = self.running_iqr.unsqueeze(0)
214
+ normalized = (x - effective_median) / effective_iqr
215
+ transformed = normalized * self.weight + self.bias
216
+ if len(original_shape) == 3:
217
+ transformed = transformed.view(original_shape)
218
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
219
+ if self.training:
220
+ reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean()
221
+ return transformed, reg_loss
222
+
223
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ if not self.config.learn_inverse_preprocessing:
225
+ return x, torch.tensor(0.0, device=x.device)
226
+ original_shape = x.shape
227
+ if x.dim() == 3:
228
+ x = x.view(-1, x.size(-1))
229
+ x = (x - self.bias) / (self.weight + 1e-8)
230
+ x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0)
231
+ if len(original_shape) == 3:
232
+ x = x.view(original_shape)
233
+ return x, torch.tensor(0.0, device=x.device)
234
+
235
+
236
+ class LearnableYeoJohnsonPreprocessor(nn.Module):
237
+ """Learnable Yeo-Johnson power transform with per-feature lambda and affine head."""
238
+
239
+ def __init__(self, config: AutoencoderConfig):
240
+ super().__init__()
241
+ self.config = config
242
+ input_dim = config.input_dim
243
+ self.lmbda = nn.Parameter(torch.ones(input_dim))
244
+ self.weight = nn.Parameter(torch.ones(input_dim))
245
+ self.bias = nn.Parameter(torch.zeros(input_dim))
246
+ self.register_buffer("running_mean", torch.zeros(input_dim))
247
+ self.register_buffer("running_std", torch.ones(input_dim))
248
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
249
+ self.momentum = 0.1
250
+
251
+ def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
252
+ eps = 1e-6
253
+ lmbda = lmbda.unsqueeze(0)
254
+ pos = x >= 0
255
+ if_part = torch.where(torch.abs(lmbda) > eps, ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, torch.log((x + 1.0).clamp_min(eps)))
256
+ two_minus_lambda = 2.0 - lmbda
257
+ else_part = torch.where(torch.abs(two_minus_lambda) > eps, -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, -torch.log((1.0 - x).clamp_min(eps)))
258
+ return torch.where(pos, if_part, else_part)
259
+
260
+ def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
261
+ eps = 1e-6
262
+ lmbda = lmbda.unsqueeze(0)
263
+ pos = y >= 0
264
+ x_pos = torch.where(torch.abs(lmbda) > eps, (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, torch.exp(y) - 1.0)
265
+ two_minus_lambda = 2.0 - lmbda
266
+ x_neg = torch.where(torch.abs(two_minus_lambda) > eps, 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), 1.0 - torch.exp(-y))
267
+ return torch.where(pos, x_pos, x_neg)
268
+
269
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
270
+ if inverse:
271
+ return self._inverse_transform(x)
272
+ orig_shape = x.shape
273
+ if x.dim() == 3:
274
+ x = x.view(-1, x.size(-1))
275
+ y = self._yeo_johnson(x, self.lmbda)
276
+ if self.training:
277
+ batch_mean = y.mean(dim=0, keepdim=True)
278
+ batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6)
279
+ with torch.no_grad():
280
+ self.num_batches_tracked += 1
281
+ if self.num_batches_tracked == 1:
282
+ self.running_mean.copy_(batch_mean.squeeze())
283
+ self.running_std.copy_(batch_std.squeeze())
284
+ else:
285
+ self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
286
+ self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
287
+ mean = batch_mean
288
+ std = batch_std
289
+ else:
290
+ mean = self.running_mean.unsqueeze(0)
291
+ std = self.running_std.unsqueeze(0)
292
+ y_norm = (y - mean) / std
293
+ out = y_norm * self.weight + self.bias
294
+ if len(orig_shape) == 3:
295
+ out = out.view(orig_shape)
296
+ reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var())
297
+ return out, reg
298
+
299
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ if not self.config.learn_inverse_preprocessing:
301
+ return x, torch.tensor(0.0, device=x.device)
302
+ orig_shape = x.shape
303
+ if x.dim() == 3:
304
+ x = x.view(-1, x.size(-1))
305
+ y = (x - self.bias) / (self.weight + 1e-8)
306
+ y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0)
307
+ out = self._yeo_johnson_inverse(y, self.lmbda)
308
+ if len(orig_shape) == 3:
309
+ out = out.view(orig_shape)
310
+ return out, torch.tensor(0.0, device=x.device)
311
+
312
+
313
+
314
+ class PreprocessingBlock(BaseBlock):
315
+ """Wraps a LearnablePreprocessor into a BaseBlock-compatible interface.
316
+ Forward returns the transformed tensor and stores the regularization loss in .reg_loss.
317
+ The inverse flag is configured at initialization to avoid leaking kwargs to other blocks.
318
+ """
319
+
320
+ def __init__(self, config: AutoencoderConfig, inverse: bool = False, proc: Optional[LearnablePreprocessor] = None):
321
+ super().__init__()
322
+ self.proc = proc if proc is not None else LearnablePreprocessor(config)
323
+ self._output_dim = config.input_dim
324
+ self.inverse = inverse
325
+ self.reg_loss: torch.Tensor = torch.tensor(0.0)
326
+
327
+ @property
328
+ def output_dim(self) -> int:
329
+ return self._output_dim
330
+
331
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
332
+ y, reg = self.proc(x, inverse=self.inverse)
333
+ self.reg_loss = reg
334
+ return y
335
+
336
+ class CouplingLayer(nn.Module):
337
+ """Coupling layer for normalizing flows."""
338
+
339
+ def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"):
340
+ super().__init__()
341
+ self.input_dim = input_dim
342
+ self.hidden_dim = hidden_dim
343
+ if mask_type == "alternating":
344
+ self.register_buffer("mask", torch.arange(input_dim) % 2)
345
+ elif mask_type == "half":
346
+ mask = torch.zeros(input_dim)
347
+ mask[: input_dim // 2] = 1
348
+ self.register_buffer("mask", mask)
349
+ else:
350
+ raise ValueError(f"Unknown mask type: {mask_type}")
351
+ masked_dim = int(self.mask.sum().item())
352
+ unmasked_dim = input_dim - masked_dim
353
+ self.scale_net = nn.Sequential(
354
+ nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim), nn.Tanh()
355
+ )
356
+ self.translate_net = nn.Sequential(
357
+ nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim)
358
+ )
359
+
360
+ def forward(self, x: torch.Tensor, inverse: bool = False):
361
+ mask = self.mask.bool()
362
+ x_masked = x[:, mask]
363
+ x_unmasked = x[:, ~mask]
364
+ s = self.scale_net(x_masked)
365
+ t = self.translate_net(x_masked)
366
+ if not inverse:
367
+ y_unmasked = x_unmasked * torch.exp(s) + t
368
+ log_det = s.sum(dim=1)
369
+ else:
370
+ y_unmasked = (x_unmasked - t) * torch.exp(-s)
371
+ log_det = -s.sum(dim=1)
372
+ y = torch.zeros_like(x)
373
+ y[:, mask] = x_masked
374
+ y[:, ~mask] = y_unmasked
375
+ return y, log_det
376
+
377
+
378
+ class NormalizingFlowPreprocessor(nn.Module):
379
+ """Normalizing flow for learnable data preprocessing."""
380
+
381
+ def __init__(self, config: AutoencoderConfig):
382
+ super().__init__()
383
+ self.config = config
384
+ input_dim = config.input_dim
385
+ hidden_dim = config.preprocessing_hidden_dim
386
+ num_layers = config.flow_coupling_layers
387
+ self.layers = nn.ModuleList()
388
+ for i in range(num_layers):
389
+ mask_type = "alternating" if i % 2 == 0 else "half"
390
+ self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type))
391
+ if config.use_batch_norm:
392
+ self.batch_norms = nn.ModuleList([nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)])
393
+ else:
394
+ self.batch_norms = None
395
+
396
+ def forward(self, x: torch.Tensor, inverse: bool = False):
397
+ original_shape = x.shape
398
+ if x.dim() == 3:
399
+ x = x.view(-1, x.size(-1))
400
+ log_det_total = torch.zeros(x.size(0), device=x.device)
401
+ if not inverse:
402
+ for i, layer in enumerate(self.layers):
403
+ x, log_det = layer(x, inverse=False)
404
+ log_det_total += log_det
405
+ if self.batch_norms and i < len(self.layers) - 1:
406
+ x = self.batch_norms[i](x)
407
+ else:
408
+ for i, layer in enumerate(reversed(self.layers)):
409
+ if self.batch_norms and i > 0:
410
+ bn_idx = len(self.layers) - 1 - i
411
+ x = self.batch_norms[bn_idx](x)
412
+ x, log_det = layer(x, inverse=True)
413
+ log_det_total += log_det
414
+ if len(original_shape) == 3:
415
+ x = x.view(original_shape)
416
+ reg_loss = 0.01 * log_det_total.abs().mean()
417
+ return x, reg_loss
418
+
419
+
420
+ class LearnablePreprocessor(nn.Module):
421
+ """Unified interface for learnable preprocessing methods."""
422
+
423
+ def __init__(self, config: AutoencoderConfig):
424
+ super().__init__()
425
+ self.config = config
426
+ if not config.has_preprocessing:
427
+ self.preprocessor = nn.Identity()
428
+ elif config.is_neural_scaler:
429
+ self.preprocessor = NeuralScaler(config)
430
+ elif config.is_normalizing_flow:
431
+ self.preprocessor = NormalizingFlowPreprocessor(config)
432
+ elif getattr(config, "is_minmax_scaler", False):
433
+ self.preprocessor = LearnableMinMaxScaler(config)
434
+ elif getattr(config, "is_robust_scaler", False):
435
+ self.preprocessor = LearnableRobustScaler(config)
436
+ elif getattr(config, "is_yeo_johnson", False):
437
+ self.preprocessor = LearnableYeoJohnsonPreprocessor(config)
438
+ else:
439
+ raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
440
+
441
+ def forward(self, x: torch.Tensor, inverse: bool = False):
442
+ if isinstance(self.preprocessor, nn.Identity):
443
+ return x, torch.tensor(0.0, device=x.device)
444
+ return self.preprocessor(x, inverse=inverse)
445
+
446
+
447
+
448
+ __all__ = [
449
+ "NeuralScaler",
450
+ "LearnableMinMaxScaler",
451
+ "LearnableRobustScaler",
452
+ "LearnableYeoJohnsonPreprocessor",
453
+ "CouplingLayer",
454
+ "NormalizingFlowPreprocessor",
455
+ "LearnablePreprocessor",
456
+ "PreprocessingBlock",
457
+ ]
template.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ready-to-use configuration templates for the block-based Autoencoder.
3
+
4
+ These helpers demonstrate how to assemble encoder_blocks and decoder_blocks
5
+ for a variety of architectures using the new block system. Each class extends
6
+ AutoencoderConfig and can be passed directly to AutoencoderModel.
7
+
8
+ Example:
9
+ from modeling_autoencoder import AutoencoderModel
10
+ from template import ClassicAutoencoderConfig
11
+
12
+ cfg = ClassicAutoencoderConfig(input_dim=784, latent_dim=64)
13
+ model = AutoencoderModel(cfg)
14
+ """
15
+ from __future__ import annotations
16
+
17
+ from typing import List
18
+
19
+ # Support both package-relative and flat import
20
+ try:
21
+ from .configuration_autoencoder import (
22
+ AutoencoderConfig,
23
+ )
24
+ except Exception: # pragma: no cover
25
+ from configuration_autoencoder import (
26
+ AutoencoderConfig,
27
+ )
28
+
29
+
30
+ # ------------------------------- Helpers ------------------------------- #
31
+
32
+ def _linear_stack(input_dim: int, dims: List[int], activation: str = "relu", normalization: str = "batch", dropout: float = 0.0):
33
+ """Build a list of Linear block dict configs mapping input_dim -> dims sequentially."""
34
+ blocks = []
35
+ prev = input_dim
36
+ for h in dims:
37
+ blocks.append({
38
+ "type": "linear",
39
+ "input_dim": prev,
40
+ "output_dim": h,
41
+ "activation": activation,
42
+ "normalization": normalization,
43
+ "dropout_rate": dropout,
44
+ "use_residual": False,
45
+ })
46
+ prev = h
47
+ return blocks
48
+
49
+
50
+ def _default_decoder(latent_dim: int, hidden: List[int], out_dim: int, activation: str = "relu", normalization: str = "batch", dropout: float = 0.0):
51
+ """Linear decoder: latent_dim -> hidden -> out_dim (final layer identity)."""
52
+ blocks = _linear_stack(latent_dim, hidden + [out_dim], activation, normalization, dropout)
53
+ if blocks:
54
+ blocks[-1]["activation"] = "identity"
55
+ blocks[-1]["normalization"] = "none"
56
+ blocks[-1]["dropout_rate"] = 0.0
57
+ return blocks
58
+
59
+
60
+ # ---------------------------- Class-based templates ---------------------------- #
61
+
62
+ class ClassicAutoencoderConfig(AutoencoderConfig):
63
+ """Classic dense autoencoder using Linear blocks.
64
+ Example:
65
+ cfg = ClassicAutoencoderConfig(input_dim=784, latent_dim=64)
66
+ """
67
+ def __init__(self, input_dim: int = 784, latent_dim: int = 64, hidden: List[int] = (512, 256, 128), activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, **kwargs):
68
+ hidden = list(hidden)
69
+ norm = "batch" if use_batch_norm else "none"
70
+ enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
71
+ dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
72
+ super().__init__(
73
+ input_dim=input_dim,
74
+ latent_dim=latent_dim,
75
+ activation=activation,
76
+ dropout_rate=dropout,
77
+ use_batch_norm=use_batch_norm,
78
+ autoencoder_type="classic",
79
+ encoder_blocks=enc,
80
+ decoder_blocks=dec,
81
+ **kwargs,
82
+ )
83
+
84
+
85
+ class VariationalAutoencoderConfig(AutoencoderConfig):
86
+ """Variational autoencoder (MLP). Uses VariationalBlock in the model.
87
+ Example:
88
+ cfg = VariationalAutoencoderConfig(input_dim=784, latent_dim=32)
89
+ """
90
+ def __init__(self, input_dim: int = 784, latent_dim: int = 32, hidden: List[int] = (512, 256, 128), activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, beta: float = 1.0, **kwargs):
91
+ hidden = list(hidden)
92
+ norm = "batch" if use_batch_norm else "none"
93
+ enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
94
+ dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
95
+ super().__init__(
96
+ input_dim=input_dim,
97
+ latent_dim=latent_dim,
98
+ activation=activation,
99
+ dropout_rate=dropout,
100
+ use_batch_norm=use_batch_norm,
101
+ autoencoder_type="variational",
102
+ beta=beta,
103
+ encoder_blocks=enc,
104
+ decoder_blocks=dec,
105
+ **kwargs,
106
+ )
107
+
108
+
109
+ class TransformerAutoencoderConfig(AutoencoderConfig):
110
+ """Transformer-style autoencoder with attention encoder and MLP decoder.
111
+ Works with (batch, input_dim) or (batch, time, input_dim).
112
+ Example:
113
+ cfg = TransformerAutoencoderConfig(input_dim=256, latent_dim=128)
114
+ """
115
+ def __init__(self, input_dim: int = 256, latent_dim: int = 128, num_layers: int = 2, num_heads: int = 4, ffn_mult: int = 4, activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = False, **kwargs):
116
+ norm = "batch" if use_batch_norm else "none"
117
+ enc = []
118
+ enc.append({"type": "linear", "input_dim": input_dim, "output_dim": input_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
119
+ for _ in range(num_layers):
120
+ enc.append({"type": "attention", "input_dim": input_dim, "num_heads": num_heads, "ffn_dim": ffn_mult * input_dim, "dropout_rate": dropout})
121
+ enc.append({"type": "linear", "input_dim": input_dim, "output_dim": input_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
122
+ dec = _default_decoder(latent_dim, [input_dim], input_dim, activation, norm, dropout)
123
+ super().__init__(
124
+ input_dim=input_dim,
125
+ latent_dim=latent_dim,
126
+ activation=activation,
127
+ dropout_rate=dropout,
128
+ use_batch_norm=use_batch_norm,
129
+ autoencoder_type="classic",
130
+ encoder_blocks=enc,
131
+ decoder_blocks=dec,
132
+ **kwargs,
133
+ )
134
+
135
+
136
+ class RecurrentAutoencoderConfig(AutoencoderConfig):
137
+ """Recurrent encoder (LSTM/GRU/RNN) for sequence data.
138
+ Expected input: (batch, time, input_dim). Decoder is MLP back to features per step.
139
+ Example:
140
+ cfg = RecurrentAutoencoderConfig(input_dim=128, latent_dim=64, rnn_type="lstm")
141
+ """
142
+ def __init__(self, input_dim: int = 128, latent_dim: int = 64, rnn_type: str = "lstm", num_layers: int = 2, bidirectional: bool = False, activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = False, **kwargs):
143
+ norm = "batch" if use_batch_norm else "none"
144
+ enc = [{
145
+ "type": "recurrent",
146
+ "input_dim": input_dim,
147
+ "hidden_size": latent_dim,
148
+ "num_layers": num_layers,
149
+ "rnn_type": rnn_type,
150
+ "bidirectional": bidirectional,
151
+ "dropout_rate": dropout,
152
+ "output_dim": latent_dim,
153
+ }]
154
+ dec = _default_decoder(latent_dim, [max(latent_dim, input_dim)], input_dim, activation, norm, dropout)
155
+ super().__init__(
156
+ input_dim=input_dim,
157
+ latent_dim=latent_dim,
158
+ activation=activation,
159
+ dropout_rate=dropout,
160
+ use_batch_norm=use_batch_norm,
161
+ autoencoder_type="classic",
162
+ encoder_blocks=enc,
163
+ decoder_blocks=dec,
164
+ **kwargs,
165
+ )
166
+
167
+
168
+ class ConvolutionalAutoencoderConfig(AutoencoderConfig):
169
+ """1D convolutional encoder for sequence data; decoder is per-step MLP.
170
+ Expected input: (batch, time, input_dim).
171
+ Example:
172
+ cfg = ConvolutionalAutoencoderConfig(input_dim=64, conv_channels=(64, 64))
173
+ """
174
+ def __init__(self, input_dim: int = 64, latent_dim: int = 64, conv_channels: List[int] = (64, 64), kernel_size: int = 3, activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
175
+ norm = "batch" if use_batch_norm else "none"
176
+ enc = []
177
+ prev = input_dim
178
+ for ch in conv_channels:
179
+ enc.append({"type": "conv1d", "input_dim": prev, "output_dim": ch, "kernel_size": kernel_size, "padding": "same", "activation": activation, "normalization": norm, "dropout_rate": dropout})
180
+ prev = ch
181
+ enc.append({"type": "linear", "input_dim": prev, "output_dim": latent_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
182
+ dec = _default_decoder(latent_dim, [prev], input_dim, activation, norm, dropout)
183
+ super().__init__(
184
+ input_dim=input_dim,
185
+ latent_dim=latent_dim,
186
+ activation=activation,
187
+ dropout_rate=dropout,
188
+ use_batch_norm=use_batch_norm,
189
+ autoencoder_type="classic",
190
+ encoder_blocks=enc,
191
+ decoder_blocks=dec,
192
+ **kwargs,
193
+ )
194
+
195
+
196
+ class ConvAttentionAutoencoderConfig(AutoencoderConfig):
197
+ """Mixed Conv + Attention encoder for sequence data.
198
+ Example:
199
+ cfg = ConvAttentionAutoencoderConfig(input_dim=64, latent_dim=64)
200
+ """
201
+ def __init__(self, input_dim: int = 64, latent_dim: int = 64, conv_channels: List[int] = (64,), num_heads: int = 4, activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, **kwargs):
202
+ norm = "batch" if use_batch_norm else "none"
203
+ enc = []
204
+ prev = input_dim
205
+ for ch in conv_channels:
206
+ enc.append({"type": "conv1d", "input_dim": prev, "output_dim": ch, "kernel_size": 3, "padding": "same", "activation": activation, "normalization": norm, "dropout_rate": dropout})
207
+ prev = ch
208
+ enc.append({"type": "attention", "input_dim": prev, "num_heads": num_heads, "ffn_dim": 4 * prev, "dropout_rate": dropout})
209
+ enc.append({"type": "linear", "input_dim": prev, "output_dim": latent_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout})
210
+ dec = _default_decoder(latent_dim, [prev], input_dim, activation, norm, dropout)
211
+ super().__init__(
212
+ input_dim=input_dim,
213
+ latent_dim=latent_dim,
214
+ activation=activation,
215
+ dropout_rate=dropout,
216
+ use_batch_norm=use_batch_norm,
217
+ autoencoder_type="classic",
218
+ encoder_blocks=enc,
219
+ decoder_blocks=dec,
220
+ **kwargs,
221
+ )
222
+
223
+
224
+ class LinearRecurrentAutoencoderConfig(AutoencoderConfig):
225
+ """Linear down-projection then Recurrent encoder.
226
+ Example:
227
+ cfg = LinearRecurrentAutoencoderConfig(input_dim=256, latent_dim=64, rnn_type="gru")
228
+ """
229
+ def __init__(self, input_dim: int = 256, latent_dim: int = 64, rnn_type: str = "gru", activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = False, **kwargs):
230
+ norm = "batch" if use_batch_norm else "none"
231
+ enc = [
232
+ {"type": "linear", "input_dim": input_dim, "output_dim": latent_dim, "activation": activation, "normalization": norm, "dropout_rate": dropout},
233
+ {"type": "recurrent", "input_dim": latent_dim, "hidden_size": latent_dim, "num_layers": 1, "rnn_type": rnn_type, "bidirectional": False, "dropout_rate": dropout, "output_dim": latent_dim},
234
+ ]
235
+ dec = _default_decoder(latent_dim, [], input_dim, activation, norm, dropout)
236
+ super().__init__(
237
+ input_dim=input_dim,
238
+ latent_dim=latent_dim,
239
+ activation=activation,
240
+ dropout_rate=dropout,
241
+ use_batch_norm=use_batch_norm,
242
+ autoencoder_type="classic",
243
+ encoder_blocks=enc,
244
+ decoder_blocks=dec,
245
+ **kwargs,
246
+ )
247
+
248
+
249
+ class PreprocessedAutoencoderConfig(AutoencoderConfig):
250
+ """Classic MLP AE with learnable preprocessing/inverse.
251
+ Example:
252
+ cfg = PreprocessedAutoencoderConfig(input_dim=64, preprocessing_type="neural_scaler")
253
+ """
254
+ def __init__(self, input_dim: int = 64, latent_dim: int = 32, preprocessing_type: str = "neural_scaler", hidden: List[int] = (128, 64), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
255
+ norm = "batch" if use_batch_norm else "none"
256
+ enc = _linear_stack(input_dim, list(hidden), activation, norm, dropout)
257
+ dec = _default_decoder(latent_dim, list(reversed(list(hidden))), input_dim, activation, norm, dropout)
258
+ super().__init__(
259
+ input_dim=input_dim,
260
+ latent_dim=latent_dim,
261
+ activation=activation,
262
+ dropout_rate=dropout,
263
+ use_batch_norm=use_batch_norm,
264
+ autoencoder_type="classic",
265
+ use_learnable_preprocessing=True,
266
+ preprocessing_type=preprocessing_type,
267
+ encoder_blocks=enc,
268
+ decoder_blocks=dec,
269
+ **kwargs,
270
+ )
271
+
272
+
273
+
274
+ class BetaVariationalAutoencoderConfig(AutoencoderConfig):
275
+ """Beta-VAE (MLP). Like VAE but with beta > 1 controlling KL weight.
276
+ Example:
277
+ cfg = BetaVariationalAutoencoderConfig(input_dim=784, latent_dim=32, beta=4.0)
278
+ """
279
+ def __init__(self, input_dim: int = 784, latent_dim: int = 32, hidden: List[int] = (512, 256, 128), activation: str = "relu", dropout: float = 0.1, use_batch_norm: bool = True, beta: float = 4.0, **kwargs):
280
+ hidden = list(hidden)
281
+ norm = "batch" if use_batch_norm else "none"
282
+ enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
283
+ dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
284
+ super().__init__(
285
+ input_dim=input_dim,
286
+ latent_dim=latent_dim,
287
+ activation=activation,
288
+ dropout_rate=dropout,
289
+ use_batch_norm=use_batch_norm,
290
+ autoencoder_type="beta_vae",
291
+ beta=beta,
292
+ encoder_blocks=enc,
293
+ decoder_blocks=dec,
294
+ **kwargs,
295
+ )
296
+
297
+
298
+ class DenoisingAutoencoderConfig(AutoencoderConfig):
299
+ """Denoising AE: adds noise during training (handled by training loop/model if supported).
300
+ Example:
301
+ cfg = DenoisingAutoencoderConfig(input_dim=128, latent_dim=32, noise_factor=0.2)
302
+ """
303
+ def __init__(self, input_dim: int = 128, latent_dim: int = 32, hidden: List[int] = (128, 64), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, noise_factor: float = 0.2, **kwargs):
304
+ hidden = list(hidden)
305
+ norm = "batch" if use_batch_norm else "none"
306
+ enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
307
+ dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
308
+ super().__init__(
309
+ input_dim=input_dim,
310
+ latent_dim=latent_dim,
311
+ activation=activation,
312
+ dropout_rate=dropout,
313
+ use_batch_norm=use_batch_norm,
314
+ autoencoder_type="denoising",
315
+ noise_factor=noise_factor,
316
+ encoder_blocks=enc,
317
+ decoder_blocks=dec,
318
+ **kwargs,
319
+ )
320
+
321
+
322
+ class SparseAutoencoderConfig(AutoencoderConfig):
323
+ """Sparse AE (typical L1 activation penalty applied in training loop).
324
+ Example:
325
+ cfg = SparseAutoencoderConfig(input_dim=256, latent_dim=64)
326
+ """
327
+ def __init__(self, input_dim: int = 256, latent_dim: int = 64, hidden: List[int] = (128, 64), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
328
+ hidden = list(hidden)
329
+ norm = "batch" if use_batch_norm else "none"
330
+ enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
331
+ dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
332
+ super().__init__(
333
+ input_dim=input_dim,
334
+ latent_dim=latent_dim,
335
+ activation=activation,
336
+ dropout_rate=dropout,
337
+ use_batch_norm=use_batch_norm,
338
+ autoencoder_type="sparse",
339
+ encoder_blocks=enc,
340
+ decoder_blocks=dec,
341
+ **kwargs,
342
+ )
343
+
344
+
345
+ class ContractiveAutoencoderConfig(AutoencoderConfig):
346
+ """Contractive AE (requires Jacobian penalty in training loop).
347
+ Example:
348
+ cfg = ContractiveAutoencoderConfig(input_dim=64, latent_dim=16)
349
+ """
350
+ def __init__(self, input_dim: int = 64, latent_dim: int = 16, hidden: List[int] = (64, 32), activation: str = "relu", dropout: float = 0.0, use_batch_norm: bool = True, **kwargs):
351
+ hidden = list(hidden)
352
+ norm = "batch" if use_batch_norm else "none"
353
+ enc = _linear_stack(input_dim, hidden, activation, norm, dropout)
354
+ dec = _default_decoder(latent_dim, list(reversed(hidden)), input_dim, activation, norm, dropout)
355
+ super().__init__(
356
+ input_dim=input_dim,
357
+ latent_dim=latent_dim,
358
+ activation=activation,
359
+ dropout_rate=dropout,
360
+ use_batch_norm=use_batch_norm,
361
+ autoencoder_type="contractive",
362
+ encoder_blocks=enc,
363
+ decoder_blocks=dec,
364
+ **kwargs,
365
+ )
366
+
367
+
368
+ __all__ = [
369
+ "ClassicAutoencoderConfig",
370
+ "VariationalAutoencoderConfig",
371
+ "TransformerAutoencoderConfig",
372
+ "RecurrentAutoencoderConfig",
373
+ "ConvolutionalAutoencoderConfig",
374
+ "ConvAttentionAutoencoderConfig",
375
+ "LinearRecurrentAutoencoderConfig",
376
+ "PreprocessedAutoencoderConfig",
377
+ "BetaVariationalAutoencoderConfig",
378
+ "DenoisingAutoencoderConfig",
379
+ "SparseAutoencoderConfig",
380
+ "ContractiveAutoencoderConfig",
381
+ ]
382
+
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ # ---------------------------- Utilities ---------------------------- #
11
+
12
+ def _get_activation(name: Optional[str]) -> nn.Module:
13
+ if name is None:
14
+ return nn.Identity()
15
+ name = name.lower()
16
+ mapping = {
17
+ "relu": nn.ReLU(),
18
+ "gelu": nn.GELU(),
19
+ "silu": nn.SiLU(),
20
+ "swish": nn.SiLU(),
21
+ "tanh": nn.Tanh(),
22
+ "sigmoid": nn.Sigmoid(),
23
+ "leaky_relu": nn.LeakyReLU(0.2),
24
+ "elu": nn.ELU(),
25
+ "mish": nn.Mish(),
26
+ "softplus": nn.Softplus(),
27
+ "identity": nn.Identity(),
28
+ None: nn.Identity(),
29
+ }
30
+ if name not in mapping:
31
+ raise ValueError(f"Unknown activation: {name}")
32
+ return mapping[name]
33
+
34
+
35
+ def _get_norm(name: Optional[str], num_features: int) -> nn.Module:
36
+ if name is None or name == "none":
37
+ return nn.Identity()
38
+ name = name.lower()
39
+ if name == "batch":
40
+ return nn.BatchNorm1d(num_features)
41
+ if name == "layer":
42
+ return nn.LayerNorm(num_features)
43
+ if name == "instance":
44
+ return nn.InstanceNorm1d(num_features)
45
+ if name == "group":
46
+ # default 8 groups or min that divides
47
+ groups = max(1, min(8, num_features))
48
+ # ensure divisible
49
+ while num_features % groups != 0 and groups > 1:
50
+ groups -= 1
51
+ if groups == 1:
52
+ return nn.LayerNorm(num_features)
53
+ return nn.GroupNorm(groups, num_features)
54
+ raise ValueError(f"Unknown normalization: {name}")
55
+
56
+
57
+ def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]:
58
+ if x.dim() == 3:
59
+ b, t, f = x.shape
60
+ return x.reshape(b * t, f), (b, t)
61
+ return x, None
62
+
63
+
64
+ def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor:
65
+ if shape_hint is None:
66
+ return x
67
+ b, t = shape_hint
68
+ f = x.shape[-1]
69
+ return x.reshape(b, t, f)