Kuangdai commited on
Commit
6fb6c07
·
0 Parent(s):

Initial release of SoilFormer

Browse files
Files changed (42) hide show
  1. .gitattributes +9 -0
  2. .gitignore +181 -0
  3. LICENSE +21 -0
  4. README.md +168 -0
  5. config/column_rules_numeric.json +30 -0
  6. config/config_data.json +15 -0
  7. config/config_model.json +25 -0
  8. config/config_train.json +51 -0
  9. data/cat_vocab.json +3 -0
  10. data/numeric_vocab.json +3 -0
  11. data/photo_map.json +3 -0
  12. data/tabular_data.csv +3 -0
  13. data/tabular_meta.json +3 -0
  14. data/tabular_meta_numeric_stats.csv +3 -0
  15. example/input_card.json +3 -0
  16. example/input_card__masked.json +3 -0
  17. example/input_card__unmasked.json +3 -0
  18. example/output_card.json +3 -0
  19. example/output_card__acc.json +3 -0
  20. inference_create_input_card.py +318 -0
  21. inference_predict_output_card.py +545 -0
  22. model_weights/gemma3n_E2B_vision_only/config.json +3 -0
  23. model_weights/gemma3n_E2B_vision_only/model.safetensors +3 -0
  24. model_weights/gemma3n_E2B_vision_only/modeling_gemma3n.py +3 -0
  25. model_weights/gemma3n_E2B_vision_only/processor_config.json +3 -0
  26. model_weights/gemma3n_E2B_vision_only/tokenizer.json +3 -0
  27. model_weights/gemma3n_E2B_vision_only/tokenizer_config.json +3 -0
  28. model_weights/gemma3n_E2B_vision_only/vision_extractor_config.json +3 -0
  29. model_weights/soilformer_pretrain/hetero_epoch_200.pt +3 -0
  30. modelling/__init__.py +0 -0
  31. modelling/decode_categorical.py +423 -0
  32. modelling/decode_numeric.py +238 -0
  33. modelling/embed_categorical.py +322 -0
  34. modelling/embed_numeric.py +547 -0
  35. modelling/embed_vision_gemma3n.py +552 -0
  36. modelling/layer.py +353 -0
  37. modelling/loader.py +1025 -0
  38. modelling/soilformer.py +696 -0
  39. modelling/train.py +552 -0
  40. modelling/utils.py +132 -0
  41. requirements.txt +10 -0
  42. resources/arch.png +3 -0
.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ model_weights/** filter=lfs diff=lfs merge=lfs -text
4
+ data/*.csv filter=lfs diff=lfs merge=lfs -text
5
+ data/*.json filter=lfs diff=lfs merge=lfs -text
6
+ example/*.json filter=lfs diff=lfs merge=lfs -text
7
+ resources/*.png filter=lfs diff=lfs merge=lfs -text
8
+ *.pt filter=lfs diff=lfs merge=lfs -text
9
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ .idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # Cursor
177
+ # Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to
178
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
179
+ # refer to https://docs.cursor.com/context/ignore-files
180
+ .cursorignore
181
+ .cursorindexingignore
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Kuangdai
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ language:
5
+ - en
6
+ tags:
7
+ - soil
8
+ - soil-science
9
+ - earth-science
10
+ - environmental-science
11
+ - multimodal
12
+ - tabular
13
+ - transformer
14
+ - representation-learning
15
+ - masked-feature-modeling
16
+ - remote-sensing
17
+ - europe
18
+ datasets:
19
+ - earthroverprogram/lucas-mega
20
+ ---
21
+
22
+ # SoilFormer
23
+
24
+ A multimodal tabular transformer trained on [LUCAS-MEGA](https://huggingface.co/datasets/earthroverprogram/lucas-mega).
25
+
26
+ [Manuscript](https://huggingface.co/datasets/earthroverprogram/lucas-mega/manuscript.pdf)
27
+
28
+ ## Introduction
29
+
30
+ SoilFormer is a multimodal transformer for representation learning in soil–environment systems. It is trained on
31
+ LUCAS-MEGA, a large-scale dataset built from European soil and environmental observations, with the LUCAS soil survey as
32
+ its backbone. LUCAS-MEGA integrates heterogeneous sources into a machine-learning-ready sample–feature table, covering
33
+ numerical, categorical, textual, and visual modalities across soil physical, chemical, hydrological, environmental, and
34
+ site-related properties.
35
+
36
+ SoilFormer learns from partially observed multimodal samples using masked feature modeling. During training, a subset of
37
+ observed categorical and numerical features is masked, and the model reconstructs them from the remaining tabular and
38
+ visual context. The architecture combines grouped categorical embedding, grouped numerical encoding/decoding, vision
39
+ feature extraction and compression, transformer layers, and heteroscedastic prediction heads for uncertainty-aware
40
+ reconstruction.
41
+
42
+ <img src="resources/arch.png" alt="SoilFormer architecture" width="70%">
43
+
44
+ ## Training
45
+
46
+ Train SoilFormer with:
47
+
48
+ ```bash
49
+ python modelling/train.py
50
+ ```
51
+
52
+ Main configuration files:
53
+
54
+ * `config/config_model.json`: model architecture parameters, including embedding sizes, transformer layer settings,
55
+ decoder settings, dtype, and vision model configuration.
56
+ * `config/config_data.json`: data parameters, including CSV path, vocab paths, numeric statistics, photo mapping, image
57
+ root, train/eval split, batch size, and masking ratios.
58
+ * `config/config_train.json`: training hyperparameters, including runtime device, seed, optimizer settings, scheduler
59
+ settings, checkpoint behavior, loss options, logging, and output paths.
60
+
61
+ ## Inference
62
+
63
+ Inference uses readable JSON input cards. The workflow is:
64
+
65
+ 1. Create input cards from one dataset row.
66
+ 2. Edit the masked card manually if desired.
67
+ 3. Run model prediction from the edited card.
68
+ 4. Optionally compare predictions against the unmasked answer card.
69
+
70
+ ### 1. Create input cards
71
+
72
+ ```bash
73
+ python create_input_card_from_dataset.py \
74
+ --row_index 10 \
75
+ --output example/input_card.json
76
+ ```
77
+
78
+ This writes two files:
79
+
80
+ ```text
81
+ example/input_card__unmasked.json
82
+ example/input_card__masked.json
83
+ ```
84
+
85
+ The unmasked card contains the raw readable values from the CSV row. The masked card randomly replaces a fraction of
86
+ categorical and numeric values with `null`. Natural missing values remain as empty strings `""`, while active masks are
87
+ represented as `null`.
88
+
89
+ Default masking ratios are 0.15 for both categorical and numeric features:
90
+
91
+ ```bash
92
+ python create_input_card_from_dataset.py \
93
+ --row_index 10 \
94
+ --output example/input_card.json \
95
+ --cat_mask_ratio 0.15 \
96
+ --num_mask_ratio 0.15 \
97
+ --seed 42
98
+ ```
99
+
100
+ The card format is intentionally simple and user-editable. Users can copy this card as a template, replace the values
101
+ with their own soil sample information, and set variables to `null` to indicate which fields should be predicted during
102
+ inference:
103
+
104
+ ```json
105
+ {
106
+ "categorical": {
107
+ "land_site:land_cover_primary": "B16: Cropland => Cereals => Maize",
108
+ "land_site:land_use_primary": null,
109
+ "soil_type:WRB_soil_group": "Cambisol",
110
+ "texture:ISSS_class": "silty clay",
111
+ "...": "..."
112
+ },
113
+ "numeric": {
114
+ "carbon:CaCO3_content (g/kg)": 7.0,
115
+ "carbon:SOC_saturation_ratio": 0.3647958934307098,
116
+ "geographic:latitude (deg)": 38.8513900000485,
117
+ "geographic:longitude (deg)": -9.29050000007487,
118
+ "mass_density:bulk_density (g/cm³)": null,
119
+ "...": "..."
120
+ },
121
+ "vision": {
122
+ "image_path_suffix": "relative/path/to/photo.jpg"
123
+ }
124
+ }
125
+ ```
126
+
127
+ ### 2. Run prediction
128
+
129
+ ```bash
130
+ python inference_predict_output_card.py \
131
+ --checkpoint model_weights/soilformer_pretrain/hetero_epoch_200.pt \
132
+ --input_card example/input_card__masked.json \
133
+ --output example/output_card.json
134
+ ```
135
+
136
+ This writes:
137
+
138
+ ```text
139
+ example/output_card.json
140
+ ```
141
+
142
+ `output_card.json` contains readable predictions:
143
+
144
+ * categorical outputs are decoded back to raw category labels;
145
+ * numeric outputs are converted from z-score space back to the original physical units;
146
+ * vision input is read from `vision.image_path_suffix` together with `photo_root` in `config/config_data.json`.
147
+
148
+ ### 3. Evaluation with an answer card
149
+
150
+ ```bash
151
+ python inference_predict_output_card.py \
152
+ --checkpoint model_weights/soilformer_pretrain/hetero_epoch_200.pt \
153
+ --input_card example/input_card__masked.json \
154
+ --answer_card example/input_card__unmasked.json \
155
+ --output example/output_card.json
156
+ ```
157
+
158
+ This additionally writes:
159
+
160
+ ```text
161
+ example/output_card__acc.json
162
+ ```
163
+
164
+ When `--answer_card` is provided, `output_card__acc.json` reports reconstruction metrics over fields that are `null` in
165
+ the masked input card:
166
+
167
+ * categorical accuracy for masked categorical fields;
168
+ * numeric MAE for masked numeric fields, measured in the original feature units.
config/column_rules_numeric.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "texture:silt_percentage (%)": ">=0",
3
+ "chemical:pH_in_H2O": ">0",
4
+ "chemical:pH_in_CaCl2": ">0",
5
+ "carbon:organic_carbon_content (g/kg)": ">0",
6
+ "carbon:CaCO3_content (g/kg)": ">0",
7
+ "carbon:observed_vs_typical_soc_index_confidence_zone": "exclude",
8
+ "carbon:observed_vs_typical_soc_index": "exclude",
9
+ "fertility:N_extractable (g/kg)": ">0",
10
+ "fertility:K_extractable (mg/kg)": ">0",
11
+ "fertility:P_extractable (mg/kg)": ">0",
12
+ "fertility:P_available_stock (kg ha⁻¹)": ">0",
13
+ "land_degradation:soil_erosion_exceeding_10Mg_ha_yr (t ha⁻¹ yr⁻¹)": "exclude",
14
+ "crop_plant:cover_crop_fraction_5th_percentile (‱)": "exclude",
15
+ "crop_plant:cover_crop_fraction_95th_percentile (‱)": "exclude",
16
+ "mass_density:bulk_density_0_10cm (g/cm³)": ">0",
17
+ "mass_density:bulk_density_10_20cm (g/cm³)": ">0",
18
+ "mass_density:bulk_density (g/cm³)": ">0",
19
+ "biodiversity:land_use_change_pressure_index": "exclude",
20
+ "biodiversity:genetically_modified_organism_use_pressure_index": "exclude",
21
+ "trace_elements:Zn_concentration_5th_percentile (mg/kg)": "exclude",
22
+ "trace_elements:Zn_concentration_95th_percentile (mg/kg)": "exclude",
23
+ "trace_elements:As_concentration_std (log10 mg/kg)": "exclude",
24
+ "trace_elements:As_concentration_skewness": "exclude",
25
+ "trace_elements:As_concentration_kurtosis": "exclude",
26
+ "trace_elements:Hg_residual (µg/kg)": "exclude",
27
+ "climate:monthly_temperature_JAN_to_DEC (°C)": ">-100",
28
+ "climate:monthly_precipitation_JAN_to_DEC (mm)": ">-100",
29
+ "topography_geology:elevation (m)": "<4000"
30
+ }
config/config_data.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data_csv_path": "data/tabular_data.csv",
3
+ "photo_map_path": "data/photo_map.json",
4
+ "cat_vocab_path": "data/cat_vocab.json",
5
+ "numeric_vocab_path": "data/numeric_vocab.json",
6
+ "numeric_stats_path": "data/tabular_meta_numeric_stats.csv",
7
+ "photo_root": "",
8
+ "image_size": 512,
9
+ "train_ratio": 0.8,
10
+ "train_eval_split_seed": 42,
11
+ "batch_size": 64,
12
+ "cat_mask_ratio": 0.15,
13
+ "num_mask_ratio": 0.15,
14
+ "active_mask_seed": 42
15
+ }
config/config_model.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dtype": "float32",
3
+ "tabular_meta": "data/tabular_meta.json",
4
+ "vision_model_dir": "./model_weights/gemma3n_E2B_vision_only",
5
+ "vision_num_output_tokens_reduced": 32,
6
+ "vision_num_heads_for_token_reduction": 4,
7
+ "vision_reducer_bottleneck_dim": 768,
8
+ "vision_reducer_project_back": false,
9
+ "cat_vocab_json": "data/cat_vocab.json",
10
+ "cat_hidden_size": 768,
11
+ "cat_decode_middle_size": null,
12
+ "numeric_vocab_json": "data/numeric_vocab.json",
13
+ "numeric_hidden_size": 768,
14
+ "numeric_encode_middle_size": null,
15
+ "numeric_decode_middle_size": null,
16
+ "layer_num_query_heads": 8,
17
+ "layer_num_kv_heads": 2,
18
+ "layer_head_dim": 128,
19
+ "layer_mlp_ratio": 1.5,
20
+ "layer_dropout": 0.1,
21
+ "layer_num_layers": 4,
22
+ "disable_tabular_attention_mask": true,
23
+ "cat_homoscedastic": false,
24
+ "num_homoscedastic": false
25
+ }
config/config_train.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "paths": {
3
+ "config_data_path": "config/config_data.json",
4
+ "config_model_path": "config/config_model.json",
5
+ "output_dir": "runs/soilformer_hetero"
6
+ },
7
+ "seed": {
8
+ "seed": 42,
9
+ "deterministic": true
10
+ },
11
+ "runtime": {
12
+ "device": "cuda",
13
+ "num_epochs": 500,
14
+ "init_weight_std": 0.02
15
+ },
16
+ "optimization": {
17
+ "lr": 1e-4,
18
+ "beta1": 0.9,
19
+ "beta2": 0.999,
20
+ "eps": 1e-8,
21
+ "weight_decay": 0.02,
22
+ "max_grad_norm": 1.0,
23
+ "scheduler": {
24
+ "type": "cosine",
25
+ "total_epochs": 500,
26
+ "eta_min": 2e-5,
27
+ "warmup_epochs": 5,
28
+ "warmup_start_factor": 0.1
29
+ }
30
+ },
31
+ "loss": {
32
+ "cat_s_bound": 2,
33
+ "num_s_bound": 4
34
+ },
35
+ "checkpoint": {
36
+ "resume_checkpoint_path": null,
37
+ "epochs_per_save": 100,
38
+ "max_saved_checkpoints": 5
39
+ },
40
+ "logging": {
41
+ "tqdm": true,
42
+ "wandb": {
43
+ "enabled": true,
44
+ "project": "soilformer",
45
+ "entity": "kuangdai-leng",
46
+ "run_name": "train-hetero",
47
+ "mode": "online",
48
+ "dir": null
49
+ }
50
+ }
51
+ }
data/cat_vocab.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da160e500f0bf01207642f39b666d84d2787fae0f8ec21bb630e10e079780843
3
+ size 14934
data/numeric_vocab.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddcfc729da9e6f5830d58f6b53928a6fa6dcd108a0ddac3eb7fe67abed3dcadc
3
+ size 17492
data/photo_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:018b7f5baaa58e8e3e5e2c6cf98d02aa547a13c6de55f1628984010fc331235c
3
+ size 4651435
data/tabular_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e87d387791bb95e3accab9afee8bda1e7e8722bad6e75d04c47a56787b24608
3
+ size 103677102
data/tabular_meta.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de393215e2bbe46b111cc5b604b7fd04c14d28a634a52b06dcb94fd7073200eb
3
+ size 84654
data/tabular_meta_numeric_stats.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73ef56458b2d7d6cbb730b153a7fd9f445dba4a96d6f29483364e38a9102c150
3
+ size 7714
example/input_card.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82977f50ba8a6a7d7542a4098434d232ad0feff5b1797b088a99b78504604420
3
+ size 6114
example/input_card__masked.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fadbf9afee9bd61fb87c4fd174306bcf6cae441e973b50afdbebd4bd433cb0be
3
+ size 5902
example/input_card__unmasked.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82977f50ba8a6a7d7542a4098434d232ad0feff5b1797b088a99b78504604420
3
+ size 6114
example/output_card.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c11d1ae2c9ada543038ca38dcb5a2a496a8392ca52b07cea215cfd46f0172af0
3
+ size 7261
example/output_card__acc.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d76e4284a3c9e5be054cceb8f96bb5c20434dc1ea11f1904a2d1663d910efd4e
3
+ size 3388
inference_create_input_card.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Optional
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ from modelling.utils import load_json
12
+
13
+
14
+ def to_jsonable(value: Any) -> Any:
15
+ if value is None:
16
+ return None
17
+ if isinstance(value, float) and pd.isna(value):
18
+ return None
19
+ if isinstance(value, np.generic):
20
+ return value.item()
21
+ return value
22
+
23
+
24
+ def parse_optional_int(value: Optional[str]) -> Optional[int]:
25
+ if value is None:
26
+ return None
27
+ value = str(value).strip().lower()
28
+ if value in {"", "none", "null", "random"}:
29
+ return None
30
+ return int(value)
31
+
32
+
33
+ def choose_row_index(num_rows: int, row_index: Optional[int], seed: int) -> int:
34
+ if num_rows <= 0:
35
+ raise RuntimeError("CSV has no rows")
36
+ if row_index is None:
37
+ return random.Random(seed).randrange(num_rows)
38
+ if row_index < 0 or row_index >= num_rows:
39
+ raise IndexError(f"row_index out of range: {row_index}; num_rows={num_rows}")
40
+ return row_index
41
+
42
+
43
+ def validate_ratio(name: str, value: float) -> float:
44
+ value = float(value)
45
+ if not 0.0 <= value <= 1.0:
46
+ raise ValueError(f"{name} must be in [0, 1], got {value}")
47
+ return value
48
+
49
+
50
+ def load_json_if_exists(path: Optional[str]) -> Optional[Dict[str, Any]]:
51
+ if not path:
52
+ return None
53
+ p = Path(path)
54
+ if not p.exists() or not p.is_file():
55
+ return None
56
+ return load_json(str(p))
57
+
58
+
59
+ def get_categorical_columns(config_data: Dict[str, Any]) -> list[str]:
60
+ cat_vocab = load_json_if_exists(config_data.get("cat_vocab_path"))
61
+ if not isinstance(cat_vocab, dict):
62
+ return []
63
+ return list(cat_vocab.keys())
64
+
65
+
66
+ def get_numeric_columns(config_data: Dict[str, Any]) -> list[str]:
67
+ numeric_vocab = load_json_if_exists(config_data.get("numeric_vocab_path"))
68
+ if not isinstance(numeric_vocab, dict):
69
+ return []
70
+
71
+ columns: list[str] = []
72
+ for group in numeric_vocab.get("groups", []):
73
+ for name in group.get("feature_names", []):
74
+ columns.append(str(name))
75
+ return columns
76
+
77
+
78
+ def get_vision_input(config_data: Dict[str, Any], row: Dict[str, Any]) -> Dict[str, Any]:
79
+ photo_map = load_json_if_exists(config_data.get("photo_map_path"))
80
+ id_column = str(config_data.get("id_column", "id"))
81
+ sample_id = row.get(id_column)
82
+
83
+ if not isinstance(photo_map, dict) or sample_id is None:
84
+ return {"image_path_suffix": ""}
85
+
86
+ relative_path = photo_map.get(sample_id)
87
+ if relative_path is None:
88
+ relative_path = photo_map.get(str(sample_id))
89
+
90
+ if relative_path is None or relative_path == "":
91
+ return {"image_path_suffix": ""}
92
+
93
+ return {"image_path_suffix": str(relative_path)}
94
+
95
+
96
+ def parse_numeric_value(value: Any) -> Any:
97
+ """
98
+ Convert known numeric CSV cells into readable JSON numbers.
99
+
100
+ Loader convention:
101
+ - missing numeric cell is ""
102
+ - scalar numeric cell is something like "12.3"
103
+ - vector numeric cell is something like "[1.2, 3.4]"
104
+ """
105
+ value = to_jsonable(value)
106
+
107
+ if value == "" or value is None:
108
+ return ""
109
+
110
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
111
+ return value
112
+
113
+ if isinstance(value, str):
114
+ s = value.strip()
115
+ if s == "":
116
+ return ""
117
+
118
+ if s.startswith("[") and s.endswith("]"):
119
+ parsed = ast.literal_eval(s)
120
+ if not isinstance(parsed, (list, tuple)):
121
+ raise ValueError(f"Expected numeric vector list, got: {value!r}")
122
+ return [float(x) for x in parsed]
123
+
124
+ return float(s)
125
+
126
+ return value
127
+
128
+
129
+ def create_unmasked_card(
130
+ row: Dict[str, Any],
131
+ cat_columns: list[str],
132
+ numeric_columns: list[str],
133
+ vision: Dict[str, Any],
134
+ ) -> Dict[str, Any]:
135
+ categorical = {col: row.get(col, "") for col in cat_columns if col in row}
136
+ numeric = {
137
+ col: parse_numeric_value(row.get(col, ""))
138
+ for col in numeric_columns
139
+ if col in row
140
+ }
141
+
142
+ return {
143
+ "categorical": categorical,
144
+ "numeric": numeric,
145
+ "vision": vision,
146
+ }
147
+
148
+
149
+ def choose_mask_keys(values: Dict[str, Any], ratio: float, rng: random.Random) -> list[str]:
150
+ valid_keys = [k for k, v in values.items() if v not in ("", None)]
151
+ if ratio <= 0.0 or not valid_keys:
152
+ return []
153
+
154
+ k = int(round(len(valid_keys) * ratio))
155
+ k = max(0, min(k, len(valid_keys)))
156
+ if k == 0:
157
+ return []
158
+
159
+ return rng.sample(valid_keys, k)
160
+
161
+
162
+ def create_masked_card(
163
+ unmasked_card: Dict[str, Any],
164
+ cat_mask_ratio: float,
165
+ num_mask_ratio: float,
166
+ seed: int,
167
+ ) -> Dict[str, Any]:
168
+ rng = random.Random(seed)
169
+ masked = json.loads(json.dumps(unmasked_card, ensure_ascii=False))
170
+
171
+ cat_keys = choose_mask_keys(masked["categorical"], cat_mask_ratio, rng)
172
+ num_keys = choose_mask_keys(masked["numeric"], num_mask_ratio, rng)
173
+
174
+ for key in cat_keys:
175
+ masked["categorical"][key] = None
176
+
177
+ for key in num_keys:
178
+ masked["numeric"][key] = None
179
+
180
+ return masked
181
+
182
+
183
+ def output_paths_from_given_name(given_name: str) -> tuple[Path, Path]:
184
+ path = Path(given_name)
185
+ base = path.with_suffix("") if path.suffix == ".json" else path
186
+
187
+ unmasked_path = base.with_name(base.name + "__unmasked.json")
188
+ masked_path = base.with_name(base.name + "__masked.json")
189
+ return unmasked_path, masked_path
190
+
191
+
192
+ def create_cards(
193
+ config_data_path: str,
194
+ row_index: Optional[int],
195
+ seed: int,
196
+ cat_mask_ratio: float,
197
+ num_mask_ratio: float,
198
+ ) -> tuple[Dict[str, Any], Dict[str, Any]]:
199
+ config_data = load_json(config_data_path)
200
+ csv_path = config_data["data_csv_path"]
201
+
202
+ # Match loader.py: empty cells remain "" instead of becoming NaN.
203
+ df = pd.read_csv(
204
+ csv_path,
205
+ keep_default_na=False,
206
+ na_filter=False,
207
+ low_memory=False,
208
+ )
209
+
210
+ chosen_row_index = choose_row_index(
211
+ num_rows=len(df),
212
+ row_index=row_index,
213
+ seed=seed,
214
+ )
215
+
216
+ row = {
217
+ str(k): to_jsonable(v)
218
+ for k, v in df.iloc[chosen_row_index].to_dict().items()
219
+ }
220
+
221
+ cat_columns = get_categorical_columns(config_data)
222
+ numeric_columns = get_numeric_columns(config_data)
223
+ vision = get_vision_input(config_data, row)
224
+
225
+ unmasked_card = create_unmasked_card(
226
+ row=row,
227
+ cat_columns=cat_columns,
228
+ numeric_columns=numeric_columns,
229
+ vision=vision,
230
+ )
231
+ masked_card = create_masked_card(
232
+ unmasked_card=unmasked_card,
233
+ cat_mask_ratio=cat_mask_ratio,
234
+ num_mask_ratio=num_mask_ratio,
235
+ seed=seed,
236
+ )
237
+
238
+ return unmasked_card, masked_card
239
+
240
+
241
+ def save_json_pretty(obj: Dict[str, Any], path: Path) -> None:
242
+ path.parent.mkdir(parents=True, exist_ok=True)
243
+ with path.open("w", encoding="utf-8") as f:
244
+ json.dump(obj, f, ensure_ascii=False, indent=2)
245
+ f.write("\n")
246
+
247
+
248
+ def main() -> None:
249
+ parser = argparse.ArgumentParser(
250
+ description="Create readable/editable SoilFormer input cards from one CSV row."
251
+ )
252
+ parser.add_argument(
253
+ "--config_data",
254
+ type=str,
255
+ default="config/config_data.json",
256
+ help="Path to config_data.json. Default: config/config_data.json",
257
+ )
258
+ parser.add_argument(
259
+ "--row_index",
260
+ type=str,
261
+ default=None,
262
+ help="CSV row index. Use None/null/random or omit for a random row.",
263
+ )
264
+ parser.add_argument(
265
+ "--output",
266
+ type=str,
267
+ required=True,
268
+ help="Given output name. Writes given_name__unmasked.json and given_name__masked.json.",
269
+ )
270
+ parser.add_argument(
271
+ "--cat_mask_ratio",
272
+ type=float,
273
+ default=0.15,
274
+ help="Ratio of non-missing categorical features to mask. Default: 0.15",
275
+ )
276
+ parser.add_argument(
277
+ "--num_mask_ratio",
278
+ type=float,
279
+ default=0.15,
280
+ help="Ratio of non-missing numeric features to mask. Default: 0.15",
281
+ )
282
+ parser.add_argument(
283
+ "--seed",
284
+ type=int,
285
+ default=0,
286
+ help="Seed for random row selection and feature masking. Default: 42",
287
+ )
288
+ args = parser.parse_args()
289
+
290
+ cat_mask_ratio = validate_ratio("cat_mask_ratio", args.cat_mask_ratio)
291
+ num_mask_ratio = validate_ratio("num_mask_ratio", args.num_mask_ratio)
292
+
293
+ unmasked_card, masked_card = create_cards(
294
+ config_data_path=args.config_data,
295
+ row_index=parse_optional_int(args.row_index),
296
+ seed=args.seed,
297
+ cat_mask_ratio=cat_mask_ratio,
298
+ num_mask_ratio=num_mask_ratio,
299
+ )
300
+
301
+ unmasked_path, masked_path = output_paths_from_given_name(args.output)
302
+ save_json_pretty(unmasked_card, unmasked_path)
303
+ save_json_pretty(masked_card, masked_path)
304
+
305
+ print(
306
+ json.dumps(
307
+ {
308
+ "status": "ok",
309
+ "unmasked_output": str(unmasked_path),
310
+ "masked_output": str(masked_path),
311
+ },
312
+ ensure_ascii=False,
313
+ )
314
+ )
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
inference_predict_output_card.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import json
4
+ import sys
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple
8
+ from urllib.parse import urljoin
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import requests
13
+ import torch
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+
17
+ # The script is intended to live one level above ./modelling.
18
+ # modelling/ modules still contain some legacy absolute imports, so expose the
19
+ # modelling directory on sys.path as well.
20
+ PROJECT_ROOT = Path(__file__).resolve().parent
21
+ MODELLING_DIR = PROJECT_ROOT / "modelling"
22
+ if str(MODELLING_DIR) not in sys.path:
23
+ sys.path.insert(0, str(MODELLING_DIR))
24
+
25
+ from modelling.soilformer import SoilFormer # noqa: E402
26
+ from modelling.utils import get_dtype, load_json # noqa: E402
27
+
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # JSON helpers
31
+ # -----------------------------------------------------------------------------
32
+
33
+ def load_card(path: str) -> Dict[str, Any]:
34
+ with open(path, "r", encoding="utf-8") as f:
35
+ obj = json.load(f)
36
+ if not isinstance(obj, dict):
37
+ raise ValueError(f"Card must be a JSON object: {path}")
38
+ return obj
39
+
40
+
41
+ def save_json_pretty(obj: Dict[str, Any], path: Path) -> None:
42
+ path.parent.mkdir(parents=True, exist_ok=True)
43
+ with path.open("w", encoding="utf-8") as f:
44
+ json.dump(obj, f, ensure_ascii=False, indent=2)
45
+ f.write("\n")
46
+
47
+
48
+ def to_jsonable(x: Any) -> Any:
49
+ if isinstance(x, np.generic):
50
+ return x.item()
51
+ if isinstance(x, np.ndarray):
52
+ return x.tolist()
53
+ if isinstance(x, torch.Tensor):
54
+ x = x.detach().cpu()
55
+ if x.ndim == 0:
56
+ return x.item()
57
+ return x.tolist()
58
+ if isinstance(x, dict):
59
+ return {str(k): to_jsonable(v) for k, v in x.items()}
60
+ if isinstance(x, (list, tuple)):
61
+ return [to_jsonable(v) for v in x]
62
+ return x
63
+
64
+
65
+ # -----------------------------------------------------------------------------
66
+ # Runtime / model loading
67
+ # -----------------------------------------------------------------------------
68
+
69
+ def resolve_device(device_str: str) -> torch.device:
70
+ device_str = str(device_str).lower()
71
+ if device_str == "auto":
72
+ if torch.cuda.is_available():
73
+ return torch.device("cuda")
74
+ if torch.backends.mps.is_available():
75
+ return torch.device("mps")
76
+ return torch.device("cpu")
77
+ if device_str == "cuda":
78
+ if not torch.cuda.is_available():
79
+ raise RuntimeError("--device cuda requested, but CUDA is not available")
80
+ return torch.device("cuda")
81
+ if device_str == "mps":
82
+ if not torch.backends.mps.is_available():
83
+ raise RuntimeError("--device mps requested, but MPS is not available")
84
+ return torch.device("mps")
85
+ if device_str == "cpu":
86
+ return torch.device("cpu")
87
+ raise ValueError(f"Unsupported device: {device_str}")
88
+
89
+
90
+ def load_model(args: argparse.Namespace, config_model: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> SoilFormer:
91
+ print("[INFO] Initializing model...")
92
+ model = SoilFormer(config=config_model, device=str(device))
93
+
94
+ print("[INFO] Loading checkpoint...")
95
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
96
+ missing, unexpected = model.load_state_dict(
97
+ checkpoint["model_state_dict"], strict=False
98
+ )
99
+
100
+ non_vision_missing = [k for k in missing if not k.startswith("vision_extractor.")]
101
+ if len(non_vision_missing) > 0:
102
+ raise RuntimeError(
103
+ f"[ERROR] Missing non-vision keys detected: {non_vision_missing[:10]}"
104
+ )
105
+
106
+ print(f"[INFO] Missing keys (vision only): {len(missing)}")
107
+ print(f"[INFO] Unexpected keys: {len(unexpected)}")
108
+
109
+ model.to(device=device, dtype=dtype)
110
+ model.eval()
111
+ return model
112
+
113
+
114
+ # -----------------------------------------------------------------------------
115
+ # Metadata loading
116
+ # -----------------------------------------------------------------------------
117
+
118
+ def load_metadata(config_data: Dict[str, Any]) -> Dict[str, Any]:
119
+ cat_vocab = load_json(config_data["cat_vocab_path"])
120
+ numeric_vocab = load_json(config_data["numeric_vocab_path"])
121
+
122
+ stats_df = pd.read_csv(config_data["numeric_stats_path"])
123
+ numeric_stats = {}
124
+ for _, row in stats_df.iterrows():
125
+ col = row["column"]
126
+ mean = float(row["mean"])
127
+ std = float(row["std"])
128
+ if std == 0.0:
129
+ std = 1.0
130
+ numeric_stats[str(col)] = (mean, std)
131
+
132
+ cat_columns = list(cat_vocab.keys())
133
+ cat_mask_local_ids = [int(cat_vocab[col]["mask_local_id"]) for col in cat_columns]
134
+
135
+ id_to_label_by_col = {}
136
+ for col in cat_columns:
137
+ label2id = cat_vocab[col]["label2id"]
138
+ id_to_label_by_col[col] = {int(v): str(k) for k, v in label2id.items()}
139
+
140
+ return {
141
+ "cat_vocab": cat_vocab,
142
+ "numeric_vocab": numeric_vocab,
143
+ "numeric_stats": numeric_stats,
144
+ "cat_columns": cat_columns,
145
+ "cat_mask_local_ids": cat_mask_local_ids,
146
+ "id_to_label_by_col": id_to_label_by_col,
147
+ }
148
+
149
+
150
+ # -----------------------------------------------------------------------------
151
+ # Image handling, matching loader.py behavior
152
+ # -----------------------------------------------------------------------------
153
+
154
+ class CenterSquareCrop:
155
+ def __call__(self, img: Image.Image) -> Image.Image:
156
+ w, h = img.size
157
+ if w == h:
158
+ return img
159
+ if w > h:
160
+ left = (w - h) // 2
161
+ return img.crop((left, 0, left + h, h))
162
+ top = (h - w) // 2
163
+ return img.crop((0, top, w, top + w))
164
+
165
+
166
+ def build_image_transform(image_size: int):
167
+ return transforms.Compose([
168
+ CenterSquareCrop(),
169
+ transforms.Resize((image_size, image_size)),
170
+ transforms.ToTensor(),
171
+ ])
172
+
173
+
174
+ def join_photo_root(photo_root: str, relative_path: str) -> str:
175
+ if photo_root.startswith("http://") or photo_root.startswith("https://"):
176
+ return urljoin(photo_root.rstrip("/") + "/", relative_path)
177
+ return photo_root.rstrip("/") + "/" + relative_path.lstrip("/")
178
+
179
+
180
+ def load_image_tensor(image_path: str, image_size: int) -> torch.Tensor:
181
+ if image_path.startswith("http://") or image_path.startswith("https://"):
182
+ resp = requests.get(image_path, timeout=(3, 10))
183
+ resp.raise_for_status()
184
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
185
+ else:
186
+ img = Image.open(image_path).convert("RGB")
187
+ return build_image_transform(image_size)(img)
188
+
189
+
190
+ # -----------------------------------------------------------------------------
191
+ # Tensorization from readable input card
192
+ # -----------------------------------------------------------------------------
193
+
194
+ def is_masked_or_missing(value: Any) -> bool:
195
+ return value is None or value == ""
196
+
197
+
198
+ def parse_numeric_card_value(value: Any, n_in: int) -> Tuple[list[float], bool]:
199
+ if value is None or value == "":
200
+ return [0.0] * n_in, False
201
+
202
+ if n_in == 1:
203
+ if isinstance(value, list):
204
+ if len(value) != 1:
205
+ raise ValueError(f"Expected scalar or length-1 list for n_in=1, got {value!r}")
206
+ return [float(value[0])], True
207
+ return [float(value)], True
208
+
209
+ if isinstance(value, str):
210
+ parsed = ast.literal_eval(value)
211
+ else:
212
+ parsed = value
213
+
214
+ if not isinstance(parsed, (list, tuple)):
215
+ raise ValueError(f"Expected list-like numeric vector for n_in={n_in}, got {value!r}")
216
+ if len(parsed) != n_in:
217
+ raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(parsed)}")
218
+ return [float(v) for v in parsed], True
219
+
220
+
221
+ def tensorize_card(
222
+ input_card: Dict[str, Any],
223
+ config_data: Dict[str, Any],
224
+ meta: Dict[str, Any],
225
+ ) -> Dict[str, Any]:
226
+ categorical = input_card.get("categorical", {})
227
+ numeric = input_card.get("numeric", {})
228
+ vision = input_card.get("vision", {})
229
+
230
+ if not isinstance(categorical, dict):
231
+ raise ValueError("input_card['categorical'] must be an object")
232
+ if not isinstance(numeric, dict):
233
+ raise ValueError("input_card['numeric'] must be an object")
234
+ if not isinstance(vision, dict):
235
+ vision = {}
236
+
237
+ # Categorical: raw label -> local id, null/"" -> mask id and invalid.
238
+ cat_ids = []
239
+ cat_valids = []
240
+ for col, mask_id in zip(meta["cat_columns"], meta["cat_mask_local_ids"]):
241
+ value = categorical.get(col, "")
242
+ if is_masked_or_missing(value):
243
+ cat_ids.append(mask_id)
244
+ cat_valids.append(False)
245
+ else:
246
+ label2id = meta["cat_vocab"][col]["label2id"]
247
+ if value not in label2id:
248
+ raise KeyError(f"Unknown categorical value: column={col}, value={value!r}")
249
+ cat_ids.append(int(label2id[value]))
250
+ cat_valids.append(True)
251
+
252
+ cat_local_ids = torch.tensor([cat_ids], dtype=torch.long)
253
+ cat_valid_positions = torch.tensor([cat_valids], dtype=torch.bool)
254
+
255
+ # Numeric: raw actual units -> z-score grouped tensors.
256
+ numeric_values_by_nin = {}
257
+ numeric_valid_positions_by_nin = {}
258
+
259
+ for group in meta["numeric_vocab"]["groups"]:
260
+ n_in = int(group["n_in"])
261
+ values = []
262
+ valids = []
263
+ for feat in group["feature_names"]:
264
+ feat = str(feat)
265
+ raw_value = numeric.get(feat, "")
266
+ parsed, is_valid = parse_numeric_card_value(raw_value, n_in)
267
+ if is_valid:
268
+ mean, std = meta["numeric_stats"][feat]
269
+ parsed = [(v - mean) / std for v in parsed]
270
+ values.append(parsed)
271
+ valids.append(is_valid)
272
+
273
+ numeric_values_by_nin[n_in] = torch.tensor([values], dtype=torch.float32)
274
+ numeric_valid_positions_by_nin[n_in] = torch.tensor([valids], dtype=torch.bool)
275
+
276
+ # Vision: readable card stores suffix only. Load/transform here.
277
+ image_size = int(config_data["image_size"])
278
+ image_path_suffix = vision.get("image_path_suffix", "")
279
+ if image_path_suffix is None or image_path_suffix == "":
280
+ pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32)
281
+ vision_valid_positions = torch.tensor([False], dtype=torch.bool)
282
+ else:
283
+ image_path = join_photo_root(str(config_data["photo_root"]), str(image_path_suffix))
284
+ try:
285
+ image = load_image_tensor(image_path, image_size=image_size)
286
+ pixel_values = image.unsqueeze(0)
287
+ vision_valid_positions = torch.tensor([True], dtype=torch.bool)
288
+ except Exception as exc:
289
+ print(f"[WARN] Could not load image; using zero vision input: {exc}")
290
+ pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32)
291
+ vision_valid_positions = torch.tensor([False], dtype=torch.bool)
292
+
293
+ return {
294
+ "cat_local_ids": cat_local_ids,
295
+ "cat_valid_positions": cat_valid_positions,
296
+ "numeric_values_by_nin": numeric_values_by_nin,
297
+ "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
298
+ "pixel_values": pixel_values,
299
+ "vision_valid_positions": vision_valid_positions,
300
+ }
301
+
302
+
303
+ def move_batch_to_device(batch: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> Dict[str, Any]:
304
+ out = {}
305
+ for key, value in batch.items():
306
+ if isinstance(value, torch.Tensor):
307
+ if value.dtype.is_floating_point:
308
+ out[key] = value.to(device=device, dtype=dtype)
309
+ else:
310
+ out[key] = value.to(device=device)
311
+ elif isinstance(value, dict):
312
+ sub = {}
313
+ for k, v in value.items():
314
+ if isinstance(v, torch.Tensor):
315
+ if v.dtype.is_floating_point:
316
+ sub[k] = v.to(device=device, dtype=dtype)
317
+ else:
318
+ sub[k] = v.to(device=device)
319
+ else:
320
+ sub[k] = v
321
+ out[key] = sub
322
+ else:
323
+ out[key] = value
324
+ return out
325
+
326
+
327
+ # -----------------------------------------------------------------------------
328
+ # Decoding model outputs to readable card
329
+ # -----------------------------------------------------------------------------
330
+
331
+ def denormalize_numeric(values_z: list[float], mean: float, std: float) -> list[float]:
332
+ return [float(v) * float(std) + float(mean) for v in values_z]
333
+
334
+
335
+ def decode_outputs(
336
+ cat_logits_padded: torch.Tensor,
337
+ valid_class_mask: torch.Tensor,
338
+ value_by_nin: Dict[int, torch.Tensor],
339
+ meta: Dict[str, Any],
340
+ ) -> Dict[str, Any]:
341
+ cat_logits = cat_logits_padded.detach().float().cpu()
342
+ valid_class_mask = valid_class_mask.detach().cpu().bool()
343
+
344
+ categorical_out = {}
345
+ for m, col in enumerate(meta["cat_columns"]):
346
+ cm = int(valid_class_mask[m].sum().item())
347
+ logits = cat_logits[0, m, :cm]
348
+ probs = torch.softmax(logits, dim=-1)
349
+ pred_id = int(torch.argmax(probs).item())
350
+ pred_label = meta["id_to_label_by_col"][col].get(pred_id, str(pred_id))
351
+ categorical_out[col] = pred_label
352
+
353
+ numeric_out = {}
354
+ for group in meta["numeric_vocab"]["groups"]:
355
+ n_in = int(group["n_in"])
356
+ preds_z = value_by_nin[n_in].detach().float().cpu()[0] # [V, n_in]
357
+ for v_idx, feat in enumerate(group["feature_names"]):
358
+ feat = str(feat)
359
+ mean, std = meta["numeric_stats"][feat]
360
+ raw_pred_values = denormalize_numeric(preds_z[v_idx].tolist(), mean, std)
361
+ if n_in == 1:
362
+ numeric_out[feat] = raw_pred_values[0]
363
+ else:
364
+ numeric_out[feat] = raw_pred_values
365
+
366
+ return {
367
+ "categorical": categorical_out,
368
+ "numeric": numeric_out,
369
+ }
370
+
371
+
372
+ # -----------------------------------------------------------------------------
373
+ # Accuracy / MAE analysis
374
+ # -----------------------------------------------------------------------------
375
+
376
+ def masked_feature_names(input_card: Dict[str, Any], section: str) -> list[str]:
377
+ values = input_card.get(section, {})
378
+ if not isinstance(values, dict):
379
+ return []
380
+ return [k for k, v in values.items() if v is None]
381
+
382
+
383
+ def numeric_abs_errors(pred_value: Any, answer_value: Any) -> list[float]:
384
+ if answer_value is None or answer_value == "":
385
+ return []
386
+ if pred_value is None or pred_value == "":
387
+ return []
388
+
389
+ if isinstance(answer_value, str):
390
+ s = answer_value.strip()
391
+ if s == "":
392
+ return []
393
+ if s.startswith("[") and s.endswith("]"):
394
+ answer_value = [float(x) for x in ast.literal_eval(s)]
395
+ else:
396
+ answer_value = float(s)
397
+
398
+ if isinstance(pred_value, str):
399
+ s = pred_value.strip()
400
+ if s.startswith("[") and s.endswith("]"):
401
+ pred_value = [float(x) for x in ast.literal_eval(s)]
402
+ else:
403
+ pred_value = float(s)
404
+
405
+ if isinstance(answer_value, (list, tuple)):
406
+ if not isinstance(pred_value, (list, tuple)):
407
+ return []
408
+ if len(pred_value) != len(answer_value):
409
+ return []
410
+ return [abs(float(p) - float(a)) for p, a in zip(pred_value, answer_value)]
411
+
412
+ return [abs(float(pred_value) - float(answer_value))]
413
+
414
+
415
+ def evaluate_against_answer(
416
+ input_card: Dict[str, Any],
417
+ output_card: Dict[str, Any],
418
+ answer_card: Dict[str, Any],
419
+ ) -> Dict[str, Any]:
420
+ cat_masked = masked_feature_names(input_card, "categorical")
421
+ num_masked = masked_feature_names(input_card, "numeric")
422
+
423
+ cat_details = {}
424
+ correct = 0
425
+ total = 0
426
+ for feat in cat_masked:
427
+ answer = answer_card.get("categorical", {}).get(feat)
428
+ pred = output_card.get("categorical", {}).get(feat)
429
+ if answer is None or answer == "":
430
+ continue
431
+ is_correct = pred == answer
432
+ cat_details[feat] = {
433
+ "predicted": pred,
434
+ "answer": answer,
435
+ "correct": bool(is_correct),
436
+ }
437
+ correct += int(is_correct)
438
+ total += 1
439
+
440
+ num_details = {}
441
+ abs_errors_all = []
442
+ for feat in num_masked:
443
+ answer = answer_card.get("numeric", {}).get(feat)
444
+ pred = output_card.get("numeric", {}).get(feat)
445
+ errors = numeric_abs_errors(pred, answer)
446
+ if not errors:
447
+ continue
448
+ mae = sum(errors) / len(errors)
449
+ num_details[feat] = {
450
+ "predicted": pred,
451
+ "answer": answer,
452
+ "absolute_error": errors[0] if len(errors) == 1 else errors,
453
+ "mae": mae,
454
+ }
455
+ abs_errors_all.extend(errors)
456
+
457
+ return {
458
+ "categorical": {
459
+ "accuracy": None if total == 0 else correct / total,
460
+ "correct": correct,
461
+ "total": total,
462
+ "details": cat_details,
463
+ },
464
+ "numeric": {
465
+ "mae": None if len(abs_errors_all) == 0 else sum(abs_errors_all) / len(abs_errors_all),
466
+ "count": len(abs_errors_all),
467
+ "details": num_details,
468
+ },
469
+ "note": "Metrics are computed only on fields that are null in input_card. Natural missing values \"\" are ignored.",
470
+ }
471
+
472
+
473
+ def acc_path_from_output(output: str) -> Path:
474
+ path = Path(output)
475
+ if path.suffix == ".json":
476
+ base = path.with_suffix("")
477
+ else:
478
+ base = path
479
+ return base.with_name(base.name + "__acc.json")
480
+
481
+
482
+ # -----------------------------------------------------------------------------
483
+ # CLI
484
+ # -----------------------------------------------------------------------------
485
+
486
+ def main() -> None:
487
+ parser = argparse.ArgumentParser(description="Run SoilFormer inference from a readable input card.")
488
+ parser.add_argument("--input_card", type=str, required=True)
489
+ parser.add_argument("--output", type=str, required=True)
490
+ parser.add_argument("--answer_card", type=str, default=None)
491
+ parser.add_argument("--checkpoint", type=str, required=True)
492
+ parser.add_argument("--config_data", type=str, default="config/config_data.json")
493
+ parser.add_argument("--config_model", type=str, default="config/config_model.json")
494
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"])
495
+ args = parser.parse_args()
496
+
497
+ config_data = load_json(args.config_data)
498
+ config_model = load_json(args.config_model)
499
+ dtype = get_dtype(config_model.get("dtype", "bfloat16"))
500
+ device = resolve_device(args.device)
501
+
502
+ meta = load_metadata(config_data)
503
+ input_card = load_card(args.input_card)
504
+ batch = tensorize_card(input_card=input_card, config_data=config_data, meta=meta)
505
+ batch = move_batch_to_device(batch, device=device, dtype=dtype)
506
+
507
+ model = load_model(args=args, config_model=config_model, device=device, dtype=dtype)
508
+
509
+ with torch.no_grad():
510
+ cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model(
511
+ cat_local_ids=batch["cat_local_ids"],
512
+ numeric_values_by_nin=batch["numeric_values_by_nin"],
513
+ cat_valid_positions=batch["cat_valid_positions"],
514
+ numeric_valid_positions_by_nin=batch["numeric_valid_positions_by_nin"],
515
+ pixel_values=batch["pixel_values"],
516
+ vision_valid_positions=batch["vision_valid_positions"],
517
+ )
518
+
519
+ output_card = decode_outputs(
520
+ cat_logits_padded=cat_logits_padded,
521
+ valid_class_mask=valid_class_mask,
522
+ value_by_nin=value_by_nin,
523
+ meta=meta,
524
+ )
525
+
526
+ save_json_pretty(to_jsonable(output_card), Path(args.output))
527
+
528
+ result = {"status": "ok", "output": args.output}
529
+
530
+ if args.answer_card:
531
+ answer_card = load_card(args.answer_card)
532
+ acc_card = evaluate_against_answer(
533
+ input_card=input_card,
534
+ output_card=output_card,
535
+ answer_card=answer_card,
536
+ )
537
+ acc_path = acc_path_from_output(args.output)
538
+ save_json_pretty(to_jsonable(acc_card), acc_path)
539
+ result["acc_output"] = str(acc_path)
540
+
541
+ print(json.dumps(result, ensure_ascii=False))
542
+
543
+
544
+ if __name__ == "__main__":
545
+ main()
model_weights/gemma3n_E2B_vision_only/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df49c2835315d4de6753bea989198e66157d84aa831738227f3bc705eab2d746
3
+ size 4455
model_weights/gemma3n_E2B_vision_only/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eed8742f2e68b0d28bac29ee591a97e6738b6d040e0a5b69d270fca1d1453e20
3
+ size 597245920
model_weights/gemma3n_E2B_vision_only/modeling_gemma3n.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78b0b5d14177913d7956279f7a08b62f45f5b0ca6ab1993507fc653ad9579b0c
3
+ size 114392
model_weights/gemma3n_E2B_vision_only/processor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3f52ae9fb2eeed632fc99f14fa8b4405b17cd4b760a369cddf366f9ccf6855b
3
+ size 2262
model_weights/gemma3n_E2B_vision_only/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fad9b5f6f930b43d292eb3c56c176a69292850ddd0abc02d9ea1dac3292c87a
3
+ size 33442428
model_weights/gemma3n_E2B_vision_only/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10c688d1767007b8614f275427198205507d941aefa6ae63c3e429ef87de7999
3
+ size 936
model_weights/gemma3n_E2B_vision_only/vision_extractor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea31eaf2aec2df075d62a4bca2209763e97a0141122257b07e62fe79e3cf4564
3
+ size 156
model_weights/soilformer_pretrain/hetero_epoch_200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:057cd623e72bbf477bd46f346506acfd5741c2b57326d2bc73e723ac3ea949fc
3
+ size 276126967
modelling/__init__.py ADDED
File without changes
modelling/decode_categorical.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # decode_categorical.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Categorical decoder for tabular transformer.
6
+
7
+ Design (column-wise heads):
8
+ - Each categorical column corresponds to exactly 1 token.
9
+ - Each column has its own classifier head:
10
+ hidden_size -> num_classes[col]
11
+ Optionally with a small MLP:
12
+ hidden_size -> middle_size -> num_classes[col]
13
+
14
+ No loss is included here (caller will apply CrossEntropyLoss).
15
+ """
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from utils import load_json, GroupedMLP
23
+
24
+
25
+ # ============================================================
26
+ # Small head builder
27
+ # ============================================================
28
+
29
+ def _make_head(
30
+ hidden_size: int,
31
+ num_classes: int,
32
+ middle_size: Optional[int],
33
+ bias: bool = True,
34
+ ) -> nn.Module:
35
+ """
36
+ Build a lightweight per-column classifier head.
37
+ """
38
+ if middle_size is None:
39
+ return nn.Linear(hidden_size, num_classes, bias=bias)
40
+
41
+ return nn.Sequential(
42
+ nn.Linear(hidden_size, middle_size, bias=bias),
43
+ nn.GELU(),
44
+ nn.Linear(middle_size, num_classes, bias=bias),
45
+ )
46
+
47
+
48
+ # ============================================================
49
+ # Decoder
50
+ # ============================================================
51
+
52
+ class CategoricalDecoder(nn.Module):
53
+ """
54
+ Column-wise categorical decoder.
55
+
56
+ Design:
57
+ - Each categorical column corresponds to exactly one token.
58
+ - Each column has its own classifier head:
59
+ hidden_size -> num_classes[col]
60
+ Optionally with a small MLP:
61
+ hidden_size -> middle_size -> num_classes[col]
62
+
63
+ - In addition, the decoder predicts a per-sample, per-column
64
+ log-variance term `s` used for heteroscedastic loss weighting.
65
+
66
+ Input:
67
+ x_cat_tokens: [B, M, H]
68
+ B = batch size
69
+ M = number of categorical columns (ordered by col_id)
70
+ H = hidden size
71
+
72
+ Outputs:
73
+
74
+ Case 1 (return_padded=False):
75
+ logits_list: List[Tensor] length M
76
+ logits_list[m]: [B, num_classes[m]]
77
+
78
+ s: [B, M]
79
+ Predicted log-variance per sample and column:
80
+ s[b, m] = log sigma^2_{b,m}
81
+ Intended for heteroscedastic loss weighting.
82
+
83
+ Case 2 (return_padded=True):
84
+ logits_padded: [B, M, Cmax]
85
+ Logits padded to the maximum class count across columns.
86
+
87
+ s: [B, M]
88
+ Same uncertainty prediction as above.
89
+
90
+ valid_mask: [M, Cmax]
91
+ True for valid class indices for each column.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ hidden_size: int,
97
+ cat_vocab_json: str,
98
+ middle_size: Optional[int] = None,
99
+ bias: bool = True,
100
+ homoscedastic: bool = True,
101
+ ):
102
+ super().__init__()
103
+
104
+ spec = load_json(cat_vocab_json)
105
+ items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
106
+
107
+ col_ids: List[int] = []
108
+ num_classes: List[int] = []
109
+
110
+ for _, val in items:
111
+ col_ids.append(int(val["col_id"]))
112
+ num_classes.append(int(val["num_classes"]))
113
+
114
+ self.hidden_size = int(hidden_size)
115
+ self.num_cols = len(num_classes)
116
+ self.middle_size = middle_size
117
+ self.homoscedastic = bool(homoscedastic)
118
+
119
+ # Buffers for debugging / validation / optional padded output
120
+ self.register_buffer("cat_col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) # [M]
121
+ self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) # [M]
122
+
123
+ # Build per-column heads
124
+ heads = []
125
+ for c in num_classes:
126
+ head = _make_head(self.hidden_size, c, middle_size, bias=bias)
127
+ heads.append(head)
128
+
129
+ self.heads = nn.ModuleList(heads)
130
+
131
+ if self.homoscedastic:
132
+ self.s_param = nn.Parameter(torch.zeros(self.num_cols))
133
+ self.s_head = None
134
+ else:
135
+ self.s_head = GroupedMLP(
136
+ n_var=self.num_cols,
137
+ n_in=self.hidden_size,
138
+ n_out=1,
139
+ middle_size=self.middle_size,
140
+ )
141
+ self.s_param = None
142
+
143
+ def init_weights(self, std: float = 0.02):
144
+ for head in self.heads:
145
+ for module in head.modules():
146
+ if isinstance(module, nn.Linear):
147
+ nn.init.normal_(module.weight, std=std)
148
+ if module.bias is not None:
149
+ nn.init.zeros_(module.bias)
150
+
151
+ if self.homoscedastic:
152
+ nn.init.zeros_(self.s_param)
153
+ else:
154
+ self.s_head.init_weights(std=0.0)
155
+
156
+ def _check_input(self, x_cat_tokens: torch.Tensor) -> Tuple[int, int, int]:
157
+ if x_cat_tokens.dim() != 3:
158
+ raise ValueError(f"x_cat_tokens must be [B,M,H], got {tuple(x_cat_tokens.shape)}")
159
+ B, M, H = x_cat_tokens.shape
160
+ if H != self.hidden_size:
161
+ raise ValueError(f"hidden_size mismatch: got {H}, expected {self.hidden_size}")
162
+ if M != self.num_cols:
163
+ raise ValueError(f"categorical token count mismatch: got M={M}, expected {self.num_cols}")
164
+ return B, M, H
165
+
166
+ @torch.no_grad()
167
+ def _build_valid_mask(self, device: torch.device) -> torch.Tensor:
168
+ """
169
+ valid_mask[m, j] = True iff j < num_classes[m]
170
+ """
171
+ M = self.num_cols
172
+ cmax = int(self.num_classes.max().item())
173
+ ar = torch.arange(cmax, device=device).view(1, cmax).expand(M, cmax)
174
+ nc = self.num_classes.view(M, 1).expand(M, cmax)
175
+ return ar < nc
176
+
177
+ def forward(
178
+ self,
179
+ x_cat_tokens: torch.Tensor,
180
+ return_padded: bool = False,
181
+ pad_value: Optional[float] = None,
182
+ ) -> Union[
183
+ Tuple[List[torch.Tensor], torch.Tensor],
184
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
185
+ ]:
186
+ """
187
+ Args:
188
+ x_cat_tokens: [B, M, H]
189
+ B = batch size
190
+ M = number of categorical columns
191
+ H = hidden size (per-column token embedding dim)
192
+
193
+ return_padded:
194
+ False:
195
+ return (logits_list, s)
196
+ True:
197
+ return (logits_padded, s, valid_mask)
198
+
199
+ pad_value:
200
+ Value used to fill invalid class positions in padded logits.
201
+
202
+ Returns:
203
+
204
+ Case 1 (return_padded=False):
205
+ logits_list: List length M
206
+ logits_list[m]: [B, C_m]
207
+ s: [B, M]
208
+ s[b, m] = log sigma^2 for sample b, column m
209
+
210
+ Case 2 (return_padded=True):
211
+ logits_padded: [B, M, Cmax]
212
+ s: [B, M]
213
+ valid_mask: [M, Cmax]
214
+ """
215
+
216
+ # --------------------------------------------------------
217
+ # 1) Basic shape validation
218
+ # --------------------------------------------------------
219
+ # Ensures x_cat_tokens is [B,M,H] and matches decoder config
220
+ B, M, _ = self._check_input(x_cat_tokens)
221
+
222
+ # --------------------------------------------------------
223
+ # 2) Per-column categorical logits
224
+ # --------------------------------------------------------
225
+ # We still use per-column heads because each column
226
+ # can have a different number of classes C_m.
227
+ #
228
+ # logits_list[m] shape: [B, C_m]
229
+ logits_list: List[torch.Tensor] = []
230
+ for m in range(M):
231
+ # x_cat_tokens[:, m, :] -> [B,H]
232
+ # heads[m] maps H -> C_m
233
+ logits_m = self.heads[m](x_cat_tokens[:, m, :])
234
+ logits_list.append(logits_m)
235
+
236
+ # --------------------------------------------------------
237
+ # 3) Sample-wise & column-wise uncertainty (log-variance)
238
+ # --------------------------------------------------------
239
+ # s_head processes all columns at once (grouped, no loop)
240
+ #
241
+ # Input: [B,M,H]
242
+ # Output: [B,M]
243
+ #
244
+ # s[b,m] = log(sigma_{b,m}^2)
245
+ if self.homoscedastic:
246
+ s = self.s_param.unsqueeze(0).expand(B, -1)
247
+ else:
248
+ s = self.s_head(x_cat_tokens).squeeze(-1)
249
+
250
+ # --------------------------------------------------------
251
+ # 4) If no padded output requested
252
+ # --------------------------------------------------------
253
+ if not return_padded:
254
+ # Return:
255
+ # logits_list: List of length M
256
+ # s: [B,M]
257
+ return logits_list, s
258
+
259
+ # --------------------------------------------------------
260
+ # 5) Build padded logits tensor
261
+ # --------------------------------------------------------
262
+ # We unify different C_m into a common Cmax.
263
+ #
264
+ # logits_padded shape: [B,M,Cmax]
265
+ cmax = int(self.num_classes.max().item())
266
+
267
+ if pad_value is None:
268
+ pad_value = torch.finfo(x_cat_tokens.dtype).min
269
+ logits_padded = torch.full(
270
+ (B, M, cmax),
271
+ pad_value,
272
+ device=x_cat_tokens.device,
273
+ dtype=x_cat_tokens.dtype,
274
+ )
275
+
276
+ # Fill valid class positions per column
277
+ for m in range(M):
278
+ cm = logits_list[m].size(-1) # C_m
279
+ logits_padded[:, m, :cm] = logits_list[m]
280
+
281
+ # --------------------------------------------------------
282
+ # 6) Build validity mask
283
+ # --------------------------------------------------------
284
+ # valid_mask[m,j] = True if j < C_m
285
+ # = False otherwise
286
+ #
287
+ # Shape: [M, Cmax]
288
+ valid_class_mask = self._build_valid_mask(device=x_cat_tokens.device)
289
+
290
+ # --------------------------------------------------------
291
+ # 7) Return padded outputs
292
+ # --------------------------------------------------------
293
+ return logits_padded, s, valid_class_mask
294
+
295
+
296
+ # ============================================================
297
+ # DEMO
298
+ # ============================================================
299
+
300
+ def _demo_main():
301
+ import argparse
302
+
303
+ parser = argparse.ArgumentParser()
304
+ parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json")
305
+ parser.add_argument("--hidden_size", type=int, default=768)
306
+ parser.add_argument("--middle_size", type=int, default=None)
307
+ parser.add_argument("--batch_size", type=int, default=4)
308
+ parser.add_argument("--device", type=str, default=None)
309
+ parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
310
+ args = parser.parse_args()
311
+
312
+ device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
313
+ dtype_map = {
314
+ "float16": torch.float16,
315
+ "bfloat16": torch.bfloat16,
316
+ "float32": torch.float32,
317
+ }
318
+ dtype = dtype_map[args.dtype]
319
+
320
+ # --------------------------------------------------------
321
+ # Load vocab spec
322
+ # --------------------------------------------------------
323
+ spec = load_json(args.cat_vocab_json)
324
+ items = sorted(spec.items(), key=lambda x_: x_[1]["col_id"])
325
+
326
+ M = len(items)
327
+ B = args.batch_size
328
+ H = args.hidden_size
329
+
330
+ num_classes = [int(s["num_classes"]) for _, s in items]
331
+
332
+ print("===== Categorical Columns =====")
333
+ for i, (name, s) in enumerate(items):
334
+ print(f"{i:03d} {name:20s} classes={s['num_classes']}")
335
+ print()
336
+
337
+ # --------------------------------------------------------
338
+ # Build model
339
+ # --------------------------------------------------------
340
+ model = CategoricalDecoder(
341
+ hidden_size=args.hidden_size,
342
+ cat_vocab_json=args.cat_vocab_json,
343
+ middle_size=args.middle_size,
344
+ ).to(device=device, dtype=dtype)
345
+
346
+ total_params = sum(p.numel() for p in model.parameters())
347
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
348
+
349
+ print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")
350
+ print()
351
+
352
+ # --------------------------------------------------------
353
+ # Fake input tokens
354
+ # --------------------------------------------------------
355
+ x = torch.randn(B, M, H, device=device, dtype=dtype)
356
+
357
+ print("Input tokens shape:", tuple(x.shape))
358
+ print()
359
+
360
+ # --------------------------------------------------------
361
+ # Case 1: logits_list
362
+ # --------------------------------------------------------
363
+ print("===== Forward: logits_list mode =====")
364
+
365
+ with torch.no_grad():
366
+ logits_list, s = model(x, return_padded=False)
367
+
368
+ for m, (name, spec_item) in enumerate(items):
369
+ C = spec_item["num_classes"]
370
+ print(f"{m:03d} {name:20s} logits:", tuple(logits_list[m].shape), f"(expected {(B, C)})")
371
+
372
+ print("s shape:", tuple(s.shape))
373
+ print()
374
+
375
+ # --------------------------------------------------------
376
+ # Case 2: padded logits
377
+ # --------------------------------------------------------
378
+ print("===== Forward: padded mode =====")
379
+
380
+ with torch.no_grad():
381
+ logits_padded, s2, valid_mask = model(x, return_padded=True)
382
+
383
+ print("logits_padded:", tuple(logits_padded.shape))
384
+ print("s:", tuple(s2.shape))
385
+ print("valid_mask:", tuple(valid_mask.shape))
386
+ print()
387
+
388
+ # --------------------------------------------------------
389
+ # Visualize valid mask
390
+ # --------------------------------------------------------
391
+ print("===== Valid class mask (first 10 columns) =====")
392
+
393
+ cols_to_show = min(10, M)
394
+ for m in range(cols_to_show):
395
+ cm = num_classes[m]
396
+ valid = valid_mask[m].sum().item()
397
+ print(f"col {m:02d} num_classes={cm} valid_mask_sum={valid}")
398
+
399
+ print()
400
+
401
+ # --------------------------------------------------------
402
+ # Check padded logits correctness
403
+ # --------------------------------------------------------
404
+ print("===== Padded logits sanity check =====")
405
+
406
+ for m in range(cols_to_show):
407
+ cm = num_classes[m]
408
+
409
+ valid_region = logits_padded[:, m, :cm]
410
+ padded_region = logits_padded[:, m, cm:]
411
+
412
+ print(f"col {m:02d} valid region shape:", tuple(valid_region.shape))
413
+
414
+ if padded_region.numel() > 0:
415
+ print(f"col {m:02d} padded region mean:", padded_region.mean().item())
416
+
417
+ print()
418
+
419
+ print("Demo finished successfully.")
420
+
421
+
422
+ if __name__ == "__main__":
423
+ _demo_main()
modelling/decode_numeric.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # decode_numeric.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Numeric decoder module for tabular transformer.
6
+
7
+ Symmetric to embed_numeric.py (bucketed by n_in):
8
+ - For each bucket (same n_in), we decode tokens without a Python for-loop over columns.
9
+ - Uses a batched per-variable MLP with per-column parameters (NOT shared across V).
10
+
11
+ Input:
12
+ x_tokens: [B, total_numeric_tokens, H]
13
+ token order must match numeric_vocab.json:
14
+ groups by n_in ascending, within group by feature name,
15
+ and within each feature: n_in tokens.
16
+
17
+ Output:
18
+ values_by_nin: Dict[int, Tensor]
19
+ n_in -> x_hat [B, V, n_in]
20
+
21
+ middle_size:
22
+ - None: 1-layer per-variable Linear
23
+ - int : 2-layer per-variable MLP (Linear -> GELU -> Linear)
24
+ """
25
+
26
+ from typing import Dict, List, Optional
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from utils import GroupedMLP, load_json
32
+
33
+
34
+ class NumericDecoder(nn.Module):
35
+ """
36
+ Decode numeric tokens back to numeric values, bucketed by n_in.
37
+
38
+ Input:
39
+ x_tokens: [B, total_numeric_tokens, H]
40
+
41
+ Output:
42
+ values_by_nin:
43
+ n_in -> y_hat [B, V, n_in]
44
+
45
+ s_by_nin:
46
+ n_in -> s [B, V]
47
+ where s = log(sigma^2), shared across the n_in dimensions
48
+ of each variable, intended for heteroscedastic loss computation.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ hidden_size: int,
54
+ numeric_vocab_json: str,
55
+ middle_size: Optional[int] = None,
56
+ homoscedastic: bool = True,
57
+ ):
58
+ super().__init__()
59
+ self.hidden_size = int(hidden_size)
60
+ self.middle_size = None if middle_size is None else int(middle_size)
61
+ self.homoscedastic = bool(homoscedastic)
62
+
63
+ spec = load_json(numeric_vocab_json)
64
+ self.groups: List[Dict] = list(spec["groups"])
65
+ self.total_numeric_tokens = int(spec["total_numeric_tokens"])
66
+ self.group_token_offsets: Dict[str, int] = dict(spec.get("group_token_offsets", {}))
67
+
68
+ self.group_v_decoders = nn.ModuleList()
69
+ self.group_s_decoders = nn.ModuleList()
70
+ self.group_nins: List[int] = []
71
+ self.group_Vs: List[int] = []
72
+
73
+ for g in self.groups:
74
+ n_in = int(g["n_in"])
75
+ names = list(g["feature_names"])
76
+ V = len(names)
77
+
78
+ self.group_nins.append(n_in) # noqa
79
+ self.group_Vs.append(V)
80
+
81
+ # value decoder: [B,V,n_in*H] -> [B,V,n_in]
82
+ self.group_v_decoders.append(
83
+ GroupedMLP(
84
+ n_var=V,
85
+ n_in=n_in * self.hidden_size,
86
+ n_out=n_in,
87
+ middle_size=self.middle_size,
88
+ )
89
+ )
90
+
91
+ # uncertainty decoder: [B,V,H] -> [B,V,1] -> [B,V]
92
+ if not self.homoscedastic:
93
+ self.group_s_decoders.append(
94
+ GroupedMLP(
95
+ n_var=V,
96
+ n_in=self.hidden_size,
97
+ n_out=1,
98
+ middle_size=self.middle_size,
99
+ )
100
+ )
101
+
102
+ if self.homoscedastic:
103
+ self.group_s_params = nn.ParameterList(
104
+ [nn.Parameter(torch.zeros(V)) for V in self.group_Vs]
105
+ )
106
+ else:
107
+ self.group_s_params = None
108
+
109
+ # spec integrity check
110
+ running = 0
111
+ for g in self.groups:
112
+ n_in = int(g["n_in"])
113
+ V = len(g["feature_names"])
114
+ key = str(n_in)
115
+
116
+ if key not in self.group_token_offsets:
117
+ raise ValueError(f"Missing group_token_offsets entry for n_in={n_in}")
118
+ if int(self.group_token_offsets[key]) != running:
119
+ raise ValueError(
120
+ f"group_token_offsets[{key}]={self.group_token_offsets[key]} does not match expected {running}"
121
+ )
122
+
123
+ running += V * n_in
124
+
125
+ if running != self.total_numeric_tokens:
126
+ raise ValueError(
127
+ f"total_numeric_tokens={self.total_numeric_tokens} does not match expected {running}"
128
+ )
129
+
130
+ def init_weights(self, std: float = 0.02):
131
+ for dec in self.group_v_decoders:
132
+ dec.init_weights(std=std)
133
+
134
+ if self.homoscedastic:
135
+ for p in self.group_s_params:
136
+ nn.init.zeros_(p)
137
+ else:
138
+ for dec in self.group_s_decoders:
139
+ dec.init_weights(std=0.0)
140
+
141
+ def forward(self, x_tokens: torch.Tensor):
142
+ if x_tokens.dim() != 3:
143
+ raise ValueError(f"x_tokens must be [B,T,H], got {tuple(x_tokens.shape)}")
144
+
145
+ B, T, H = x_tokens.shape
146
+ if H != self.hidden_size:
147
+ raise ValueError(f"hidden_size mismatch: got H={H}, expected {self.hidden_size}")
148
+ if T != self.total_numeric_tokens:
149
+ raise ValueError(f"token length mismatch: got T={T}, expected {self.total_numeric_tokens}")
150
+
151
+ value_out: Dict[int, torch.Tensor] = {}
152
+ s_out: Dict[int, torch.Tensor] = {}
153
+
154
+ for gi, n_in in enumerate(self.group_nins):
155
+ key = str(n_in)
156
+ start = int(self.group_token_offsets[key])
157
+
158
+ V = self.group_Vs[gi]
159
+ length = V * n_in
160
+
161
+ xg_tok = x_tokens[:, start:start + length, :] # [B, V*n_in, H]
162
+ xg_tok4 = xg_tok.reshape(B, V, n_in, H) # [B, V, n_in, H]
163
+ xg_flat = xg_tok4.reshape(B, V, n_in * H) # [B, V, n_in*H]
164
+
165
+ # values: [B, V, n_in]
166
+ y = self.group_v_decoders[gi](xg_flat)
167
+
168
+ # s = log sigma^2: [B, V]
169
+ if self.homoscedastic:
170
+ s = self.group_s_params[gi].unsqueeze(0).expand(B, -1)
171
+ else:
172
+ x_var = xg_tok4.mean(dim=2) # [B, V, H]
173
+ s = self.group_s_decoders[gi](x_var).squeeze(-1) # [B, V]
174
+
175
+ value_out[n_in] = y
176
+ s_out[n_in] = s
177
+
178
+ return value_out, s_out
179
+
180
+
181
+ # ============================================================
182
+ # DEMO
183
+ # ============================================================
184
+
185
+ def _demo_main():
186
+ import argparse
187
+
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
190
+ parser.add_argument("--hidden_size", type=int, default=768)
191
+ parser.add_argument("--middle_size", type=int, default=-1,
192
+ help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
193
+ parser.add_argument("--batch_size", type=int, default=4)
194
+ parser.add_argument("--device", type=str, default=None)
195
+ parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
196
+ args = parser.parse_args()
197
+
198
+ device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
199
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
200
+ dtype = dtype_map[args.dtype]
201
+
202
+ # Directly load existing numeric vocab spec
203
+ spec = load_json(args.numeric_vocab_json)
204
+ print(f"Loaded numeric vocab spec from: {args.numeric_vocab_json}")
205
+ print(f"Groups (n_in -> V):", {int(g['n_in']): len(g['feature_names']) for g in spec["groups"]})
206
+ print("total_numeric_tokens:", spec["total_numeric_tokens"])
207
+ print("group_token_offsets:", spec["group_token_offsets"])
208
+
209
+ middle_size = None if args.middle_size < 0 else int(args.middle_size)
210
+ model = NumericDecoder(
211
+ hidden_size=args.hidden_size,
212
+ numeric_vocab_json=args.numeric_vocab_json,
213
+ middle_size=middle_size,
214
+ ).to(device=device, dtype=dtype)
215
+ model.eval()
216
+
217
+ total_params = sum(p.numel() for p in model.parameters())
218
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
219
+ print(f"Total parameters (NumericDecoder): {total_params:,} (trainable: {trainable_params:,})")
220
+
221
+ B = args.batch_size
222
+ T = int(spec["total_numeric_tokens"])
223
+ H = args.hidden_size
224
+
225
+ x_tokens = torch.randn(B, T, H, device=device, dtype=dtype)
226
+
227
+ with torch.no_grad():
228
+ values_by_nin, s_by_nin = model(x_tokens)
229
+
230
+ print("Input tokens:", tuple(x_tokens.shape), x_tokens.dtype, x_tokens.device)
231
+ print("Decoded values:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
232
+ print("Decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
233
+ # values_by_nin[n_in]: [B, V, n_in]
234
+ # s_by_nin[n_in]: [B, V]
235
+
236
+
237
+ if __name__ == "__main__":
238
+ _demo_main()
modelling/embed_categorical.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embed_categorical.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Categorical embedding module for tabular transformer.
6
+
7
+ Design:
8
+ - Each categorical column = 1 token
9
+ - Value embedding: ONE global lookup table using (offset + local_id)
10
+ - ID embedding: ONE categorical column-ID embedding table
11
+ - Explicit col_id stored in cat_vocab.json (no implicit ordering assumptions)
12
+
13
+ Outputs:
14
+ local_ids [B,M] -> tokens [B,M,H]
15
+ """
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Dict, List, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from utils import load_json, save_json
24
+
25
+ SPECIAL_MASK = "__MASK__"
26
+
27
+
28
+ # ============================================================
29
+ # Meta → categorical column list
30
+ # ============================================================
31
+
32
+ def get_categorical_feature_names_from_meta(tabular_meta: Dict) -> List[str]:
33
+ """
34
+ Deterministic ordering:
35
+ alphabetical by feature name.
36
+ """
37
+ cols = []
38
+ for k, v in tabular_meta.items():
39
+ if v.get("dataclass") == "categorical" and not v.get("is_array_valued", False):
40
+ cols.append(k)
41
+ return sorted(cols)
42
+
43
+
44
+ # ============================================================
45
+ # Vocab spec
46
+ # ============================================================
47
+
48
+ @dataclass
49
+ class CatColSpec:
50
+ name: str
51
+ col_id: int
52
+ offset: int
53
+ num_classes: int
54
+ mask_local_id: int
55
+ label2id: Dict[str, int]
56
+
57
+
58
+ def build_cat_vocab_spec_from_meta(
59
+ tabular_meta: Dict,
60
+ categorical_feature_names: List[str],
61
+ label_order: str = "alpha",
62
+ ) -> Dict[str, CatColSpec]:
63
+ vocab: Dict[str, CatColSpec] = {}
64
+
65
+ offset = 0
66
+ for j, col in enumerate(categorical_feature_names):
67
+ info = tabular_meta[col]
68
+ class_stats = info.get("class_stats", {}) or {}
69
+
70
+ # deterministic label order
71
+ if label_order == "alpha":
72
+ labels = sorted(class_stats.keys())
73
+ elif label_order == "freq_desc":
74
+ labels = sorted(class_stats.keys(), key=lambda k: (-class_stats[k], k))
75
+ else:
76
+ raise ValueError("label_order must be alpha or freq_desc")
77
+
78
+ label2id = {lab: i for i, lab in enumerate(labels)}
79
+
80
+ mask_local_id = len(labels)
81
+ label2id[SPECIAL_MASK] = mask_local_id
82
+
83
+ spec = CatColSpec(
84
+ name=col,
85
+ col_id=j, # EXPLICIT categorical column id
86
+ offset=offset,
87
+ num_classes=mask_local_id + 1,
88
+ mask_local_id=mask_local_id,
89
+ label2id=label2id,
90
+ )
91
+ vocab[col] = spec
92
+
93
+ offset += spec.num_classes
94
+
95
+ return vocab
96
+
97
+
98
+ def save_cat_vocab_json(vocab: Dict[str, CatColSpec], path: str) -> None:
99
+ out = {}
100
+
101
+ for col, spec in vocab.items():
102
+ out[col] = {
103
+ "col_id": spec.col_id,
104
+ "offset": spec.offset,
105
+ "num_classes": spec.num_classes,
106
+ "mask_local_id": spec.mask_local_id,
107
+ "global_id_start": spec.offset,
108
+ "global_id_end": spec.offset + spec.num_classes - 1,
109
+ "label2id": spec.label2id,
110
+ }
111
+
112
+ save_json(out, path)
113
+
114
+
115
+ # ============================================================
116
+ # Embedding modules
117
+ # ============================================================
118
+
119
+ class CategoricalValueEmbedding(nn.Module):
120
+ """
121
+ Global value embedding using offsets.
122
+ """
123
+
124
+ def __init__(self, hidden_size: int, cat_vocab_json: str):
125
+ super().__init__()
126
+
127
+ spec = load_json(cat_vocab_json)
128
+
129
+ # sort by col_id to ensure consistent tensor layout
130
+ items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
131
+
132
+ offsets = []
133
+ num_classes = []
134
+ col_ids = []
135
+
136
+ total_vocab = 0
137
+
138
+ for name, s in items:
139
+ offsets.append(int(s["offset"]))
140
+ num_classes.append(int(s["num_classes"]))
141
+ col_ids.append(int(s["col_id"]))
142
+ total_vocab = max(total_vocab, s["offset"] + s["num_classes"])
143
+
144
+ self.hidden_size = int(hidden_size)
145
+ self.total_vocab_size = int(total_vocab)
146
+ # Merge all classes to avoid many small nn.Embedding modules
147
+ self.emb = nn.Embedding(self.total_vocab_size, self.hidden_size)
148
+
149
+ self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long), persistent=True)
150
+ self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True)
151
+ self.register_buffer("col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True)
152
+
153
+ def init_weights(self, std=0.02):
154
+ nn.init.normal_(self.emb.weight, std=std)
155
+
156
+ def forward(self, local_ids: torch.LongTensor) -> torch.Tensor:
157
+ """
158
+ local_ids: [B,M]
159
+ returns: [B,M,H]
160
+ """
161
+
162
+ if local_ids.dim() != 2:
163
+ raise ValueError("local_ids must be [B,M]")
164
+
165
+ B, M = local_ids.shape
166
+
167
+ if M != self.offsets.numel():
168
+ raise ValueError("Column count mismatch")
169
+
170
+ if torch.any(local_ids < 0):
171
+ raise ValueError("Negative local_id")
172
+
173
+ nc = self.num_classes.view(1, M).expand(B, M)
174
+ if torch.any(local_ids >= nc):
175
+ raise ValueError("local_ids out of range")
176
+
177
+ gid = self.offsets.view(1, M) + local_ids
178
+ return self.emb(gid)
179
+
180
+
181
+ class CategoricalIdEmbedding(nn.Module):
182
+ """
183
+ Explicit categorical column ID embedding.
184
+ """
185
+
186
+ def __init__(self, hidden_size: int, cat_vocab_json: str):
187
+ super().__init__()
188
+
189
+ spec = load_json(cat_vocab_json)
190
+ items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
191
+
192
+ col_ids = [s["col_id"] for _, s in items]
193
+ max_col_id = max(col_ids)
194
+
195
+ self.emb = nn.Embedding(max_col_id + 1, hidden_size)
196
+
197
+ self.register_buffer(
198
+ "cat_col_ids",
199
+ torch.tensor(col_ids, dtype=torch.long),
200
+ persistent=True,
201
+ )
202
+
203
+ self.hidden_size = hidden_size
204
+
205
+ def init_weights(self, std=0.02):
206
+ nn.init.normal_(self.emb.weight, std=std)
207
+
208
+ def forward(self, batch_size: int) -> torch.Tensor:
209
+ """
210
+ returns [B,M,H]
211
+ """
212
+ id_vec = self.emb(self.cat_col_ids) # [M,H]
213
+ return id_vec.view(1, -1, self.hidden_size).expand(batch_size, -1, -1)
214
+
215
+
216
+ class CategoricalEmbedding(nn.Module):
217
+ """
218
+ token = value_embedding + categorical_id_embedding
219
+ """
220
+
221
+ def __init__(self, hidden_size: int, cat_vocab_json: str):
222
+ super().__init__()
223
+
224
+ self.value_emb = CategoricalValueEmbedding(hidden_size, cat_vocab_json)
225
+ self.id_emb = CategoricalIdEmbedding(hidden_size, cat_vocab_json)
226
+
227
+ def init_weights(self, std=0.02):
228
+ self.value_emb.init_weights(std=std)
229
+ self.id_emb.init_weights(std=std)
230
+
231
+ def forward(
232
+ self,
233
+ local_ids: torch.LongTensor, # [B, M]
234
+ valid_positions: Optional[torch.Tensor] = None, # Bool [B,M] (True=valid) or indices [K,2]
235
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
236
+ """
237
+ Returns:
238
+ tokens: [B, M, H]
239
+ token_mask: [B, M] (1=valid, 0=invalid)
240
+ """
241
+ if local_ids.dim() != 2:
242
+ raise ValueError(f"local_ids must be [B,M], got {tuple(local_ids.shape)}")
243
+ B, M = local_ids.shape
244
+
245
+ tokens = self.value_emb(local_ids) + self.id_emb(B) # [B,M,H]
246
+
247
+ # Default: all tokens are valid
248
+ valid = torch.ones((B, M), dtype=torch.bool, device=local_ids.device)
249
+
250
+ if valid_positions is not None:
251
+ if valid_positions.dtype == torch.bool:
252
+ if valid_positions.shape != (B, M):
253
+ raise ValueError(
254
+ f"valid_positions (bool) must be [B,M]=({B}, {M}), got {tuple(valid_positions.shape)}")
255
+ valid = valid_positions.to(device=local_ids.device)
256
+ else:
257
+ # Optional: support index pairs [K,2] where each row is (b_idx, m_idx) for valid positions
258
+ if valid_positions.dim() != 2 or valid_positions.size(1) != 2:
259
+ raise ValueError("valid_positions (indices) must be [K,2] with (batch_idx, col_idx)")
260
+ valid = torch.zeros((B, M), dtype=torch.bool, device=local_ids.device)
261
+ b_idx = valid_positions[:, 0].to(device=local_ids.device, dtype=torch.long)
262
+ m_idx = valid_positions[:, 1].to(device=local_ids.device, dtype=torch.long)
263
+ valid[b_idx, m_idx] = True
264
+
265
+ # Token mask: 1=valid, 0=invalid
266
+ token_mask = valid.to(dtype=torch.long) # [B,M]
267
+
268
+ # This is WRONG: we should allow __MASK__ to attend other columns
269
+ # # Invalid tokens must not contribute
270
+ # invalid = ~valid
271
+ # if invalid.any():
272
+ # tokens = tokens.masked_fill(invalid.unsqueeze(-1), 0.0)
273
+
274
+ return tokens, token_mask
275
+
276
+
277
+ # ============================================================
278
+ # DEMO
279
+ # ============================================================
280
+
281
+ def _demo_main():
282
+ import argparse
283
+
284
+ parser = argparse.ArgumentParser()
285
+ parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json")
286
+ parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json")
287
+ parser.add_argument("--hidden_size", type=int, default=768)
288
+ parser.add_argument("--batch_size", type=int, default=4)
289
+ args = parser.parse_args()
290
+
291
+ tabular_meta = load_json(args.tabular_meta)
292
+
293
+ cat_names = get_categorical_feature_names_from_meta(tabular_meta)
294
+ print(f"Found {len(cat_names)} categorical columns")
295
+
296
+ vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names)
297
+ save_cat_vocab_json(vocab, args.cat_vocab_json)
298
+ print(f"Saved vocab to {args.cat_vocab_json}")
299
+
300
+ model = CategoricalEmbedding(
301
+ hidden_size=args.hidden_size,
302
+ cat_vocab_json=args.cat_vocab_json,
303
+ )
304
+ total_params = sum(p.numel() for p in model.parameters())
305
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
306
+ print(f"Total parameters (CategoricalEmbedding): {total_params:,} (trainable: {trainable_params:,})")
307
+
308
+ B = args.batch_size
309
+ M = len(cat_names)
310
+
311
+ local_ids = torch.zeros((B, M), dtype=torch.long)
312
+
313
+ with torch.no_grad():
314
+ out, mask = model(local_ids)
315
+
316
+ print("local_ids:", tuple(local_ids.shape))
317
+ print("output:", tuple(out.shape)) # [B,M,H]
318
+ print("mask:", tuple(mask.shape)) # [B,M]
319
+
320
+
321
+ if __name__ == "__main__":
322
+ _demo_main()
modelling/embed_numeric.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embed_numeric.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Numeric embedding module for tabular transformer.
6
+
7
+ Updates in this version:
8
+ - numeric_vocab.json now includes:
9
+ - total_numeric_tokens
10
+ - group_token_offsets (by n_in)
11
+ - demo_main prints total parameter count
12
+
13
+ Design:
14
+ - scalar numeric (n_in=1): 1 token
15
+ - vector numeric (n_in=L): L tokens
16
+ - per bucket (same n_in): GroupedMLP with per-column weights (no for-loop over columns)
17
+ input : [B, V, n_in]
18
+ output : [B, V*n_in, H]
19
+ - middle_size:
20
+ - None: 1-layer
21
+ - int : 2-layer (Linear -> GELU -> Linear)
22
+ - NumericIdEmbedding:
23
+ - per numeric column id embedding [H]
24
+ - broadcast across that column's n_in tokens
25
+ """
26
+
27
+ from dataclasses import dataclass
28
+ from typing import Dict, List, Optional, Tuple
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+ from utils import load_json, save_json, GroupedMLP
34
+
35
+
36
+ # ============================================================
37
+ # Meta parsing
38
+ # ============================================================
39
+
40
+ def infer_n_in_from_meta_item(info: Dict) -> int:
41
+ return int(info["array_length"]) if info["is_array_valued"] else 1
42
+
43
+
44
+ def get_numeric_feature_names_and_dims_from_meta(tabular_meta: Dict) -> List[Tuple[str, int]]:
45
+ """
46
+ Return list of (feature_name, n_in) for numeric features.
47
+
48
+ Heuristic:
49
+ - info['dataclass'] == 'numeric' is treated as numeric.
50
+ """
51
+ out: List[Tuple[str, int]] = []
52
+ for name, info in tabular_meta.items():
53
+ if info.get("dataclass") != "numeric":
54
+ continue
55
+ n_in = infer_n_in_from_meta_item(info)
56
+ out.append((name, n_in))
57
+ # deterministic: group by n_in then name
58
+ out.sort(key=lambda x: (x[1], x[0]))
59
+ return out
60
+
61
+
62
+ # ============================================================
63
+ # Vocab/spec building
64
+ # ============================================================
65
+
66
+ @dataclass
67
+ class NumColSpec:
68
+ name: str
69
+ col_id: int
70
+ n_in: int
71
+ group_index: int
72
+ index_within_group: int
73
+
74
+
75
+ def build_numeric_vocab_spec_from_meta(tabular_meta: Dict) -> Dict:
76
+ """
77
+ Build numeric_vocab.json dict.
78
+
79
+ Output keys:
80
+ - ordered_feature_names
81
+ - features[name] = {col_id, n_in, group_index, index_within_group}
82
+ - groups = [{n_in, feature_names}, ...] sorted by n_in asc
83
+ - total_numeric_tokens
84
+ - group_token_offsets: { "<n_in>": <start_token_index> }
85
+ token order is groups by n_in asc, within group by feature name
86
+ """
87
+ feats = get_numeric_feature_names_and_dims_from_meta(tabular_meta)
88
+ if not feats:
89
+ raise ValueError("No numeric features found (dataclass=='numeric').")
90
+
91
+ # group by n_in
92
+ groups_map: Dict[int, List[str]] = {}
93
+ for name, n_in in feats:
94
+ groups_map.setdefault(n_in, []).append(name)
95
+
96
+ for n_in in groups_map:
97
+ groups_map[n_in] = sorted(groups_map[n_in])
98
+
99
+ group_nins = sorted(groups_map.keys())
100
+
101
+ groups: List[Dict] = []
102
+ ordered_feature_names: List[str] = []
103
+
104
+ for n_in in group_nins:
105
+ names = groups_map[n_in]
106
+ groups.append({"n_in": int(n_in), "feature_names": names})
107
+ ordered_feature_names.extend(names)
108
+
109
+ # build per-feature mapping
110
+ name_to_group: Dict[str, Tuple[int, int]] = {}
111
+ for gi, g in enumerate(groups):
112
+ for idx, nm in enumerate(g["feature_names"]):
113
+ name_to_group[nm] = (gi, idx)
114
+
115
+ features: Dict[str, Dict] = {}
116
+ for col_id, nm in enumerate(ordered_feature_names):
117
+ gi, idx = name_to_group[nm]
118
+ n_in = int(groups[gi]["n_in"])
119
+ features[nm] = {
120
+ "col_id": int(col_id),
121
+ "n_in": int(n_in),
122
+ "group_index": int(gi),
123
+ "index_within_group": int(idx),
124
+ }
125
+
126
+ # total tokens + group token offsets
127
+ total_numeric_tokens = 0
128
+ group_token_offsets: Dict[str, int] = {}
129
+ running = 0
130
+ for g in groups:
131
+ n_in = int(g["n_in"])
132
+ group_token_offsets[str(n_in)] = int(running)
133
+ V = len(g["feature_names"])
134
+ running += V * n_in
135
+ total_numeric_tokens += V * n_in
136
+
137
+ spec = {
138
+ "ordered_feature_names": ordered_feature_names,
139
+ "features": features,
140
+ "groups": groups,
141
+ "total_numeric_tokens": int(total_numeric_tokens),
142
+ "group_token_offsets": group_token_offsets, # keys are strings to be JSON-friendly
143
+ }
144
+ return spec
145
+
146
+
147
+ # ============================================================
148
+ # Core modules
149
+ # ============================================================
150
+
151
+ class NumericIdEmbedding(nn.Module):
152
+ """
153
+ Per-numeric-column ID embedding in the GLOBAL numeric namespace.
154
+ Broadcast each global column id vector across its n_in tokens.
155
+ """
156
+
157
+ def __init__(self, num_numeric_cols: int, hidden_size: int):
158
+ super().__init__()
159
+ self.num_numeric_cols = int(num_numeric_cols)
160
+ self.hidden_size = int(hidden_size)
161
+ self.emb = nn.Embedding(self.num_numeric_cols, self.hidden_size)
162
+
163
+ def forward(self, global_col_ids: torch.LongTensor, batch_size: int, n_in: int) -> torch.Tensor:
164
+ """
165
+ global_col_ids: [V] in global numeric namespace
166
+ returns: [B, V*n_in, H]
167
+ """
168
+ if global_col_ids.dim() != 1:
169
+ raise ValueError(f"global_col_ids must be [V], got {tuple(global_col_ids.shape)}")
170
+
171
+ V = global_col_ids.numel()
172
+ n_in = int(n_in)
173
+
174
+ id_vec = self.emb(global_col_ids) # [V, H]
175
+ id_vec = id_vec.view(1, V, 1, self.hidden_size).expand(batch_size, V, n_in, self.hidden_size)
176
+ return id_vec.reshape(batch_size, V * n_in, self.hidden_size)
177
+
178
+ def init_weights(self, std: float = 0.02):
179
+ nn.init.normal_(self.emb.weight, std=std)
180
+
181
+
182
+ class NumericMaskEmbedding(nn.Module):
183
+ """
184
+ Per-bucket numeric mask embedding.
185
+ Local to one (n_in) group / bucket.
186
+
187
+ Parameter shape:
188
+ [num_bucket_cols, n_in, H]
189
+
190
+ So missing numeric columns are represented by:
191
+ (bucket-local column index, sub-token index)
192
+ """
193
+
194
+ def __init__(self, num_bucket_cols: int, n_in: int, hidden_size: int):
195
+ super().__init__()
196
+ self.num_bucket_cols = int(num_bucket_cols)
197
+ self.n_in = int(n_in)
198
+ self.hidden_size = int(hidden_size)
199
+
200
+ self.emb = nn.Parameter(
201
+ torch.empty(self.num_bucket_cols, self.n_in, self.hidden_size)
202
+ )
203
+
204
+ def forward(self, local_col_ids: torch.LongTensor, batch_size: int) -> torch.Tensor:
205
+ """
206
+ local_col_ids: [V] bucket-local ids, usually 0 to V-1
207
+ returns: [B, V*n_in, H]
208
+ """
209
+ if local_col_ids.dim() != 1:
210
+ raise ValueError(f"local_col_ids must be [V], got {tuple(local_col_ids.shape)}")
211
+
212
+ V = local_col_ids.numel()
213
+ mask_vec = self.emb[local_col_ids] # [V, n_in, H]
214
+ mask_vec = mask_vec.unsqueeze(0).expand(batch_size, V, self.n_in, self.hidden_size)
215
+ return mask_vec.reshape(batch_size, V * self.n_in, self.hidden_size)
216
+
217
+ def init_weights(self, std: float = 0.02):
218
+ nn.init.normal_(self.emb, std=std)
219
+
220
+
221
+ class NumericEmbedding(nn.Module):
222
+ """
223
+ Full numeric embedding for all numeric columns described by numeric_vocab.json.
224
+
225
+ Forward expects bucketed input:
226
+ values_by_nin: { n_in: x[B, V, n_in] }
227
+ where V must match the feature count and order of that n_in group.
228
+
229
+ Output token ordering:
230
+ groups by n_in ascending (as stored in spec["groups"]),
231
+ within each group by feature_names order.
232
+ """
233
+
234
+ def __init__(self, hidden_size: int, numeric_vocab_json: str, middle_size: Optional[int] = None):
235
+ super().__init__()
236
+ self.hidden_size = int(hidden_size)
237
+ self.middle_size = None if middle_size is None else int(middle_size)
238
+
239
+ spec = load_json(numeric_vocab_json)
240
+ self.ordered_feature_names: List[str] = list(spec["ordered_feature_names"])
241
+ self.features: Dict[str, Dict] = dict(spec["features"])
242
+ self.groups: List[Dict] = list(spec["groups"])
243
+ self.total_numeric_tokens = int(spec.get("total_numeric_tokens", -1))
244
+
245
+ num_cols = len(self.ordered_feature_names)
246
+
247
+ # Global numeric namespace id embedding
248
+ self.id_emb = NumericIdEmbedding(
249
+ num_numeric_cols=num_cols,
250
+ hidden_size=self.hidden_size,
251
+ )
252
+
253
+ # Per-group mask embedding
254
+ self.mask_emb = nn.ModuleDict()
255
+
256
+ # Per-group value embedding
257
+ self.group_mlps = nn.ModuleList()
258
+
259
+ self.group_nins: List[int] = []
260
+ self._num_groups = len(self.groups)
261
+
262
+ # Optional: useful for debugging / downstream checks
263
+ self.group_sizes: List[int] = []
264
+
265
+ # Build one block per group
266
+ for gi, g in enumerate(self.groups):
267
+ n_in = int(g["n_in"])
268
+ names = list(g["feature_names"])
269
+ V = len(names)
270
+
271
+ self.group_nins.append(n_in)
272
+ self.group_sizes.append(V)
273
+
274
+ # ---- spec consistency check
275
+ # group_index and index_within_group in features must match groups[gi]["feature_names"] order
276
+ local_ids = []
277
+ for local_idx, nm in enumerate(names):
278
+ f = self.features[nm]
279
+
280
+ if int(f["group_index"]) != gi:
281
+ raise ValueError(
282
+ f"Feature {nm} has group_index={f['group_index']}, expected {gi}"
283
+ )
284
+ if int(f["n_in"]) != n_in:
285
+ raise ValueError(
286
+ f"Feature {nm} has n_in={f['n_in']}, expected {n_in}"
287
+ )
288
+ if int(f["index_within_group"]) != local_idx:
289
+ raise ValueError(
290
+ f"Feature {nm} has index_within_group={f['index_within_group']}, expected {local_idx}"
291
+ )
292
+
293
+ local_ids.append(int(f["index_within_group"]))
294
+
295
+ # strict check: local ids must be exactly 0 to V-1 with no gap / no duplicate
296
+ if sorted(local_ids) != list(range(V)):
297
+ raise ValueError(
298
+ f"Group gi={gi}, n_in={n_in} has invalid index_within_group set: "
299
+ f"got {sorted(local_ids)}, expected {list(range(V))}"
300
+ )
301
+
302
+ # ---- observed value path: bucket-local ordering
303
+ self.group_mlps.append(
304
+ GroupedMLP(
305
+ n_var=V,
306
+ n_in=n_in,
307
+ n_out=n_in * self.hidden_size,
308
+ middle_size=self.middle_size,
309
+ )
310
+ )
311
+
312
+ # ---- global ids for NumericIdEmbedding
313
+ global_col_ids = [int(self.features[nm]["col_id"]) for nm in names]
314
+ self.register_buffer(
315
+ f"group_global_col_ids_{gi}",
316
+ torch.tensor(global_col_ids, dtype=torch.long),
317
+ persistent=True,
318
+ )
319
+
320
+ # ---- local ids for NumericMaskEmbedding
321
+ local_col_ids = [int(self.features[nm]["index_within_group"]) for nm in names]
322
+ self.register_buffer(
323
+ f"group_local_col_ids_{gi}",
324
+ torch.tensor(local_col_ids, dtype=torch.long),
325
+ persistent=True,
326
+ )
327
+
328
+ # one mask embedding per bucket
329
+ self.mask_emb[str(n_in)] = NumericMaskEmbedding(
330
+ num_bucket_cols=V,
331
+ n_in=n_in,
332
+ hidden_size=self.hidden_size,
333
+ )
334
+
335
+ if self.total_numeric_tokens < 0:
336
+ self.total_numeric_tokens = sum(
337
+ len(g["feature_names"]) * int(g["n_in"]) for g in self.groups
338
+ )
339
+
340
+ def init_weights(self, std: float = 0.02):
341
+ self.id_emb.init_weights(std=std)
342
+
343
+ for _, mask_mod in self.mask_emb.items():
344
+ mask_mod.init_weights(std=std)
345
+
346
+ for mlp in self.group_mlps:
347
+ mlp.init_weights(std=std)
348
+
349
+ def forward(
350
+ self,
351
+ values_by_nin: Dict[int, torch.Tensor],
352
+ valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None,
353
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
354
+ """
355
+ Args:
356
+ values_by_nin:
357
+ { n_in: x } where x is [B, V, n_in]
358
+ Missing numeric values are assumed already filled in x.
359
+
360
+ valid_positions_by_nin (optional):
361
+ { n_in: valid_cols } where valid_cols is BoolTensor [B, V]
362
+ True means this COLUMN is observed/valid.
363
+
364
+ Note:
365
+ This is COLUMN-level mask, not token-level.
366
+ It is expanded to token-level by repeating across n_in.
367
+
368
+ Returns:
369
+ tokens: [B, total_numeric_tokens, H]
370
+ token_mask: [B, total_numeric_tokens] (1=valid, 0=missing)
371
+ """
372
+ outs = []
373
+ masks = []
374
+ batch_size = None
375
+
376
+ for gi, n_in in enumerate(self.group_nins):
377
+ if n_in not in values_by_nin:
378
+ raise KeyError(f"Missing bucket input for n_in={n_in}")
379
+
380
+ x = values_by_nin[n_in] # [B, V, n_in]
381
+ if x.dim() != 3 or x.size(-1) != n_in:
382
+ raise ValueError(f"Bucket n_in={n_in} expects x [B,V,{n_in}], got {tuple(x.shape)}")
383
+
384
+ if batch_size is None:
385
+ batch_size = x.size(0)
386
+ elif x.size(0) != batch_size:
387
+ raise ValueError("All buckets must share the same batch size")
388
+
389
+ B, V, _ = x.shape
390
+
391
+ expected_V = self.group_sizes[gi]
392
+ if V != expected_V:
393
+ raise ValueError(
394
+ f"Bucket n_in={n_in} expects V={expected_V}, got V={V}"
395
+ )
396
+
397
+ # column-level valid mask [B, V]
398
+ if valid_positions_by_nin is None:
399
+ valid_cols = torch.ones((B, V), dtype=torch.bool, device=x.device)
400
+ else:
401
+ if n_in not in valid_positions_by_nin:
402
+ raise KeyError(f"Missing valid mask for bucket n_in={n_in}")
403
+
404
+ valid_cols = valid_positions_by_nin[n_in]
405
+ if valid_cols.dtype != torch.bool:
406
+ raise ValueError(
407
+ f"valid_positions_by_nin[{n_in}] must be bool tensor, got {valid_cols.dtype}"
408
+ )
409
+ if valid_cols.shape != (B, V):
410
+ raise ValueError(
411
+ f"valid_positions_by_nin[{n_in}] must be [B,V]=[{B},{V}], got {tuple(valid_cols.shape)}"
412
+ )
413
+ valid_cols = valid_cols.to(device=x.device)
414
+
415
+ # ---- observed numeric value embedding
416
+ mlp = self.group_mlps[gi]
417
+ param = next(mlp.parameters())
418
+ x = x.to(device=param.device, dtype=param.dtype)
419
+
420
+ # [B, V, n_in] -> [B, V, n_in*H]
421
+ y = mlp(x)
422
+
423
+ # [B, V, n_in*H] -> [B, V*n_in, H]
424
+ y_tok = y.view(B, V, n_in, self.hidden_size).reshape(B, V * n_in, self.hidden_size)
425
+
426
+ # [B, V] -> [B, V*n_in]
427
+ valid_tok = valid_cols.unsqueeze(-1).expand(B, V, n_in).reshape(B, V * n_in)
428
+
429
+ # ---- missing replacement: bucket-local mask embedding
430
+ local_col_ids = getattr(self, f"group_local_col_ids_{gi}") # [V]
431
+ mask_tok = self.mask_emb[str(n_in)](local_col_ids, batch_size=B)
432
+
433
+ if (~valid_tok).any():
434
+ y_tok = torch.where(
435
+ valid_tok.unsqueeze(-1),
436
+ y_tok,
437
+ mask_tok,
438
+ )
439
+
440
+ # ---- add global numeric column id embedding
441
+ global_col_ids = getattr(self, f"group_global_col_ids_{gi}") # [V]
442
+ y_tok = y_tok + self.id_emb(global_col_ids, batch_size=B, n_in=n_in)
443
+
444
+ token_mask = valid_tok.to(dtype=torch.long)
445
+
446
+ outs.append(y_tok)
447
+ masks.append(token_mask)
448
+
449
+ tokens = torch.cat(outs, dim=1)
450
+ token_mask = torch.cat(masks, dim=1)
451
+
452
+ if token_mask.shape[:2] != tokens.shape[:2]:
453
+ raise RuntimeError("token_mask shape mismatch with tokens")
454
+
455
+ return tokens, token_mask
456
+
457
+
458
+ # ============================================================
459
+ # DEMO
460
+ # ============================================================
461
+
462
+ def _demo_main():
463
+ import argparse
464
+
465
+ parser = argparse.ArgumentParser()
466
+ parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json")
467
+ parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
468
+ parser.add_argument("--hidden_size", type=int, default=768)
469
+ parser.add_argument("--middle_size", type=int, default=-1,
470
+ help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
471
+ parser.add_argument("--batch_size", type=int, default=4)
472
+ parser.add_argument("--device", type=str, default=None)
473
+ parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
474
+ args = parser.parse_args()
475
+
476
+ device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
477
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
478
+ dtype = dtype_map[args.dtype]
479
+
480
+ meta = load_json(args.tabular_meta)
481
+
482
+ spec = build_numeric_vocab_spec_from_meta(meta)
483
+ save_json(spec, args.numeric_vocab_json)
484
+ print(f"Saved numeric vocab spec to: {args.numeric_vocab_json}")
485
+ print(f"Groups (n_in -> V):", {g["n_in"]: len(g["feature_names"]) for g in spec["groups"]})
486
+ print("total_numeric_tokens:", spec["total_numeric_tokens"])
487
+ print("group_token_offsets:", spec["group_token_offsets"])
488
+
489
+ middle_size = None if args.middle_size < 0 else int(args.middle_size)
490
+ model = NumericEmbedding(
491
+ hidden_size=args.hidden_size,
492
+ numeric_vocab_json=args.numeric_vocab_json,
493
+ middle_size=middle_size,
494
+ ).to(device=device, dtype=dtype)
495
+ model.init_weights()
496
+ model.eval()
497
+
498
+ total_params = sum(p.numel() for p in model.parameters())
499
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
500
+ print(f"Total parameters (NumericEmbedding): {total_params:,} (trainable: {trainable_params:,})")
501
+
502
+ # create demo inputs bucketed by n_in
503
+ B = args.batch_size
504
+ values_by_nin: Dict[int, torch.Tensor] = {}
505
+ valid_positions_by_nin: Dict[int, torch.Tensor] = {}
506
+
507
+ for g in spec["groups"]:
508
+ n_in = int(g["n_in"])
509
+ V = len(g["feature_names"])
510
+
511
+ # random numeric inputs
512
+ x = torch.randn(B, V, n_in, device=device, dtype=dtype)
513
+ values_by_nin[n_in] = x
514
+
515
+ # Build valid mask (column-level)
516
+ # shape: [B, V], True = valid
517
+ valid_cols = torch.ones((B, V), dtype=torch.bool, device=device)
518
+
519
+ # Mark first sample's first 2 columns as invalid
520
+ num_to_invalidate = min(2, V)
521
+ valid_cols[0, :num_to_invalidate] = False
522
+
523
+ valid_positions_by_nin[n_in] = valid_cols
524
+
525
+ with torch.no_grad():
526
+ out, mask = model(values_by_nin, valid_positions_by_nin)
527
+
528
+ print("Buckets:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
529
+ print("Output tokens:", tuple(out.shape), out.dtype, out.device) # [B, total_numeric_tokens, H]
530
+ print("Masks:", tuple(mask.shape), mask.dtype, mask.device) # [B, total_numeric_tokens]
531
+
532
+ # ---- Inspect first sample
533
+ print("\nFirst sample mask (first 5 tokens):")
534
+ print(mask[0, :5])
535
+
536
+ print("\nFirst sample token L2 norms (first 5 tokens):")
537
+ print(out[0, :5].norm(dim=-1))
538
+
539
+ print("\nSecond sample mask (first 5 tokens):")
540
+ print(mask[1, :5])
541
+
542
+ print("\nSecond sample token L2 norms (first 5 tokens):")
543
+ print(out[1, :5].norm(dim=-1))
544
+
545
+
546
+ if __name__ == "__main__":
547
+ _demo_main()
modelling/embed_vision_gemma3n.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embed_vision_gemma3n.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ from typing import Optional, Tuple, Dict
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from safetensors.torch import load_file as safetensors_load_file
10
+ from transformers import AutoConfig, AutoModel
11
+ from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder # noqa
12
+
13
+ from utils import load_json
14
+
15
+
16
+ def _split_state_dict_from_tmp(sd: Dict[str, torch.Tensor]) \
17
+ -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
18
+ """
19
+ Model extractor saved tmp.state_dict() where tmp has attributes:
20
+ - vision_tower
21
+ - embed_vision (optional)
22
+ So keys look like:
23
+ - vision_tower.xxx
24
+ - embed_vision.xxx
25
+ """
26
+ vt = {}
27
+ ev = {}
28
+ for k, v in sd.items():
29
+ if k.startswith("vision_tower."):
30
+ vt[k[len("vision_tower."):]] = v
31
+ elif k.startswith("embed_vision."):
32
+ ev[k[len("embed_vision."):]] = v
33
+ return vt, ev
34
+
35
+
36
+ # ============================================================
37
+ # Optional lightweight learnable token reducer
38
+ # ============================================================
39
+
40
+
41
+ class VisionTokenReducer(nn.Module):
42
+ """
43
+ Perceiver-style learnable cross-attention pooling with optional bottleneck.
44
+
45
+ Base (no bottleneck):
46
+ [B,T,D] -> [B,K,D]
47
+
48
+ Bottleneck mode (bottleneck_dim=d):
49
+ [B,T,D] -> down -> [B,T,d] -> cross-attn -> [B,K,d] -> (optional up) -> [B,K,D]
50
+
51
+ Notes:
52
+ - num_heads does NOT change parameter count of MultiheadAttention (depends on D only).
53
+ - perform_norm_latent controls whether to pre-norm the learnable latent queries.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ vision_dim: int,
59
+ num_output_tokens: int,
60
+ num_heads: int = 4,
61
+ perform_norm_latent: bool = True,
62
+ bottleneck_dim: Optional[int] = None,
63
+ project_back: bool = True,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.vision_dim = int(vision_dim)
68
+ self.num_output_tokens = int(num_output_tokens)
69
+ self.num_heads = int(num_heads)
70
+ self.perform_norm_latent = bool(perform_norm_latent)
71
+
72
+ self.bottleneck_dim = None if bottleneck_dim is None else int(bottleneck_dim)
73
+ self.project_back = bool(project_back)
74
+
75
+ # Decide the attention working dimension: D (base) or d (bottleneck)
76
+ attn_dim = self.vision_dim if self.bottleneck_dim is None else self.bottleneck_dim
77
+ if attn_dim % self.num_heads != 0:
78
+ raise ValueError(f"embed_dim ({attn_dim}) must be divisible by num_heads ({self.num_heads})")
79
+
80
+ # Optional projection layers for bottleneck mode
81
+ if self.bottleneck_dim is None:
82
+ self.down = None
83
+ self.up = None
84
+ else:
85
+ # bias=False keeps it lightweight; switch to True if you prefer
86
+ self.down = nn.Linear(self.vision_dim, attn_dim, bias=False)
87
+ self.up = nn.Linear(attn_dim, self.vision_dim, bias=False) if self.project_back else None
88
+
89
+ # Learnable latent tokens (K, attn_dim)
90
+ self.latents = nn.Parameter(torch.randn(self.num_output_tokens, attn_dim) * 0.02)
91
+
92
+ # Separate norms: typically more stable than sharing one LN
93
+ self.norm_latents = nn.LayerNorm(attn_dim)
94
+ self.norm_x = nn.LayerNorm(attn_dim)
95
+
96
+ # Cross-attention: query=latents, key/value=x
97
+ self.attn = nn.MultiheadAttention(
98
+ embed_dim=attn_dim,
99
+ num_heads=self.num_heads,
100
+ batch_first=True,
101
+ )
102
+
103
+ def init_weights(self, std: float = 0.02):
104
+ # Optional bottleneck projections
105
+ if self.down is not None:
106
+ nn.init.normal_(self.down.weight, std=std)
107
+ if self.up is not None:
108
+ nn.init.normal_(self.up.weight, std=std)
109
+
110
+ # Learnable latent queries
111
+ nn.init.normal_(self.latents, std=std)
112
+
113
+ # LayerNorm
114
+ nn.init.ones_(self.norm_latents.weight)
115
+ nn.init.zeros_(self.norm_latents.bias)
116
+ nn.init.ones_(self.norm_x.weight)
117
+ nn.init.zeros_(self.norm_x.bias)
118
+
119
+ # MultiheadAttention: use PyTorch's own reset only
120
+ self.attn._reset_parameters() # noqa
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ """
124
+ Args:
125
+ x: [B, T, D] where D == vision_dim
126
+
127
+ Returns:
128
+ out: [B, K, D] if (bottleneck_dim is None) or project_back=True
129
+ [B, K, d] if bottleneck_dim is not None and project_back=False
130
+ """
131
+ if x.dim() != 3:
132
+ raise ValueError(f"Expected x [B,T,D], got {tuple(x.shape)}")
133
+ if x.size(-1) != self.vision_dim:
134
+ raise ValueError(f"Expected last dim D={self.vision_dim}, got {x.size(-1)}")
135
+
136
+ B = x.size(0)
137
+
138
+ # Bottleneck projection if enabled
139
+ if self.down is not None:
140
+ x = self.down(x) # [B,T,d]
141
+
142
+ # Expand learnable latents across batch
143
+ latents = self.latents.unsqueeze(0).expand(B, -1, -1) # [B,K,attn_dim]
144
+
145
+ # Pre-norm (optional for latents, always for input tokens)
146
+ if self.perform_norm_latent:
147
+ latents = self.norm_latents(latents)
148
+ x = self.norm_x(x)
149
+
150
+ # Cross-attention pooling
151
+ out, _ = self.attn(query=latents, key=x, value=x) # [B,K,attn_dim]
152
+
153
+ # Project back to original dim if requested
154
+ if self.up is not None:
155
+ out = self.up(out) # [B,K,D]
156
+
157
+ return out
158
+
159
+
160
+ # ============================================================
161
+ # Main body
162
+ # ============================================================
163
+
164
+ class Gemma3nVisionFeatureExtractor(nn.Module):
165
+ """
166
+ Vision-only feature extractor for Gemma-3n that matches transformers' Gemma3nModel.get_image_features().
167
+
168
+ Input: pixel_values [B, 3, H, W]
169
+ Output: image_features [B, vision_soft_tokens_per_image, text_hidden_size]
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ vision_tower: nn.Module,
175
+ embed_vision: Optional[nn.Module],
176
+ vision_hidden_size: int,
177
+ vision_soft_tokens_per_image: int,
178
+ text_hidden_size: int,
179
+ num_output_tokens_reduced: Optional[int] = None,
180
+ num_heads_for_token_reduction: int = 4,
181
+ perform_norm_latent_for_token_reduction: bool = True,
182
+ reducer_bottleneck_dim: Optional[int] = None,
183
+ reducer_project_back: bool = True,
184
+ ):
185
+ super().__init__()
186
+ self.vision_tower = vision_tower
187
+ self.embed_vision = embed_vision
188
+ self.vision_hidden_size = int(vision_hidden_size)
189
+ self.vision_soft_tokens_per_image = int(vision_soft_tokens_per_image)
190
+ self.text_hidden_size = int(text_hidden_size)
191
+ self.has_embed_vision = embed_vision is not None
192
+
193
+ # Freeze vision modules
194
+ self.vision_tower.requires_grad_(False)
195
+ if self.embed_vision is not None:
196
+ self.embed_vision.requires_grad_(False)
197
+
198
+ # Reduce number of tokens
199
+ if num_output_tokens_reduced is not None:
200
+ reducer_dim = text_hidden_size if self.has_embed_vision else vision_hidden_size
201
+ self.reducer = VisionTokenReducer(
202
+ vision_dim=reducer_dim,
203
+ num_output_tokens=num_output_tokens_reduced,
204
+ num_heads=num_heads_for_token_reduction,
205
+ perform_norm_latent=perform_norm_latent_for_token_reduction,
206
+ bottleneck_dim=reducer_bottleneck_dim,
207
+ project_back=reducer_project_back,
208
+ )
209
+ else:
210
+ self.reducer = None
211
+
212
+ def init_weights(self, std: float = 0.02):
213
+ if self.reducer is not None:
214
+ self.reducer.init_weights(std)
215
+
216
+ def get_actual_hidden_dim(self) -> int:
217
+ """
218
+ Return the actual feature hidden dimension produced by this extractor.
219
+
220
+ The output dimension depends on:
221
+ - whether embed_vision is used
222
+ - whether a reducer is present
223
+ - reducer bottleneck + project_back configuration
224
+
225
+ Returns:
226
+ int: feature hidden size of output tokens
227
+ """
228
+
229
+ # Base dimension before reducer
230
+ base_dim = self.text_hidden_size if self.has_embed_vision else self.vision_hidden_size
231
+
232
+ # No reducer
233
+ if self.reducer is None:
234
+ return base_dim
235
+
236
+ # Reducer without bottleneck
237
+ if self.reducer.bottleneck_dim is None:
238
+ return base_dim
239
+
240
+ # Bottleneck reducer
241
+ if self.reducer.project_back:
242
+ return base_dim
243
+
244
+ # Bottleneck without projection back
245
+ return int(self.reducer.bottleneck_dim)
246
+
247
+ def train(self, mode: bool = True) -> "Gemma3nVisionFeatureExtractor":
248
+ """ Override train(): vision is not trainable"""
249
+ super().train(mode=mode)
250
+ self.vision_tower.eval()
251
+ if self.embed_vision is not None:
252
+ self.embed_vision.eval()
253
+ return self
254
+
255
+ def forward(
256
+ self,
257
+ pixel_values: torch.Tensor,
258
+ valid_positions: Optional[torch.Tensor] = None,
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ """
261
+ Args:
262
+ pixel_values: [B, 3, H, W]
263
+ valid_positions:
264
+ Indicates which samples have valid images.
265
+ Supported formats:
266
+ - BoolTensor [B] where True means "has image"
267
+ - LongTensor [K] with indices of samples that have images
268
+ If None: assume all samples have images.
269
+
270
+ Returns:
271
+ features: [B, T_img, D]
272
+ vision_mask: [B, T_img] (1=valid vision token, 0=masked out)
273
+ """
274
+ if pixel_values.dim() != 4:
275
+ raise ValueError(f"pixel_values must be [B,3,H,W], got {tuple(pixel_values.shape)}")
276
+
277
+ B = pixel_values.size(0)
278
+ device = next(self.vision_tower.parameters()).device
279
+ dtype = next(self.vision_tower.parameters()).dtype
280
+
281
+ # --------------------------------------------------------
282
+ # Build per-sample valid-image mask
283
+ # --------------------------------------------------------
284
+ if valid_positions is None:
285
+ valid_mask = torch.ones(B, dtype=torch.bool, device=pixel_values.device)
286
+ else:
287
+ if valid_positions.dtype == torch.bool:
288
+ if valid_positions.shape != (B,):
289
+ raise ValueError(f"valid_positions (bool) must be [B], got {tuple(valid_positions.shape)}")
290
+ valid_mask = valid_positions.to(device=pixel_values.device)
291
+ else:
292
+ if valid_positions.dim() != 1:
293
+ raise ValueError(f"valid_positions (indices) must be 1D, got {tuple(valid_positions.shape)}")
294
+ valid_mask = torch.zeros(B, dtype=torch.bool, device=pixel_values.device)
295
+ valid_mask[valid_positions.to(device=pixel_values.device, dtype=torch.long)] = True
296
+
297
+ num_valid = int(valid_mask.sum().item())
298
+
299
+ # --------------------------------------------------------
300
+ # Figure out final output shape in advance
301
+ # --------------------------------------------------------
302
+ if self.reducer is None:
303
+ T_img = self.vision_soft_tokens_per_image
304
+ else:
305
+ T_img = self.reducer.num_output_tokens
306
+
307
+ D_out = self.get_actual_hidden_dim()
308
+
309
+ # vision_mask always returned for full batch
310
+ vision_mask = valid_mask[:, None].expand(B, T_img).to(dtype=torch.long)
311
+
312
+ # Fast path: no valid image at all
313
+ if num_valid == 0:
314
+ features = torch.zeros(B, T_img, D_out, device=device, dtype=dtype)
315
+ return features, vision_mask
316
+
317
+ # --------------------------------------------------------
318
+ # Run only valid samples through frozen vision stack
319
+ # --------------------------------------------------------
320
+ pixel_values_valid = pixel_values[valid_mask].to(device=device, dtype=dtype)
321
+
322
+ with torch.no_grad():
323
+ vision_last = self.vision_tower(
324
+ pixel_values=pixel_values_valid,
325
+ do_pooling=False,
326
+ return_dict=True,
327
+ ).last_hidden_state
328
+
329
+ if vision_last.dim() != 4:
330
+ raise RuntimeError(f"Expected vision last_hidden_state (B,C,h,w), got {tuple(vision_last.shape)}")
331
+
332
+ Bv, C, h, w = vision_last.shape
333
+ if Bv != num_valid:
334
+ raise RuntimeError("Batch size mismatch between valid pixel_values and vision_last")
335
+ if C != self.vision_hidden_size:
336
+ raise RuntimeError(f"Expected vision_hidden_size={self.vision_hidden_size}, got C={C}")
337
+ if h * w != self.vision_soft_tokens_per_image:
338
+ raise RuntimeError(
339
+ f"Expected h*w={self.vision_soft_tokens_per_image}, got {h * w}. "
340
+ f"Check processor image size/crop or config."
341
+ )
342
+
343
+ # (Bv, C, h, w) -> (Bv, C, HW) -> (Bv, HW, C)
344
+ vision_tokens = vision_last.reshape(Bv, C, self.vision_soft_tokens_per_image).permute(0, 2, 1).contiguous()
345
+
346
+ # Scale by sqrt(C) (matches Gemma codepath)
347
+ vision_tokens = vision_tokens * (self.vision_hidden_size ** 0.5)
348
+
349
+ # --------------------------------------------------------
350
+ # Extract valid-image features only
351
+ # --------------------------------------------------------
352
+ if not self.has_embed_vision:
353
+ valid_features = vision_tokens # [Bv, HW, C]
354
+ if self.reducer is not None:
355
+ valid_features = self.reducer(valid_features) # [Bv, T_img, C or d]
356
+ else:
357
+ with torch.no_grad():
358
+ valid_features = self.embed_vision(inputs_embeds=vision_tokens)
359
+
360
+ if valid_features.shape != (Bv, self.vision_soft_tokens_per_image, self.text_hidden_size):
361
+ raise RuntimeError(
362
+ f"Bad output shape {tuple(valid_features.shape)}; expected "
363
+ f"({Bv}, {self.vision_soft_tokens_per_image}, {self.text_hidden_size})"
364
+ )
365
+
366
+ if self.reducer is not None:
367
+ valid_features = self.reducer(valid_features)
368
+
369
+ # --------------------------------------------------------
370
+ # Scatter back to full batch; invalid samples stay zero
371
+ # --------------------------------------------------------
372
+ if valid_features.size(1) != T_img:
373
+ raise RuntimeError(f"T_img mismatch: expected {T_img}, got {valid_features.size(1)}")
374
+ if valid_features.size(2) != D_out:
375
+ raise RuntimeError(f"D_out mismatch: expected {D_out}, got {valid_features.size(2)}")
376
+
377
+ features = torch.zeros(B, T_img, D_out, device=valid_features.device, dtype=valid_features.dtype)
378
+ features[valid_mask] = valid_features
379
+
380
+ return features, vision_mask
381
+
382
+ @classmethod
383
+ def from_pretrained_vision_only_dir(
384
+ cls,
385
+ model_dir: str,
386
+ map_location: str = "cpu",
387
+ num_output_tokens_reduced: Optional[int] = None,
388
+ num_heads_for_token_reduction: int = 4,
389
+ perform_norm_latent_for_token_reduction: bool = True,
390
+ reducer_bottleneck_dim: Optional[int] = None,
391
+ reducer_project_back: bool = True,
392
+ ) -> "Gemma3nVisionFeatureExtractor":
393
+ weights_path = os.path.join(model_dir, "model.safetensors")
394
+ if not os.path.isfile(weights_path):
395
+ raise FileNotFoundError(f"Missing weights: {weights_path}")
396
+
397
+ ve_cfg_path = os.path.join(model_dir, "vision_extractor_config.json")
398
+ if not os.path.isfile(ve_cfg_path):
399
+ raise FileNotFoundError(f"Missing {ve_cfg_path}")
400
+ ve_cfg = load_json(ve_cfg_path)
401
+
402
+ vision_soft_tokens_per_image = int(ve_cfg.get("vision_soft_tokens_per_image", 256))
403
+ vision_hidden_size = int(ve_cfg.get("vision_hidden_size", -1))
404
+ text_hidden_size = int(ve_cfg.get("text_hidden_size", -1))
405
+ has_embed_vision = bool(ve_cfg.get("has_embed_vision", True))
406
+
407
+ if vision_hidden_size <= 0:
408
+ raise ValueError("vision_hidden_size missing/invalid in vision_extractor_config.json")
409
+ if has_embed_vision and text_hidden_size <= 0:
410
+ raise ValueError("text_hidden_size missing/invalid in vision_extractor_config.json")
411
+
412
+ cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
413
+ vision_cfg = getattr(cfg, "vision_config", cfg)
414
+ text_cfg = getattr(cfg, "text_config", None)
415
+
416
+ vision_tower = AutoModel.from_config(vision_cfg, trust_remote_code=True)
417
+
418
+ embed_vision = None
419
+ if has_embed_vision:
420
+ if text_cfg is None:
421
+ raise RuntimeError(
422
+ "config.json does not contain text_config, but has_embed_vision=True. "
423
+ "You need a Gemma3nConfig-like config.json in this folder."
424
+ )
425
+ embed_vision = Gemma3nMultimodalEmbedder(vision_cfg, text_cfg)
426
+
427
+ sd = safetensors_load_file(weights_path, device=map_location)
428
+
429
+ vt_sd, ev_sd = _split_state_dict_from_tmp(sd)
430
+ if not vt_sd:
431
+ raise RuntimeError("No vision_tower.* keys found in model.safetensors")
432
+ if has_embed_vision and not ev_sd:
433
+ raise RuntimeError("has_embed_vision=True but no embed_vision.* keys found in model.safetensors")
434
+
435
+ missing_vt, unexpected_vt = vision_tower.load_state_dict(vt_sd, strict=True)
436
+ if missing_vt or unexpected_vt:
437
+ raise RuntimeError(f"vision_tower load mismatch: missing={missing_vt}, unexpected={unexpected_vt}")
438
+
439
+ if has_embed_vision:
440
+ missing_ev, unexpected_ev = embed_vision.load_state_dict(ev_sd, strict=True)
441
+ if missing_ev or unexpected_ev:
442
+ raise RuntimeError(f"embed_vision load mismatch: missing={missing_ev}, unexpected={unexpected_ev}")
443
+
444
+ vision_tower.eval()
445
+ if embed_vision is not None:
446
+ embed_vision.eval()
447
+
448
+ model = cls(
449
+ vision_tower=vision_tower,
450
+ embed_vision=embed_vision,
451
+ vision_hidden_size=vision_hidden_size,
452
+ vision_soft_tokens_per_image=vision_soft_tokens_per_image,
453
+ text_hidden_size=text_hidden_size if has_embed_vision else vision_hidden_size,
454
+ num_output_tokens_reduced=num_output_tokens_reduced,
455
+ num_heads_for_token_reduction=num_heads_for_token_reduction,
456
+ perform_norm_latent_for_token_reduction=perform_norm_latent_for_token_reduction,
457
+ reducer_bottleneck_dim=reducer_bottleneck_dim,
458
+ reducer_project_back=reducer_project_back,
459
+ )
460
+ model.eval()
461
+ return model
462
+
463
+
464
+ def _demo_main():
465
+ import argparse
466
+ from PIL import Image
467
+ from transformers import AutoProcessor
468
+ from pathlib import Path
469
+
470
+ parser = argparse.ArgumentParser()
471
+ parser.add_argument("--model_dir", type=str, default="./model_weights/gemma3n_E2B_vision_only")
472
+ parser.add_argument("--device", type=str, default=None)
473
+ parser.add_argument("--dtype", type=str, default="float32", choices=["bfloat16", "float16", "float32"])
474
+ parser.add_argument("--num_output_tokens_reduced", type=int, default=32)
475
+ parser.add_argument("--reducer_bottleneck_dim", type=int, default=768)
476
+ parser.add_argument("--reducer_project_back", action="store_true")
477
+ args = parser.parse_args()
478
+
479
+ model_dir = str(Path(args.model_dir).resolve())
480
+
481
+ # Force local loading
482
+ processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
483
+
484
+ model = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir(
485
+ model_dir=model_dir,
486
+ map_location="cpu",
487
+ num_output_tokens_reduced=args.num_output_tokens_reduced,
488
+ num_heads_for_token_reduction=4,
489
+ reducer_bottleneck_dim=args.reducer_bottleneck_dim,
490
+ reducer_project_back=args.reducer_project_back,
491
+ )
492
+ model.init_weights()
493
+ model.to(device=args.device, dtype=args.dtype)
494
+ model.eval()
495
+
496
+ def count_params(module):
497
+ return sum(p.numel() for p in module.parameters())
498
+
499
+ vision_params = count_params(model.vision_tower)
500
+
501
+ embed_params = 0
502
+ if model.has_embed_vision and model.embed_vision is not None:
503
+ embed_params = count_params(model.embed_vision)
504
+
505
+ reducer_params = 0
506
+ if model.reducer is not None:
507
+ reducer_params = count_params(model.reducer)
508
+
509
+ frozen_params = vision_params + embed_params
510
+ total_params = frozen_params + reducer_params
511
+
512
+ print(f"Vision tower parameters (frozen): {vision_params:,}")
513
+
514
+ if model.has_embed_vision:
515
+ print(f"Embed vision parameters (frozen): {embed_params:,}")
516
+ else:
517
+ print("Embed vision: NONE")
518
+
519
+ if model.reducer is not None:
520
+ print(f"Reducer parameters (trainable): {reducer_params:,}")
521
+ else:
522
+ print("Reducer: NONE")
523
+
524
+ print(f"Total frozen parameters: {frozen_params:,}")
525
+ print(f"Total trainable parameters: {reducer_params:,}")
526
+ print(f"Total parameters: {total_params:,}")
527
+
528
+ img1 = Image.new("RGB", (768, 768), color=(0, 0, 0))
529
+ img2 = Image.new("RGB", (768, 768), color=(255, 255, 255))
530
+
531
+ inputs = processor(
532
+ text=["", ""],
533
+ images=[[img1], [img2]],
534
+ return_tensors="pt",
535
+ )
536
+
537
+ pixel_values = inputs["pixel_values"].to(
538
+ device=next(model.parameters()).device,
539
+ dtype=next(model.parameters()).dtype,
540
+ )
541
+
542
+ print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device)
543
+
544
+ with torch.no_grad():
545
+ feats, masks = model(pixel_values)
546
+
547
+ print("features:", tuple(feats.shape), feats.dtype, feats.device)
548
+ print("masks:", tuple(masks.shape), masks.dtype, masks.device)
549
+
550
+
551
+ if __name__ == "__main__":
552
+ _demo_main()
modelling/layer.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # layer.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F # noqa
10
+
11
+
12
+ class RMSNorm(nn.Module):
13
+ def __init__(self, dim: int, eps: float = 1e-6):
14
+ super().__init__()
15
+ self.eps = float(eps)
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ # x: [..., dim]
20
+ x_float = x.float()
21
+ rms = x_float.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
22
+ y = (x_float / rms).to(dtype=x.dtype)
23
+ return y * self.weight.to(dtype=x.dtype, device=x.device)
24
+
25
+
26
+ class SwiGLU(nn.Module):
27
+ @staticmethod
28
+ def forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
29
+ return nn.functional.silu(gate) * up
30
+
31
+
32
+ class TabularImageGQALayer(nn.Module):
33
+ """
34
+ Pre-norm Transformer block with:
35
+ - Tabular tokens produce Q; tabular+image produce KV (image optional)
36
+ - GQA: num_query_heads is a multiple of num_kv_heads
37
+ - Numeric+categorical must be concatenated before calling this layer (one tabular stream)
38
+ - attention_mask is 1D [B, T_tab] and does not include vision tokens
39
+ - If vision_features is None, attention is tabular-only
40
+ - Vision tokens are not updated (no Q for vision)
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ tabular_dim: int,
46
+ vision_dim: int,
47
+ num_query_heads: int,
48
+ num_kv_heads: int,
49
+ head_dim: int,
50
+ mlp_ratio: float = 4.0,
51
+ dropout: float = 0.0,
52
+ rmsnorm_eps: float = 1e-6,
53
+ ):
54
+ super().__init__()
55
+
56
+ if num_query_heads % num_kv_heads != 0:
57
+ raise ValueError("num_query_heads must be a multiple of num_kv_heads")
58
+
59
+ self.tabular_dim = int(tabular_dim)
60
+ self.vision_dim = int(vision_dim)
61
+ self.num_query_heads = int(num_query_heads)
62
+ self.num_kv_heads = int(num_kv_heads)
63
+ self.head_dim = int(head_dim)
64
+
65
+ self.q_dim = self.num_query_heads * self.head_dim
66
+ self.kv_dim = self.num_kv_heads * self.head_dim
67
+ self.group_size = self.num_query_heads // self.num_kv_heads
68
+
69
+ self.attn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps)
70
+
71
+ # Tabular projections (shared for numeric+categorical stream)
72
+ self.q_proj_tab = nn.Linear(self.tabular_dim, self.q_dim, bias=False)
73
+ self.k_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False)
74
+ self.v_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False)
75
+
76
+ # Vision KV projections (separate; vision has no Q)
77
+ self.k_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False)
78
+ self.v_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False)
79
+
80
+ self.o_proj = nn.Linear(self.q_dim, self.tabular_dim, bias=False)
81
+
82
+ self.attn_dropout = float(dropout)
83
+ self.resid_dropout = float(dropout)
84
+
85
+ # FFN (LLM-style: gated MLP with SwiGLU)
86
+ self.ffn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps)
87
+ ffn_dim = int(round(self.tabular_dim * float(mlp_ratio)))
88
+
89
+ self.gate_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False)
90
+ self.up_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False)
91
+ self.down_proj = nn.Linear(ffn_dim, self.tabular_dim, bias=False)
92
+ self.act = SwiGLU()
93
+
94
+ def init_weights(self, std: float = 0.02):
95
+ # RMSNorm
96
+ nn.init.ones_(self.attn_norm.weight)
97
+ nn.init.ones_(self.ffn_norm.weight)
98
+
99
+ # Attention projections
100
+ nn.init.normal_(self.q_proj_tab.weight, std=std)
101
+ nn.init.normal_(self.k_proj_tab.weight, std=std)
102
+ nn.init.normal_(self.v_proj_tab.weight, std=std)
103
+ nn.init.normal_(self.k_proj_img.weight, std=std)
104
+ nn.init.normal_(self.v_proj_img.weight, std=std)
105
+ nn.init.normal_(self.o_proj.weight, std=std)
106
+
107
+ # FFN
108
+ nn.init.normal_(self.gate_proj.weight, std=std)
109
+ nn.init.normal_(self.up_proj.weight, std=std)
110
+ nn.init.normal_(self.down_proj.weight, std=std)
111
+
112
+ @staticmethod
113
+ def _make_key_bias_from_mask(mask_1d: torch.Tensor, key_len: int) -> torch.Tensor:
114
+ """
115
+ mask_1d: [B, T_key] with 1=keep, 0=mask
116
+ returns: [B, 1, 1, T_key] float bias with 0 for keep and -inf for mask
117
+ """
118
+ if mask_1d.dtype != torch.float32:
119
+ mask_f = mask_1d.float()
120
+ else:
121
+ mask_f = mask_1d
122
+ if mask_f.shape[1] != key_len:
123
+ raise ValueError(f"mask_1d width mismatch: got {mask_f.shape[1]} expected {key_len}")
124
+ bias = (1.0 - mask_f) * -1e9
125
+ return bias.view(mask_f.shape[0], 1, 1, key_len)
126
+
127
+ def _split_heads_q(self, x: torch.Tensor) -> torch.Tensor:
128
+ # x: [B, T, Hq*d] -> [B, Hq, T, d]
129
+ B, T, _ = x.shape
130
+ return x.view(B, T, self.num_query_heads, self.head_dim).transpose(1, 2).contiguous()
131
+
132
+ def _split_heads_kv(self, x: torch.Tensor) -> torch.Tensor:
133
+ # x: [B, T, Hkv*d] -> [B, Hkv, T, d]
134
+ B, T, _ = x.shape
135
+ return x.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
136
+
137
+ @staticmethod
138
+ def _merge_heads_q(x: torch.Tensor) -> torch.Tensor:
139
+ # x: [B, Hq, T, d] -> [B, T, Hq*d]
140
+ B, H, T, d = x.shape
141
+ return x.transpose(1, 2).contiguous().view(B, T, H * d)
142
+
143
+ def forward(
144
+ self,
145
+ x_tab: torch.Tensor,
146
+ attention_mask: torch.Tensor,
147
+ vision_features: Optional[torch.Tensor] = None,
148
+ vision_mask: Optional[torch.Tensor] = None,
149
+ ) -> torch.Tensor:
150
+ """
151
+ x_tab: [B, T_tab, tabular_dim]
152
+ attention_mask: [B, T_tab] (1=valid tab token, 0=masked tab token). Does NOT include vision.
153
+ vision_features: None or [B, T_img, vision_dim]
154
+ vision_mask: None or [B, T_img] (1=valid vision token, 0=masked). Required if vision_features is not None.
155
+ returns: updated x_tab [B, T_tab, tabular_dim]
156
+ """
157
+ if x_tab.dim() != 3:
158
+ raise ValueError(f"x_tab must be [B,T,D], got {tuple(x_tab.shape)}")
159
+ if attention_mask.dim() != 2:
160
+ raise ValueError(f"attention_mask must be [B,T_tab], got {tuple(attention_mask.shape)}")
161
+
162
+ B, T_tab, D = x_tab.shape
163
+ if D != self.tabular_dim:
164
+ raise ValueError(f"tabular_dim mismatch: got {D}, expected {self.tabular_dim}")
165
+ if attention_mask.shape != (B, T_tab):
166
+ raise ValueError("attention_mask shape mismatch with x_tab")
167
+ if attention_mask.device != x_tab.device:
168
+ attention_mask = attention_mask.to(device=x_tab.device)
169
+
170
+ # ---- Attention block (pre-norm)
171
+ h = self.attn_norm(x_tab)
172
+
173
+ q_tab = self.q_proj_tab(h) # [B, T_tab, Hq*d]
174
+ k_tab = self.k_proj_tab(h) # [B, T_tab, Hkv*d]
175
+ v_tab = self.v_proj_tab(h) # [B, T_tab, Hkv*d]
176
+
177
+ q = self._split_heads_q(q_tab) # [B, Hq, T_tab, d]
178
+ k_tab = self._split_heads_kv(k_tab) # [B, Hkv, T_tab, d]
179
+ v_tab = self._split_heads_kv(v_tab) # [B, Hkv, T_tab, d]
180
+
181
+ if vision_features is None:
182
+ # Keys/values = tab only
183
+ k = k_tab
184
+ v = v_tab
185
+ key_mask = attention_mask # [B, T_tab]
186
+ else:
187
+ if vision_features.dim() != 3:
188
+ raise ValueError(f"vision_features must be [B,T_img,Dv], got {tuple(vision_features.shape)}")
189
+ if vision_features.shape[0] != B:
190
+ raise ValueError("vision_features batch mismatch")
191
+ if vision_features.shape[2] != self.vision_dim:
192
+ raise ValueError(f"vision_dim mismatch: got {vision_features.shape[2]}, expected {self.vision_dim}")
193
+
194
+ # Require vision_mask for strict missing handling
195
+ if vision_mask is None:
196
+ raise ValueError("vision_mask must be provided when vision_features is not None")
197
+ if vision_mask.dim() != 2:
198
+ raise ValueError(f"vision_mask must be [B,T_img], got {tuple(vision_mask.shape)}")
199
+
200
+ T_img = vision_features.shape[1]
201
+ if vision_mask.shape != (B, T_img):
202
+ raise ValueError(f"vision_mask shape mismatch: expected {(B, T_img)}, got {tuple(vision_mask.shape)}")
203
+
204
+ # Ensure mask dtype matches attention_mask dtype for concatenation
205
+ if vision_mask.dtype != attention_mask.dtype:
206
+ vision_mask = vision_mask.to(dtype=attention_mask.dtype)
207
+ if vision_mask.device != attention_mask.device:
208
+ vision_mask = vision_mask.to(device=attention_mask.device)
209
+
210
+ param = self.k_proj_img.weight
211
+ vision_features = vision_features.to(device=param.device, dtype=param.dtype)
212
+ k_img = self.k_proj_img(vision_features) # [B, T_img, Hkv*d]
213
+ v_img = self.v_proj_img(vision_features) # [B, T_img, Hkv*d]
214
+ k_img = self._split_heads_kv(k_img) # [B, Hkv, T_img, d]
215
+ v_img = self._split_heads_kv(v_img) # [B, Hkv, T_img, d]
216
+
217
+ k = torch.cat([k_tab, k_img], dim=2) # [B, Hkv, T_tab+T_img, d]
218
+ v = torch.cat([v_tab, v_img], dim=2) # [B, Hkv, T_tab+T_img, d]
219
+
220
+ # STRICT key mask: tab_mask + vision_mask
221
+ key_mask = torch.cat([attention_mask, vision_mask], dim=1) # [B, T_tab+T_img]
222
+
223
+ # Expand KV heads to Q heads (GQA)
224
+ if self.group_size != 1:
225
+ k = k.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d]
226
+ v = v.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d]
227
+
228
+ T_k = k.shape[2]
229
+ key_bias = self._make_key_bias_from_mask(key_mask, key_len=T_k) # [B,1,1,T_k]
230
+
231
+ # Attention scores: [B, Hq, T_tab, T_k]
232
+ scale = 1.0 / math.sqrt(self.head_dim)
233
+ attn_scores = torch.einsum("bhtd,bhkd->bhtk", q, k) * scale
234
+ attn_scores = attn_scores + key_bias # broadcast
235
+
236
+ attn_probs = F.softmax(attn_scores.float(), dim=-1)
237
+ if self.attn_dropout > 0.0 and self.training:
238
+ attn_probs = F.dropout(attn_probs, p=self.attn_dropout)
239
+ attn_probs = attn_probs.to(v.dtype)
240
+
241
+ attn_out = torch.einsum("bhtk,bhkd->bhtd", attn_probs, v) # [B,Hq,T_tab,d]
242
+ attn_out = self._merge_heads_q(attn_out) # [B,T_tab,Hq*d]
243
+ attn_out = self.o_proj(attn_out) # [B,T_tab,tab_dim]
244
+
245
+ # Query-side masking (tab only): prevents masked tab tokens from updating residual path
246
+ attn_out = attn_out * attention_mask.to(attn_out.dtype).unsqueeze(-1)
247
+
248
+ if self.resid_dropout > 0.0 and self.training:
249
+ attn_out = F.dropout(attn_out, p=self.resid_dropout)
250
+
251
+ x = x_tab + attn_out
252
+
253
+ # ---- FFN block (pre-norm)
254
+ h2 = self.ffn_norm(x)
255
+ gate = self.gate_proj(h2)
256
+ up = self.up_proj(h2)
257
+ f = self.act(gate, up)
258
+ f = self.down_proj(f)
259
+
260
+ # Query-side masking (tab only)
261
+ f = f * attention_mask.to(f.dtype).unsqueeze(-1)
262
+
263
+ if self.resid_dropout > 0.0 and self.training:
264
+ f = F.dropout(f, p=self.resid_dropout)
265
+
266
+ x = x + f
267
+ return x
268
+
269
+
270
+ def _count_params(m: nn.Module) -> Tuple[int, int]:
271
+ total = sum(p.numel() for p in m.parameters())
272
+ trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
273
+ return total, trainable
274
+
275
+
276
+ def _demo_main():
277
+ import argparse
278
+
279
+ parser = argparse.ArgumentParser()
280
+ parser.add_argument("--batch_size", type=int, default=4)
281
+ parser.add_argument("--t_tab", type=int, default=126)
282
+ parser.add_argument("--t_img", type=int, default=256)
283
+ parser.add_argument("--tabular_dim", type=int, default=768)
284
+ parser.add_argument("--vision_dim", type=int, default=768)
285
+ parser.add_argument("--num_query_heads", type=int, default=8)
286
+ parser.add_argument("--num_kv_heads", type=int, default=2)
287
+ parser.add_argument("--head_dim", type=int, default=128)
288
+ parser.add_argument("--mlp_ratio", type=float, default=1.5)
289
+ parser.add_argument("--dropout", type=float, default=0.0)
290
+ parser.add_argument("--with_vision", action="store_true")
291
+ parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
292
+ parser.add_argument("--device", type=str, default=None)
293
+ args = parser.parse_args()
294
+
295
+ device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
296
+ dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
297
+ dtype = dtype_map[args.dtype]
298
+
299
+ layer = TabularImageGQALayer(
300
+ tabular_dim=args.tabular_dim,
301
+ vision_dim=args.vision_dim,
302
+ num_query_heads=args.num_query_heads,
303
+ num_kv_heads=args.num_kv_heads,
304
+ head_dim=args.head_dim,
305
+ mlp_ratio=args.mlp_ratio,
306
+ dropout=args.dropout,
307
+ ).to(device=device, dtype=dtype)
308
+
309
+ total, trainable = _count_params(layer)
310
+ print(f"Layer parameters: {total:,} (trainable: {trainable:,})")
311
+
312
+ B = args.batch_size
313
+ T_tab = args.t_tab
314
+
315
+ x_tab = torch.randn(B, T_tab, args.tabular_dim, device=device, dtype=dtype)
316
+
317
+ # Build a typical HF-style 1D attention mask: 1 for valid, 0 for masked/padded.
318
+ # Here we create variable valid lengths.
319
+ lengths = torch.randint(low=max(1, T_tab // 2), high=T_tab + 1, size=(B,), device=device)
320
+ attention_mask = torch.zeros(B, T_tab, device=device, dtype=torch.long)
321
+ for b in range(B):
322
+ attention_mask[b, : int(lengths[b].item())] = 1
323
+
324
+ if args.with_vision:
325
+ vision = torch.randn(B, args.t_img, args.vision_dim, device=device, dtype=dtype)
326
+
327
+ # Example vision mask: first half valid for sample 0, all valid for others
328
+ vision_mask = torch.ones(B, args.t_img, device=device, dtype=torch.long)
329
+ if args.t_img > 0:
330
+ vision_mask[0, args.t_img // 2:] = 0
331
+ else:
332
+ vision = None
333
+ vision_mask = None
334
+
335
+ print("Input x_tab:", tuple(x_tab.shape), x_tab.dtype, x_tab.device)
336
+ print("Input attention_mask:", tuple(attention_mask.shape), attention_mask.dtype, attention_mask.device)
337
+ print("Input vision_features:", None if vision is None else (tuple(vision.shape), vision.dtype, vision.device))
338
+ print("Input vision_mask:",
339
+ None if vision_mask is None else (tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device))
340
+
341
+ with torch.no_grad():
342
+ y = layer(
343
+ x_tab=x_tab,
344
+ attention_mask=attention_mask,
345
+ vision_features=vision,
346
+ vision_mask=vision_mask,
347
+ )
348
+
349
+ print("Output y_tab:", tuple(y.shape), y.dtype, y.device)
350
+
351
+
352
+ if __name__ == "__main__":
353
+ _demo_main()
modelling/loader.py ADDED
@@ -0,0 +1,1025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # loader.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import ast
5
+ from io import BytesIO
6
+ from urllib.parse import urljoin
7
+
8
+ import pandas as pd
9
+ import requests
10
+ import torch
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torchvision import transforms
14
+
15
+ from utils import load_json
16
+
17
+
18
+ class CenterSquareCrop:
19
+ """
20
+ Crop image to a centered square without resizing.
21
+ """
22
+
23
+ def __call__(self, img: Image.Image):
24
+ w, h = img.size
25
+
26
+ if w == h:
27
+ return img
28
+
29
+ if w > h:
30
+ left = (w - h) // 2
31
+ right = left + h
32
+ top = 0
33
+ bottom = h
34
+ else:
35
+ top = (h - w) // 2
36
+ bottom = top + w
37
+ left = 0
38
+ right = w
39
+ return img.crop((left, top, right, bottom))
40
+
41
+
42
+ def build_image_transform(image_size: int):
43
+ return transforms.Compose([
44
+ CenterSquareCrop(),
45
+ transforms.Resize((image_size, image_size)),
46
+ transforms.ToTensor(),
47
+ ])
48
+
49
+
50
+ def join_photo_root(photo_root: str, relative_path: str) -> str:
51
+ """
52
+ Join photo_root and relative path.
53
+
54
+ Supports:
55
+ - local filesystem roots
56
+ - http / https roots
57
+ """
58
+ if photo_root.startswith("http://") or photo_root.startswith("https://"): # noqa
59
+ return urljoin(photo_root.rstrip("/") + "/", relative_path)
60
+
61
+ return photo_root.rstrip("/") + "/" + relative_path.lstrip("/")
62
+
63
+
64
+ def parse_numeric_cell(value: str, n_in: int):
65
+ """
66
+ Convert numeric csv cell to list[float].
67
+
68
+ Returns:
69
+ values, is_valid
70
+
71
+ Data assumption:
72
+ - Empty value is always ""
73
+ - Scalar numeric -> "12.3"
74
+ - Vector numeric -> "[1.2,3.4,5.6]"
75
+ """
76
+ if value == "":
77
+ return [0.0] * n_in, False
78
+
79
+ if n_in == 1:
80
+ return [float(value)], True
81
+
82
+ vec = ast.literal_eval(value)
83
+ if len(vec) != n_in:
84
+ raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(vec)}")
85
+ return [float(v) for v in vec], True
86
+
87
+
88
+ class SoilFormerDataset(Dataset):
89
+
90
+ def __init__(
91
+ self,
92
+ csv_path: str,
93
+ photo_map_path: str,
94
+ cat_vocab_path: str,
95
+ numeric_vocab_path: str,
96
+ numeric_stats_path: str,
97
+ photo_root: str,
98
+ image_size: int = 512,
99
+ id_column: str = "id",
100
+ ):
101
+ self.df = pd.read_csv(
102
+ csv_path,
103
+ keep_default_na=False,
104
+ na_filter=False,
105
+ low_memory=False,
106
+ )
107
+
108
+ self.photo_map = load_json(photo_map_path)
109
+ self.cat_vocab = load_json(cat_vocab_path)
110
+ self.numeric_vocab = load_json(numeric_vocab_path)
111
+
112
+ self.photo_root = photo_root
113
+ self.id_column = id_column
114
+ self.image_size = int(image_size)
115
+ self.image_transform = build_image_transform(self.image_size)
116
+
117
+ # Keep json order exactly
118
+ self.cat_columns = list(self.cat_vocab.keys())
119
+ self.numeric_groups = self.numeric_vocab["groups"]
120
+ self.numeric_stats_df = pd.read_csv(numeric_stats_path)
121
+ self.numeric_stats_index = self.numeric_stats_df.set_index("column")
122
+
123
+ # Numeric mean/std
124
+ self.numeric_stats = {}
125
+ for _, row in self.numeric_stats_df.iterrows():
126
+ col = row["column"]
127
+ mean = float(row["mean"])
128
+ std = float(row["std"])
129
+ if std == 0.0:
130
+ std = 1.0
131
+ self.numeric_stats[col] = (mean, std)
132
+
133
+ # For active masking
134
+ self.cat_mask_local_ids = torch.tensor(
135
+ [int(self.cat_vocab[col]["mask_local_id"]) for col in self.cat_columns],
136
+ dtype=torch.long,
137
+ )
138
+
139
+ def __len__(self):
140
+ return len(self.df)
141
+
142
+ def load_image(self, path: str):
143
+ if path.startswith("http://") or path.startswith("https://"): # noqa
144
+ resp = requests.get(path, timeout=(3, 10))
145
+ resp.raise_for_status()
146
+ img = Image.open(BytesIO(resp.content)).convert("RGB")
147
+ else:
148
+ img = Image.open(path).convert("RGB")
149
+
150
+ return self.image_transform(img)
151
+
152
+ def __getitem__(self, idx):
153
+ row = self.df.iloc[idx]
154
+ sample_id = row[self.id_column]
155
+
156
+ # -----------------------
157
+ # categorical features
158
+ # -----------------------
159
+ cat_ids = []
160
+ cat_valids = []
161
+
162
+ for col in self.cat_columns:
163
+ spec = self.cat_vocab[col]
164
+ label2id = spec["label2id"]
165
+ mask_id = spec["mask_local_id"]
166
+
167
+ value = row[col]
168
+
169
+ if value == "":
170
+ cat_ids.append(mask_id)
171
+ cat_valids.append(False)
172
+ else:
173
+ if value not in label2id:
174
+ raise KeyError(f"Unknown categorical value: column={col}, value={value!r}")
175
+ cat_ids.append(label2id[value])
176
+ cat_valids.append(True)
177
+
178
+ cat_ids = torch.tensor(cat_ids, dtype=torch.long)
179
+ cat_valids = torch.tensor(cat_valids, dtype=torch.bool)
180
+
181
+ # -----------------------
182
+ # numeric features
183
+ # -----------------------
184
+ numeric_values_by_nin = {}
185
+ numeric_valid_positions_by_nin = {}
186
+
187
+ for group in self.numeric_groups:
188
+ n_in = int(group["n_in"])
189
+ features = group["feature_names"]
190
+
191
+ values = []
192
+ valids = []
193
+
194
+ for feat in features:
195
+ cell = row[feat]
196
+ parsed, is_valid = parse_numeric_cell(cell, n_in)
197
+ if is_valid:
198
+ mean, std = self.numeric_stats[feat]
199
+ parsed = [(v - mean) / std for v in parsed]
200
+ values.append(parsed)
201
+ valids.append(is_valid)
202
+
203
+ numeric_values_by_nin[n_in] = torch.tensor(values, dtype=torch.float32)
204
+ numeric_valid_positions_by_nin[n_in] = torch.tensor(valids, dtype=torch.bool)
205
+
206
+ # -----------------------
207
+ # vision
208
+ # -----------------------
209
+ try:
210
+ relative_path = self.photo_map[sample_id]
211
+ full_path = join_photo_root(self.photo_root, relative_path)
212
+ image = self.load_image(full_path)
213
+ vision_valid = True
214
+ except Exception: # noqa
215
+ image = torch.zeros(3, self.image_size, self.image_size, dtype=torch.float32)
216
+ vision_valid = False
217
+
218
+ vision_valid = torch.tensor(vision_valid, dtype=torch.bool)
219
+
220
+ return {
221
+ "row_idx": torch.tensor(idx, dtype=torch.long),
222
+ "sample_id": sample_id,
223
+ "cat_local_ids": cat_ids,
224
+ "cat_valid_positions": cat_valids,
225
+ "numeric_values_by_nin": numeric_values_by_nin,
226
+ "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
227
+ "pixel_values": image,
228
+ "vision_valid_positions": vision_valid,
229
+ }
230
+
231
+ @staticmethod
232
+ def collate_fn(batch):
233
+ cat_ids = torch.stack([b["cat_local_ids"] for b in batch], dim=0)
234
+ cat_valids = torch.stack([b["cat_valid_positions"] for b in batch], dim=0)
235
+
236
+ group_keys = list(batch[0]["numeric_values_by_nin"].keys())
237
+
238
+ numeric_values_by_nin = {}
239
+ numeric_valid_positions_by_nin = {}
240
+
241
+ for k in group_keys:
242
+ numeric_values_by_nin[k] = torch.stack(
243
+ [b["numeric_values_by_nin"][k] for b in batch],
244
+ dim=0,
245
+ )
246
+ numeric_valid_positions_by_nin[k] = torch.stack(
247
+ [b["numeric_valid_positions_by_nin"][k] for b in batch],
248
+ dim=0,
249
+ )
250
+
251
+ pixel_values = torch.stack([b["pixel_values"] for b in batch], dim=0)
252
+ vision_valid_positions = torch.stack([b["vision_valid_positions"] for b in batch], dim=0)
253
+ row_idx = torch.stack([b["row_idx"] for b in batch], dim=0)
254
+ sample_ids = [b["sample_id"] for b in batch]
255
+
256
+ return {
257
+ "row_idx": row_idx,
258
+ "sample_id": sample_ids,
259
+ "cat_local_ids": cat_ids,
260
+ "numeric_values_by_nin": numeric_values_by_nin,
261
+ "cat_valid_positions": cat_valids,
262
+ "numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
263
+ "pixel_values": pixel_values,
264
+ "vision_valid_positions": vision_valid_positions,
265
+ }
266
+
267
+ def perform_active_mask(self, batch, cat_ratio=0.15, num_ratio=0.15, seed=None):
268
+ """
269
+ Apply active masking to categorical and numeric inputs.
270
+
271
+ Conventions
272
+ -----------
273
+ Input batch must contain:
274
+ - cat_local_ids: [B, M] LongTensor
275
+ - cat_valid_positions: [B, M] Bool/0-1 tensor
276
+ - numeric_values_by_nin: Dict[int, Tensor[B, V, n_in]]
277
+ - numeric_valid_positions_by_nin: Dict[int, Tensor[B, V]]
278
+
279
+ Output batch will additionally contain:
280
+ - original_cat_local_ids
281
+ - original_cat_valid_positions
282
+ - original_numeric_values_by_nin
283
+ - original_numeric_valid_positions_by_nin
284
+
285
+ - masked_cat_local_ids
286
+ - masked_cat_valid_positions
287
+ - masked_numeric_values_by_nin
288
+ - masked_numeric_valid_positions_by_nin
289
+
290
+ - cat_loss_mask: [B, M] BoolTensor
291
+ - numeric_loss_mask_by_nin: Dict[int, BoolTensor[B, V]]
292
+
293
+ Semantics
294
+ ---------
295
+ - Only originally valid positions can be actively masked.
296
+ - Masked categorical positions:
297
+ local_id -> self.cat_mask_local_ids[col]
298
+ valid -> False
299
+ - Masked numeric positions:
300
+ values -> 0
301
+ valid -> False
302
+ - original_* fields always preserve the unmodified input batch content.
303
+ """
304
+ # --------------------------------------------------
305
+ # Validate ratios
306
+ # --------------------------------------------------
307
+ if not (0.0 <= cat_ratio <= 1.0):
308
+ raise ValueError(f"cat_ratio must be in [0, 1], got {cat_ratio}")
309
+ if not (0.0 <= num_ratio <= 1.0):
310
+ raise ValueError(f"num_ratio must be in [0, 1], got {num_ratio}")
311
+
312
+ # --------------------------------------------------
313
+ # Validate required keys
314
+ # --------------------------------------------------
315
+ required_keys = [
316
+ "cat_local_ids",
317
+ "cat_valid_positions",
318
+ "numeric_values_by_nin",
319
+ "numeric_valid_positions_by_nin",
320
+ ]
321
+ for k in required_keys:
322
+ if k not in batch:
323
+ raise KeyError(f"Missing key in batch: {k}")
324
+
325
+ cat_local_ids = batch["cat_local_ids"]
326
+ cat_valid_positions = batch["cat_valid_positions"]
327
+ numeric_values_by_nin = batch["numeric_values_by_nin"]
328
+ numeric_valid_positions_by_nin = batch["numeric_valid_positions_by_nin"]
329
+
330
+ if cat_local_ids.dim() != 2:
331
+ raise ValueError(f"cat_local_ids must be [B, M], got {tuple(cat_local_ids.shape)}")
332
+ if cat_valid_positions.shape != cat_local_ids.shape:
333
+ raise ValueError(
334
+ f"cat_valid_positions must match cat_local_ids shape, got "
335
+ f"{tuple(cat_valid_positions.shape)} vs {tuple(cat_local_ids.shape)}"
336
+ )
337
+
338
+ if not isinstance(numeric_values_by_nin, dict):
339
+ raise ValueError("numeric_values_by_nin must be a dict")
340
+ if not isinstance(numeric_valid_positions_by_nin, dict):
341
+ raise ValueError("numeric_valid_positions_by_nin must be a dict")
342
+
343
+ B, M = cat_local_ids.shape
344
+ device = cat_local_ids.device
345
+
346
+ if self.cat_mask_local_ids.dim() != 1 or self.cat_mask_local_ids.numel() != M:
347
+ raise ValueError(
348
+ f"self.cat_mask_local_ids must be [M] with M={M}, got {tuple(self.cat_mask_local_ids.shape)}"
349
+ )
350
+ cat_mask_local_ids = self.cat_mask_local_ids.to(device=device, dtype=cat_local_ids.dtype)
351
+
352
+ # --------------------------------------------------
353
+ # Random generator
354
+ # --------------------------------------------------
355
+ if device.type == "cuda":
356
+ generator = torch.Generator(device=device)
357
+ else:
358
+ generator = torch.Generator()
359
+
360
+ if seed is not None:
361
+ generator.manual_seed(seed)
362
+
363
+ # --------------------------------------------------
364
+ # Start from shallow copy only
365
+ # --------------------------------------------------
366
+ masked_batch = dict(batch)
367
+
368
+ # Preserve original aliases (do NOT deepcopy)
369
+ masked_batch["original_cat_local_ids"] = batch["cat_local_ids"]
370
+ masked_batch["original_cat_valid_positions"] = batch["cat_valid_positions"]
371
+ masked_batch["original_numeric_values_by_nin"] = batch["numeric_values_by_nin"]
372
+ masked_batch["original_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"]
373
+
374
+ # --------------------------------------------------
375
+ # Fast path: no active masking at all
376
+ # --------------------------------------------------
377
+ if cat_ratio == 0.0 and num_ratio == 0.0:
378
+ masked_batch["masked_cat_local_ids"] = batch["cat_local_ids"]
379
+ masked_batch["masked_cat_valid_positions"] = batch["cat_valid_positions"]
380
+
381
+ masked_batch["masked_numeric_values_by_nin"] = batch["numeric_values_by_nin"]
382
+ masked_batch["masked_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"]
383
+
384
+ masked_batch["cat_loss_mask"] = torch.zeros(
385
+ (B, M), dtype=torch.bool, device=device
386
+ )
387
+ masked_batch["numeric_loss_mask_by_nin"] = {
388
+ n_in: torch.zeros_like(valid_positions, dtype=torch.bool)
389
+ for n_in, valid_positions in numeric_valid_positions_by_nin.items()
390
+ }
391
+ return masked_batch
392
+
393
+ # --------------------------------------------------
394
+ # Categorical masking
395
+ # --------------------------------------------------
396
+ original_cat_valid_positions = cat_valid_positions.bool()
397
+
398
+ masked_cat_local_ids = cat_local_ids.clone()
399
+ masked_cat_valid_positions = original_cat_valid_positions.clone()
400
+ cat_loss_mask = torch.zeros((B, M), dtype=torch.bool, device=device)
401
+
402
+ if cat_ratio > 0.0:
403
+ for b in range(B):
404
+ valid_idx = torch.nonzero(original_cat_valid_positions[b], as_tuple=False).squeeze(1)
405
+ n_valid = valid_idx.numel()
406
+ if n_valid == 0:
407
+ continue
408
+
409
+ k = int(round(n_valid * cat_ratio))
410
+ if k <= 0:
411
+ continue
412
+ if k > n_valid:
413
+ k = n_valid
414
+
415
+ perm = valid_idx[
416
+ torch.randperm(n_valid, generator=generator, device=device)[:k]
417
+ ]
418
+ cat_loss_mask[b, perm] = True
419
+
420
+ expanded_cat_mask_ids = cat_mask_local_ids.view(1, M).expand(B, M)
421
+ masked_cat_local_ids[cat_loss_mask] = expanded_cat_mask_ids[cat_loss_mask]
422
+ masked_cat_valid_positions = masked_cat_valid_positions & (~cat_loss_mask)
423
+
424
+ masked_batch["masked_cat_local_ids"] = masked_cat_local_ids
425
+ masked_batch["masked_cat_valid_positions"] = masked_cat_valid_positions
426
+ masked_batch["cat_loss_mask"] = cat_loss_mask
427
+
428
+ # --------------------------------------------------
429
+ # Numeric masking
430
+ # --------------------------------------------------
431
+ masked_numeric_values_by_nin = {}
432
+ masked_numeric_valid_positions_by_nin = {}
433
+ numeric_loss_mask_by_nin = {}
434
+
435
+ # keep deterministic ordering if caller passed mixed int-like keys
436
+ for n_in in sorted(numeric_values_by_nin.keys(), key=int):
437
+ values = numeric_values_by_nin[n_in]
438
+ if n_in not in numeric_valid_positions_by_nin:
439
+ raise KeyError(f"Missing numeric_valid_positions_by_nin[{n_in}]")
440
+
441
+ valid_positions = numeric_valid_positions_by_nin[n_in]
442
+
443
+ if values.dim() != 3:
444
+ raise ValueError(
445
+ f"numeric_values_by_nin[{n_in}] must be [B, V, n_in], got {tuple(values.shape)}"
446
+ )
447
+
448
+ Bn, V, Nin = values.shape
449
+ if Bn != B:
450
+ raise ValueError(
451
+ f"numeric_values_by_nin[{n_in}] batch mismatch: got {Bn}, expected {B}"
452
+ )
453
+ if int(Nin) != int(n_in):
454
+ raise ValueError(
455
+ f"numeric_values_by_nin[{n_in}] last dim mismatch: got {Nin}, expected {n_in}"
456
+ )
457
+ if valid_positions.shape != (B, V):
458
+ raise ValueError(
459
+ f"numeric_valid_positions_by_nin[{n_in}] must be [B,V]=({B},{V}), "
460
+ f"got {tuple(valid_positions.shape)}"
461
+ )
462
+
463
+ original_valid = valid_positions.bool()
464
+
465
+ # IMPORTANT: clone before modifying
466
+ masked_values = values.clone()
467
+ masked_valid_positions = original_valid.clone()
468
+ num_loss_mask = torch.zeros((B, V), dtype=torch.bool, device=values.device)
469
+
470
+ if num_ratio > 0.0:
471
+ for b in range(B):
472
+ valid_idx = torch.nonzero(original_valid[b], as_tuple=False).squeeze(1)
473
+ n_valid = valid_idx.numel()
474
+ if n_valid == 0:
475
+ continue
476
+
477
+ k = int(round(n_valid * num_ratio))
478
+ if k <= 0:
479
+ continue
480
+ if k > n_valid:
481
+ k = n_valid
482
+
483
+ perm = valid_idx[
484
+ torch.randperm(n_valid, generator=generator, device=values.device)[:k]
485
+ ]
486
+ num_loss_mask[b, perm] = True
487
+
488
+ # masked numeric columns become zero and invalid
489
+ masked_values[num_loss_mask] = 0.0
490
+ masked_valid_positions = masked_valid_positions & (~num_loss_mask)
491
+
492
+ masked_numeric_values_by_nin[n_in] = masked_values
493
+ masked_numeric_valid_positions_by_nin[n_in] = masked_valid_positions
494
+ numeric_loss_mask_by_nin[n_in] = num_loss_mask
495
+
496
+ masked_batch["masked_numeric_values_by_nin"] = masked_numeric_values_by_nin
497
+ masked_batch["masked_numeric_valid_positions_by_nin"] = masked_numeric_valid_positions_by_nin
498
+ masked_batch["numeric_loss_mask_by_nin"] = numeric_loss_mask_by_nin
499
+
500
+ return masked_batch
501
+
502
+
503
+ def perform_active_mask_single(self, batch, feature_name, assert_not_missing=True):
504
+ """
505
+ Actively mask exactly one feature specified by feature_name.
506
+
507
+ Parameters
508
+ ----------
509
+ batch : dict
510
+ Same input convention as perform_active_mask(...).
511
+ feature_name : str
512
+ Full feature name. Can be either categorical or numeric.
513
+ assert_not_missing : bool
514
+ If True, require the target feature to be originally valid for all samples
515
+ in the batch. Otherwise raise ValueError.
516
+ If False, only originally valid positions are masked; naturally missing
517
+ positions remain missing and are not included in the loss mask.
518
+
519
+ Returns
520
+ -------
521
+ masked_batch : dict
522
+ Same output convention as perform_active_mask(...), except that exactly
523
+ one feature is actively masked.
524
+ """
525
+
526
+ # --------------------------------------------------
527
+ # Validate required keys
528
+ # --------------------------------------------------
529
+ required_keys = [
530
+ "cat_local_ids",
531
+ "cat_valid_positions",
532
+ "numeric_values_by_nin",
533
+ "numeric_valid_positions_by_nin",
534
+ ]
535
+ for k in required_keys:
536
+ if k not in batch:
537
+ raise KeyError(f"Missing key in batch: {k}")
538
+
539
+ cat_local_ids = batch["cat_local_ids"]
540
+ cat_valid_positions = batch["cat_valid_positions"]
541
+ numeric_values_by_nin = batch["numeric_values_by_nin"]
542
+ numeric_valid_positions_by_nin = batch["numeric_valid_positions_by_nin"]
543
+
544
+ if cat_local_ids.dim() != 2:
545
+ raise ValueError(f"cat_local_ids must be [B, M], got {tuple(cat_local_ids.shape)}")
546
+ if cat_valid_positions.shape != cat_local_ids.shape:
547
+ raise ValueError(
548
+ f"cat_valid_positions must match cat_local_ids shape, got "
549
+ f"{tuple(cat_valid_positions.shape)} vs {tuple(cat_local_ids.shape)}"
550
+ )
551
+
552
+ if not isinstance(numeric_values_by_nin, dict):
553
+ raise ValueError("numeric_values_by_nin must be a dict")
554
+ if not isinstance(numeric_valid_positions_by_nin, dict):
555
+ raise ValueError("numeric_valid_positions_by_nin must be a dict")
556
+
557
+ B, M = cat_local_ids.shape
558
+ device = cat_local_ids.device
559
+
560
+ if self.cat_mask_local_ids.dim() != 1 or self.cat_mask_local_ids.numel() != M:
561
+ raise ValueError(
562
+ f"self.cat_mask_local_ids must be [M] with M={M}, got {tuple(self.cat_mask_local_ids.shape)}"
563
+ )
564
+ cat_mask_local_ids = self.cat_mask_local_ids.to(device=device, dtype=cat_local_ids.dtype)
565
+
566
+ # --------------------------------------------------
567
+ # Resolve feature_name -> categorical col or numeric (n_in, v_idx)
568
+ # --------------------------------------------------
569
+ # Assumptions:
570
+ # - self.cat_vocab is the categorical vocab dict keyed by full feature name
571
+ # - self.numeric_vocab contains:
572
+ # numeric_vocab["ordered_feature_names"]
573
+ # numeric_vocab["features"][name]["n_in"]
574
+ # numeric_vocab["features"][name]["col_id"]
575
+ #
576
+ # If your actual attribute names differ, only this block needs adaptation.
577
+ is_cat = False
578
+ is_num = False
579
+ cat_col = None
580
+ num_n_in = None
581
+ num_v_idx = None
582
+
583
+ # categorical
584
+ if hasattr(self, "cat_vocab") and feature_name in self.cat_vocab:
585
+ is_cat = True
586
+ cat_col = int(self.cat_vocab[feature_name]["col_id"])
587
+
588
+ # numeric
589
+ if hasattr(self, "numeric_vocab"):
590
+ num_features = self.numeric_vocab.get("features", {})
591
+ if feature_name in num_features:
592
+ is_num = True
593
+ meta = num_features[feature_name]
594
+ num_n_in = int(meta["n_in"])
595
+ num_v_idx = int(meta["col_id"])
596
+
597
+ if is_cat and is_num:
598
+ raise ValueError(f"Feature name appears in both categorical and numeric vocab: {feature_name}")
599
+ if not is_cat and not is_num:
600
+ raise KeyError(f"Unknown feature_name: {feature_name}")
601
+
602
+ # --------------------------------------------------
603
+ # Start from shallow copy only
604
+ # --------------------------------------------------
605
+ masked_batch = dict(batch)
606
+
607
+ # Preserve original aliases (do NOT deepcopy)
608
+ masked_batch["original_cat_local_ids"] = batch["cat_local_ids"]
609
+ masked_batch["original_cat_valid_positions"] = batch["cat_valid_positions"]
610
+ masked_batch["original_numeric_values_by_nin"] = batch["numeric_values_by_nin"]
611
+ masked_batch["original_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"]
612
+
613
+ # --------------------------------------------------
614
+ # Default: no masking anywhere
615
+ # --------------------------------------------------
616
+ masked_cat_local_ids = batch["cat_local_ids"].clone()
617
+ masked_cat_valid_positions = batch["cat_valid_positions"].bool().clone()
618
+ cat_loss_mask = torch.zeros((B, M), dtype=torch.bool, device=device)
619
+
620
+ masked_numeric_values_by_nin = {}
621
+ masked_numeric_valid_positions_by_nin = {}
622
+ numeric_loss_mask_by_nin = {}
623
+
624
+ for n_in in sorted(numeric_values_by_nin.keys(), key=int):
625
+ values = numeric_values_by_nin[n_in]
626
+ if n_in not in numeric_valid_positions_by_nin:
627
+ raise KeyError(f"Missing numeric_valid_positions_by_nin[{n_in}]")
628
+
629
+ valid_positions = numeric_valid_positions_by_nin[n_in]
630
+
631
+ if values.dim() != 3:
632
+ raise ValueError(
633
+ f"numeric_values_by_nin[{n_in}] must be [B, V, n_in], got {tuple(values.shape)}"
634
+ )
635
+
636
+ Bn, V, Nin = values.shape
637
+ if Bn != B:
638
+ raise ValueError(
639
+ f"numeric_values_by_nin[{n_in}] batch mismatch: got {Bn}, expected {B}"
640
+ )
641
+ if int(Nin) != int(n_in):
642
+ raise ValueError(
643
+ f"numeric_values_by_nin[{n_in}] last dim mismatch: got {Nin}, expected {n_in}"
644
+ )
645
+ if valid_positions.shape != (B, V):
646
+ raise ValueError(
647
+ f"numeric_valid_positions_by_nin[{n_in}] must be [B,V]=({B},{V}), "
648
+ f"got {tuple(valid_positions.shape)}"
649
+ )
650
+
651
+ masked_numeric_values_by_nin[n_in] = values.clone()
652
+ masked_numeric_valid_positions_by_nin[n_in] = valid_positions.bool().clone()
653
+ numeric_loss_mask_by_nin[n_in] = torch.zeros((B, V), dtype=torch.bool, device=values.device)
654
+
655
+ # --------------------------------------------------
656
+ # Apply single-feature masking
657
+ # --------------------------------------------------
658
+ if is_cat:
659
+ original_valid = cat_valid_positions[:, cat_col].bool() # [B]
660
+
661
+ if assert_not_missing and not bool(original_valid.all().item()):
662
+ n_bad = int((~original_valid).sum().item())
663
+ raise ValueError(
664
+ f"Categorical feature '{feature_name}' has {n_bad} naturally missing samples in batch"
665
+ )
666
+
667
+ # only originally valid positions are actively masked
668
+ cat_loss_mask[:, cat_col] = original_valid
669
+
670
+ masked_cat_local_ids[cat_loss_mask] = cat_mask_local_ids.view(1, M).expand(B, M)[cat_loss_mask]
671
+ masked_cat_valid_positions = masked_cat_valid_positions & (~cat_loss_mask)
672
+
673
+ else:
674
+ if num_n_in not in masked_numeric_values_by_nin:
675
+ raise KeyError(f"numeric_values_by_nin does not contain n_in={num_n_in} for {feature_name}")
676
+
677
+ values = masked_numeric_values_by_nin[num_n_in]
678
+ valid_positions = masked_numeric_valid_positions_by_nin[num_n_in]
679
+ num_loss_mask = numeric_loss_mask_by_nin[num_n_in]
680
+
681
+ if num_v_idx >= values.shape[1]:
682
+ raise IndexError(
683
+ f"Numeric feature '{feature_name}' resolved to v_idx={num_v_idx}, "
684
+ f"but numeric_values_by_nin[{num_n_in}] has V={values.shape[1]}"
685
+ )
686
+
687
+ original_valid = valid_positions[:, num_v_idx].bool() # [B]
688
+
689
+ if assert_not_missing and not bool(original_valid.all().item()):
690
+ n_bad = int((~original_valid).sum().item())
691
+ raise ValueError(
692
+ f"Numeric feature '{feature_name}' has {n_bad} naturally missing samples in batch"
693
+ )
694
+
695
+ # only originally valid positions are actively masked
696
+ num_loss_mask[:, num_v_idx] = original_valid
697
+
698
+ values[num_loss_mask] = 0.0
699
+ valid_positions[:] = valid_positions & (~num_loss_mask)
700
+
701
+ # --------------------------------------------------
702
+ # Finalize outputs
703
+ # --------------------------------------------------
704
+ masked_batch["masked_cat_local_ids"] = masked_cat_local_ids
705
+ masked_batch["masked_cat_valid_positions"] = masked_cat_valid_positions
706
+ masked_batch["cat_loss_mask"] = cat_loss_mask
707
+
708
+ masked_batch["masked_numeric_values_by_nin"] = masked_numeric_values_by_nin
709
+ masked_batch["masked_numeric_valid_positions_by_nin"] = masked_numeric_valid_positions_by_nin
710
+ masked_batch["numeric_loss_mask_by_nin"] = numeric_loss_mask_by_nin
711
+
712
+ return masked_batch
713
+
714
+
715
+ def build_train_eval_dataloaders(
716
+ dataset,
717
+ train_ratio=0.8,
718
+ seed=42,
719
+ batch_size=32,
720
+ ):
721
+ n = len(dataset)
722
+
723
+ n_train = int(n * train_ratio)
724
+ n_eval = n - n_train
725
+
726
+ split_generator = torch.Generator().manual_seed(seed)
727
+
728
+ train_ds, eval_ds = torch.utils.data.random_split(
729
+ dataset,
730
+ [n_train, n_eval],
731
+ generator=split_generator
732
+ )
733
+
734
+ train_generator = torch.Generator()
735
+
736
+ train_loader = DataLoader(
737
+ train_ds,
738
+ batch_size=batch_size,
739
+ shuffle=True,
740
+ collate_fn=dataset.collate_fn,
741
+ generator=train_generator,
742
+ )
743
+
744
+ eval_loader = DataLoader(
745
+ eval_ds,
746
+ batch_size=batch_size,
747
+ shuffle=False,
748
+ collate_fn=dataset.collate_fn,
749
+ )
750
+
751
+ return train_loader, eval_loader, train_generator
752
+
753
+
754
+ def debug_print_first_sample(dataset, batch, batch_pos=0):
755
+ """
756
+ Inspect one sample in a batch.
757
+
758
+ This debug function checks masked_* fields against the original csv row.
759
+ Positions in loss_mask are allowed to mismatch.
760
+
761
+ Args:
762
+ dataset: SoilFormerDataset
763
+ batch: collated + optionally masked batch
764
+ batch_pos: index inside the batch (not dataset row index)
765
+ """
766
+ import math
767
+
768
+ def numeric_list_close(a, b, atol=1e-6, rtol=1e-5):
769
+ if len(a) != len(b):
770
+ return False
771
+ for x, y in zip(a, b):
772
+ if not math.isclose(float(x), float(y), rel_tol=rtol, abs_tol=atol):
773
+ return False
774
+ return True
775
+
776
+ def normalize_numeric_list(feat_name, vals, is_valid):
777
+ if not is_valid:
778
+ return [0.0] * len(vals)
779
+
780
+ stat_row = dataset.numeric_stats_index.loc[feat_name]
781
+ mean = float(stat_row["mean"])
782
+ std = float(stat_row["std"])
783
+ if std == 0.0:
784
+ std = 1.0
785
+
786
+ return [(float(v) - mean) / std for v in vals]
787
+
788
+ if "row_idx" not in batch:
789
+ raise KeyError("batch must contain 'row_idx' for debug_print_first_sample")
790
+ if "sample_id" not in batch:
791
+ raise KeyError("batch must contain 'sample_id' for debug_print_first_sample")
792
+
793
+ row_idx = int(batch["row_idx"][batch_pos].item())
794
+ row = dataset.df.iloc[row_idx]
795
+ sample_id = batch["sample_id"][batch_pos]
796
+
797
+ print("\n====================================================")
798
+ print("DEBUG SAMPLE")
799
+ print("====================================================")
800
+ print("batch_pos :", batch_pos)
801
+ print("row_idx :", row_idx)
802
+ print("sample_id :", sample_id)
803
+
804
+ # ====================================================
805
+ # categorical
806
+ # ====================================================
807
+ print("\n[CATEGORICAL FEATURES]")
808
+
809
+ cat_ids = batch["masked_cat_local_ids"][batch_pos]
810
+ cat_valids = batch["masked_cat_valid_positions"][batch_pos]
811
+ cat_loss_mask = batch.get("cat_loss_mask", None)
812
+ if cat_loss_mask is not None:
813
+ cat_loss_mask = cat_loss_mask[batch_pos]
814
+
815
+ for i, col in enumerate(dataset.cat_columns):
816
+ raw = row[col]
817
+ raw_str = str(raw)
818
+
819
+ got_id = int(cat_ids[i].item())
820
+ got_valid = bool(cat_valids[i].item())
821
+
822
+ spec = dataset.cat_vocab[col]
823
+ label2id = spec["label2id"]
824
+ mask_id = int(spec["mask_local_id"])
825
+
826
+ if raw == "":
827
+ expected_id = mask_id
828
+ expected_valid = False
829
+ else:
830
+ expected_id = int(label2id[raw])
831
+ expected_valid = True
832
+
833
+ is_loss_position = False
834
+ if cat_loss_mask is not None:
835
+ is_loss_position = bool(cat_loss_mask[i].item())
836
+
837
+ if is_loss_position:
838
+ ok = True
839
+ else:
840
+ ok = (got_id == expected_id) and (got_valid == expected_valid)
841
+
842
+ print(
843
+ f"{i:03d} | {col} | "
844
+ f"raw={raw_str:<60} | "
845
+ f"id={got_id:<6} | expected={expected_id:<6} | "
846
+ f"valid={got_valid} | exp_valid={expected_valid} | "
847
+ f"loss_mask={is_loss_position} | ok={ok}"
848
+ )
849
+
850
+ if not ok:
851
+ raise AssertionError(
852
+ f"\nCategorical mismatch\n"
853
+ f"batch_pos={batch_pos}\n"
854
+ f"row_idx={row_idx}\n"
855
+ f"feature={col}\n"
856
+ f"raw={raw}\n"
857
+ f"id={got_id}, expected={expected_id}\n"
858
+ f"valid={got_valid}, expected={expected_valid}"
859
+ )
860
+
861
+ # ====================================================
862
+ # numeric
863
+ # ====================================================
864
+ print("\n[NUMERIC FEATURES]")
865
+
866
+ numeric_loss_mask_by_nin = batch.get("numeric_loss_mask_by_nin", None)
867
+
868
+ for group in dataset.numeric_groups:
869
+ n_in = int(group["n_in"])
870
+ features = group["feature_names"]
871
+
872
+ values = batch["masked_numeric_values_by_nin"][n_in][batch_pos]
873
+ valids = batch["masked_numeric_valid_positions_by_nin"][n_in][batch_pos]
874
+
875
+ if numeric_loss_mask_by_nin is not None:
876
+ loss_mask = numeric_loss_mask_by_nin[n_in][batch_pos]
877
+ else:
878
+ loss_mask = None
879
+
880
+ print(f"\nGroup n_in={n_in}")
881
+
882
+ for i, feat in enumerate(features):
883
+ raw = row[feat]
884
+ raw_str = str(raw)
885
+
886
+ parsed, expected_valid = parse_numeric_cell(raw, n_in)
887
+ expected_norm = normalize_numeric_list(feat, parsed, expected_valid)
888
+
889
+ tensor_val = values[i].tolist()
890
+ got_valid = bool(valids[i].item())
891
+
892
+ is_loss_position = False
893
+ if loss_mask is not None:
894
+ is_loss_position = bool(loss_mask[i].item())
895
+
896
+ if is_loss_position:
897
+ ok = True
898
+ else:
899
+ value_ok = numeric_list_close(tensor_val, expected_norm)
900
+ valid_ok = (got_valid == expected_valid)
901
+ ok = value_ok and valid_ok
902
+
903
+ print(
904
+ f"{i:03d} | {feat} | "
905
+ f"raw={raw_str:<60} | "
906
+ f"tensor={tensor_val} | expected_norm={expected_norm} | "
907
+ f"valid={got_valid} | exp_valid={expected_valid} | "
908
+ f"loss_mask={is_loss_position} | ok={ok}"
909
+ )
910
+
911
+ if not ok:
912
+ raise AssertionError(
913
+ f"\nNumeric mismatch\n"
914
+ f"batch_pos={batch_pos}\n"
915
+ f"row_idx={row_idx}\n"
916
+ f"feature={feat}\n"
917
+ f"raw={raw}\n"
918
+ f"tensor={tensor_val}\n"
919
+ f"expected={parsed}\n"
920
+ f"valid={got_valid}, expected={expected_valid}"
921
+ )
922
+
923
+ # ====================================================
924
+ # vision
925
+ # ====================================================
926
+ print("\n[VISION]")
927
+
928
+ try:
929
+ relative_path = dataset.photo_map[sample_id]
930
+ expected_path = join_photo_root(dataset.photo_root, relative_path)
931
+
932
+ # Use the same logic as __getitem__: valid only if image can actually be loaded
933
+ _ = dataset.load_image(expected_path)
934
+ expected_valid = True
935
+
936
+ except Exception: # noqa
937
+ expected_path = None
938
+ expected_valid = False
939
+
940
+ got_valid = bool(batch["vision_valid_positions"][batch_pos].item())
941
+ img_shape = tuple(batch["pixel_values"][batch_pos].shape)
942
+
943
+ print("expected_path :", expected_path)
944
+ print("vision_valid :", got_valid)
945
+ print("image_shape :", img_shape)
946
+
947
+ if got_valid != expected_valid:
948
+ raise AssertionError(
949
+ f"\nVision validity mismatch\n"
950
+ f"batch_pos={batch_pos}\n"
951
+ f"row_idx={row_idx}\n"
952
+ f"expected={expected_valid}, got={got_valid}"
953
+ )
954
+
955
+ print("\n====================================================")
956
+ print("DEBUG CHECK PASSED")
957
+ print("====================================================\n")
958
+
959
+
960
+ def main():
961
+ dataset = SoilFormerDataset(
962
+ csv_path="data/tabular_data.csv",
963
+ photo_map_path="data/photo_map.json",
964
+ cat_vocab_path="data/cat_vocab.json",
965
+ numeric_vocab_path="data/numeric_vocab.json",
966
+ numeric_stats_path="data/tabular_meta_numeric_stats.csv",
967
+ photo_root="/Volumes/TOSHIBA EXT",
968
+ image_size=512,
969
+ id_column="id",
970
+ )
971
+
972
+ train_loader, eval_loader, train_generator = build_train_eval_dataloaders(dataset)
973
+
974
+ print("Dataset size:", len(dataset))
975
+
976
+ raw_batch = next(iter(eval_loader))
977
+ batch = dataset.perform_active_mask(
978
+ raw_batch,
979
+ cat_ratio=0.15,
980
+ num_ratio=0.15,
981
+ seed=42,
982
+ )
983
+
984
+ print("\nBatch check")
985
+ if "row_idx" in batch:
986
+ print("row_idx:", batch["row_idx"].shape, batch["row_idx"].dtype)
987
+ if "sample_id" in batch:
988
+ print("sample_id:", len(batch["sample_id"]))
989
+
990
+ print("original_cat_local_ids:", batch["original_cat_local_ids"].shape)
991
+ print("masked_cat_local_ids:", batch["masked_cat_local_ids"].shape)
992
+ print("original_cat_valid_positions:", batch["original_cat_valid_positions"].shape)
993
+ print("masked_cat_valid_positions:", batch["masked_cat_valid_positions"].shape)
994
+ print("cat_loss_mask:", batch["cat_loss_mask"].shape)
995
+
996
+ for k, v in batch["original_numeric_values_by_nin"].items():
997
+ print(f"original_numeric_values_by_nin[{k}]:", v.shape)
998
+
999
+ for k, v in batch["masked_numeric_values_by_nin"].items():
1000
+ print(f"masked_numeric_values_by_nin[{k}]:", v.shape)
1001
+
1002
+ for k, v in batch["original_numeric_valid_positions_by_nin"].items():
1003
+ print(f"original_numeric_valid_positions_by_nin[{k}]:", v.shape)
1004
+
1005
+ for k, v in batch["masked_numeric_valid_positions_by_nin"].items():
1006
+ print(f"masked_numeric_valid_positions_by_nin[{k}]:", v.shape)
1007
+
1008
+ for k, v in batch["numeric_loss_mask_by_nin"].items():
1009
+ print(f"numeric_loss_mask_by_nin[{k}]:", v.shape)
1010
+
1011
+ print("pixel_values:", batch["pixel_values"].shape)
1012
+ print("vision_valid_positions:", batch["vision_valid_positions"].shape)
1013
+
1014
+ print("\nTensor dtype check")
1015
+ print("masked cat ids dtype:", batch["masked_cat_local_ids"].dtype)
1016
+ print("masked numeric dtype:", next(iter(batch["masked_numeric_values_by_nin"].values())).dtype)
1017
+ print("image dtype:", batch["pixel_values"].dtype)
1018
+
1019
+ print("\nLoader test finished successfully")
1020
+
1021
+ debug_print_first_sample(dataset, batch, batch_pos=0)
1022
+
1023
+
1024
+ if __name__ == "__main__":
1025
+ main()
modelling/soilformer.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # soilformer.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F # noqa
12
+
13
+ from decode_categorical import CategoricalDecoder
14
+ from decode_numeric import NumericDecoder
15
+ from embed_categorical import (
16
+ CategoricalEmbedding,
17
+ build_cat_vocab_spec_from_meta,
18
+ get_categorical_feature_names_from_meta,
19
+ save_cat_vocab_json,
20
+ )
21
+ from embed_numeric import (
22
+ NumericEmbedding,
23
+ build_numeric_vocab_spec_from_meta,
24
+ )
25
+ from embed_vision_gemma3n import Gemma3nVisionFeatureExtractor
26
+ from layer import TabularImageGQALayer
27
+ from utils import load_json, save_json, get_dtype
28
+
29
+
30
+ # ============================================================
31
+ # SoilFormer
32
+ # ============================================================
33
+
34
+ class SoilFormer(nn.Module):
35
+ """
36
+ Full model: embeddings -> TabularImageGQALayer stack -> decoders.
37
+ """
38
+
39
+ def __init__(self, config: Dict, device: Optional[str] = None):
40
+ super().__init__()
41
+ self.config = dict(config)
42
+
43
+ dtype = get_dtype(self.config.get("dtype", "bfloat16"))
44
+ dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
45
+
46
+ # ---- Tabular dims
47
+ cat_hidden = int(self.config["cat_hidden_size"])
48
+ num_hidden = int(self.config["numeric_hidden_size"])
49
+ if cat_hidden != num_hidden:
50
+ raise ValueError("Expect cat_hidden_size == numeric_hidden_size for one tabular stream.")
51
+ self.tabular_dim = cat_hidden
52
+
53
+ # ---- Embeddings
54
+ self.embed_cat = CategoricalEmbedding(
55
+ hidden_size=cat_hidden,
56
+ cat_vocab_json=self.config["cat_vocab_json"],
57
+ )
58
+ self.embed_num = NumericEmbedding(
59
+ hidden_size=num_hidden,
60
+ numeric_vocab_json=self.config["numeric_vocab_json"],
61
+ middle_size=self.config.get("numeric_encode_middle_size", None),
62
+ )
63
+
64
+ # ---- Decoders
65
+ self.decode_cat = CategoricalDecoder(
66
+ hidden_size=cat_hidden,
67
+ cat_vocab_json=self.config["cat_vocab_json"],
68
+ middle_size=self.config.get("cat_decode_middle_size", None),
69
+ homoscedastic=self.config.get("cat_homoscedastic", True),
70
+ )
71
+ self.decode_num = NumericDecoder(
72
+ hidden_size=num_hidden,
73
+ numeric_vocab_json=self.config["numeric_vocab_json"],
74
+ middle_size=self.config.get("numeric_decode_middle_size", None),
75
+ homoscedastic=self.config.get("num_homoscedastic", True),
76
+ )
77
+
78
+ # ---- Vision
79
+ self.vision_extractor = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir(
80
+ model_dir=self.config["vision_model_dir"],
81
+ map_location="cpu",
82
+ num_output_tokens_reduced=self.config["vision_num_output_tokens_reduced"],
83
+ num_heads_for_token_reduction=self.config["vision_num_heads_for_token_reduction"],
84
+ reducer_bottleneck_dim=self.config["vision_reducer_bottleneck_dim"],
85
+ reducer_project_back=self.config["vision_reducer_project_back"],
86
+ )
87
+
88
+ # ---- Layers
89
+ L = int(self.config["layer_num_layers"])
90
+ self.layers = nn.ModuleList([
91
+ TabularImageGQALayer(
92
+ tabular_dim=self.tabular_dim,
93
+ vision_dim=self.vision_extractor.get_actual_hidden_dim(),
94
+ num_query_heads=int(self.config["layer_num_query_heads"]),
95
+ num_kv_heads=int(self.config["layer_num_kv_heads"]),
96
+ head_dim=int(self.config["layer_head_dim"]),
97
+ mlp_ratio=float(self.config["layer_mlp_ratio"]),
98
+ dropout=float(self.config["layer_dropout"]),
99
+ )
100
+ for _ in range(L)
101
+ ])
102
+
103
+ # ---- Move
104
+ self.to(device=dev, dtype=dtype)
105
+
106
+ def init_weights(self, std: float = 0.02):
107
+ self.embed_cat.init_weights(std=std)
108
+ self.embed_num.init_weights(std=std)
109
+
110
+ self.decode_cat.init_weights(std=std)
111
+ self.decode_num.init_weights(std=std)
112
+
113
+ self.vision_extractor.init_weights(std=std)
114
+
115
+ for blk in self.layers:
116
+ blk.init_weights(std=std)
117
+
118
+ def forward(
119
+ self,
120
+ cat_local_ids: torch.LongTensor, # [B, M_cat]
121
+ numeric_values_by_nin: Dict[int, torch.Tensor], # {n_in: [B, V, n_in]}
122
+ cat_valid_positions: Optional[torch.Tensor] = None, # [B, M_cat] bool
123
+ numeric_valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, # {n_in: [B,V] bool}
124
+ pixel_values: Optional[torch.Tensor] = None, # [B, 3, H, W]
125
+ vision_valid_positions: Optional[torch.Tensor] = None, # [B] bool OR indices [K]
126
+ ):
127
+ # ----------------------------
128
+ # Embeddings (tabular)
129
+ # ----------------------------
130
+ x_cat, cat_mask = self.embed_cat(
131
+ local_ids=cat_local_ids,
132
+ valid_positions=cat_valid_positions,
133
+ )
134
+
135
+ x_num, num_mask = self.embed_num(
136
+ values_by_nin=numeric_values_by_nin,
137
+ valid_positions_by_nin=numeric_valid_positions_by_nin,
138
+ )
139
+
140
+ x_tab = torch.cat([x_cat, x_num], dim=1) # [B, T_tab, H]
141
+
142
+ B, T_tab, _ = x_tab.shape
143
+ M_cat = x_cat.size(1)
144
+ T_num = x_num.size(1)
145
+
146
+ # ----------------------------
147
+ # Tabular attention mask
148
+ # ----------------------------
149
+ cat_mask = cat_mask.to(device=x_tab.device, dtype=torch.long)
150
+ num_mask = num_mask.to(device=x_tab.device, dtype=torch.long)
151
+
152
+ if self.config["disable_tabular_attention_mask"]:
153
+ attention_mask_tab = torch.ones(B, T_tab, device=x_tab.device, dtype=torch.long)
154
+ else:
155
+ attention_mask_tab = torch.cat([cat_mask, num_mask], dim=1)
156
+ if attention_mask_tab.shape != (B, T_tab):
157
+ raise RuntimeError("Internal attention_mask_tab shape mismatch")
158
+
159
+ # ----------------------------
160
+ # Vision features
161
+ # ----------------------------
162
+ if pixel_values is None:
163
+
164
+ vision_features = None
165
+ vision_mask = None
166
+
167
+ else:
168
+
169
+ vision_features, vision_mask = self.vision_extractor(
170
+ pixel_values=pixel_values,
171
+ valid_positions=vision_valid_positions,
172
+ )
173
+
174
+ if vision_features.shape[0] != B:
175
+ raise ValueError("vision_features batch mismatch with tabular batch")
176
+
177
+ if vision_mask.shape[0] != B or vision_mask.shape[1] != vision_features.shape[1]:
178
+ raise ValueError("vision_mask shape mismatch with vision_features")
179
+
180
+ vision_mask = vision_mask.to(
181
+ device=attention_mask_tab.device,
182
+ dtype=attention_mask_tab.dtype,
183
+ )
184
+
185
+ # ----------------------------
186
+ # Transformer blocks
187
+ # ----------------------------
188
+ for blk in self.layers: # type: TabularImageGQALayer
189
+ x_tab = blk(
190
+ x_tab=x_tab,
191
+ attention_mask=attention_mask_tab,
192
+ vision_features=vision_features,
193
+ vision_mask=vision_mask
194
+ )
195
+
196
+ # ----------------------------
197
+ # Slice outputs
198
+ # ----------------------------
199
+ x_cat_out = x_tab[:, :M_cat, :]
200
+ x_num_out = x_tab[:, M_cat:M_cat + T_num, :]
201
+
202
+ # ----------------------------
203
+ # Decode
204
+ # ----------------------------
205
+ cat_logits_padded, cat_s, valid_class_mask = self.decode_cat(
206
+ x_cat_out,
207
+ return_padded=True,
208
+ )
209
+
210
+ value_by_nin, s_by_nin = self.decode_num(
211
+ x_num_out
212
+ )
213
+
214
+ return cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab
215
+
216
+ def _checkpoint_state_dict(self) -> Dict[str, torch.Tensor]:
217
+ """
218
+ State dict used for save/load.
219
+
220
+ Excludes pretrained frozen vision weights:
221
+ - vision_extractor.vision_tower.*
222
+ - vision_extractor.embed_vision.*
223
+
224
+ Keeps reducer weights if reducer exists.
225
+ """
226
+ full_sd = self.state_dict()
227
+ out = {}
228
+
229
+ for k, v in full_sd.items():
230
+ if k.startswith("vision_extractor.vision_tower."):
231
+ continue
232
+ if k.startswith("vision_extractor.embed_vision."):
233
+ continue
234
+ out[k] = v
235
+
236
+ return out
237
+
238
+ def save_weights(self, path: str):
239
+ """
240
+ Save model weights needed for SoilFormer training/inference,
241
+ excluding pretrained frozen vision weights.
242
+ """
243
+ payload = {
244
+ "model_state_dict": self._checkpoint_state_dict(),
245
+ "config": self.config,
246
+ }
247
+ torch.save(payload, path)
248
+
249
+ def load_weights(self, path: str, map_location: str = "cpu", strict: bool = True):
250
+ """
251
+ Load weights saved by save_weights().
252
+
253
+ Only the checkpoint-managed subset is loaded:
254
+ - embeddings / decoders / layers
255
+ - vision_extractor.reducer.* (if present)
256
+
257
+ Pretrained frozen vision weights are ignored here and are expected
258
+ to come from vision_model_dir during model construction.
259
+ """
260
+ ckpt = torch.load(path, map_location=map_location)
261
+
262
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
263
+ sd = ckpt["model_state_dict"]
264
+ elif isinstance(ckpt, dict):
265
+ sd = ckpt
266
+ else:
267
+ raise ValueError(f"Unsupported checkpoint format: {path}")
268
+
269
+ expected_sd = self._checkpoint_state_dict()
270
+
271
+ # Only keep keys that belong to the checkpoint-managed subset
272
+ loadable_sd = {k: v for k, v in sd.items() if k in expected_sd}
273
+
274
+ missing = sorted(set(expected_sd.keys()) - set(loadable_sd.keys()))
275
+ unexpected = sorted(set(sd.keys()) - set(expected_sd.keys()))
276
+
277
+ # Actually load
278
+ load_info = self.load_state_dict(loadable_sd, strict=False)
279
+
280
+ # PyTorch may still report missing keys from the full model state_dict;
281
+ # keep only checkpoint-managed ones.
282
+ missing_after_load = [
283
+ k for k in load_info.missing_keys
284
+ if k in expected_sd
285
+ ]
286
+ unexpected_after_load = [
287
+ k for k in load_info.unexpected_keys
288
+ if k in expected_sd
289
+ ]
290
+
291
+ # Merge both sources of mismatch info
292
+ missing_final = sorted(set(missing) | set(missing_after_load))
293
+ unexpected_final = sorted(set(unexpected) | set(unexpected_after_load))
294
+
295
+ if strict and (missing_final or unexpected_final):
296
+ raise RuntimeError(
297
+ "Checkpoint load mismatch.\n"
298
+ f"Missing keys: {missing_final}\n"
299
+ f"Unexpected keys: {unexpected_final}"
300
+ )
301
+
302
+ return {
303
+ "missing_keys": missing_final,
304
+ "unexpected_keys": unexpected_final,
305
+ }
306
+
307
+
308
+ def loss_function(
309
+ x_cat: torch.Tensor, # [B,M,Cmax] padded logits
310
+ s_cat: torch.Tensor, # [B,M] log-variance
311
+ y_cat: torch.Tensor, # [B,M] class index
312
+ loss_mask_cat: torch.Tensor, # [B,M] 0/1
313
+ valid_class_mask: torch.Tensor, # [M,Cmax] bool
314
+ x_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]}
315
+ s_num: Dict[int, torch.Tensor], # {n_in: [B,V]}
316
+ y_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]}
317
+ loss_mask_num: Dict[int, torch.Tensor], # {n_in: [B,V]} 0/1
318
+ cat_temperature: float = 1.0,
319
+ reduction: str = "mean", # "mean" or "sum"
320
+ eps: float = 1e-12,
321
+ cat_s_bound: Optional[float] = None,
322
+ num_s_bound: Optional[float] = None,
323
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
324
+ """
325
+ Strict loss for SoilFormer.
326
+
327
+ Categorical:
328
+ - Uses per-column CE over the valid class range only.
329
+ - Does NOT rely on padded logits values.
330
+ - s_cat[b,m] = log sigma^2 for categorical column m.
331
+
332
+ Numeric:
333
+ - Per-variable MSE averaged over n_in dimensions.
334
+ - s_num[n_in][b,v] = log sigma^2 for numeric variable v.
335
+
336
+ Optional soft bound:
337
+ If cat_s_bound or num_s_bound is not None, apply
338
+ s <- bound * tanh(s / bound)
339
+ before using s in heteroscedastic weighting.
340
+
341
+ Returns:
342
+ total_loss: scalar (float32)
343
+ stats: dict with cat_loss, num_loss, cat_base, num_base, counts...
344
+ """
345
+
346
+ def _soft_bound_logvar(s_: torch.Tensor, bound: Optional[float]) -> torch.Tensor:
347
+ if bound is None:
348
+ return s_
349
+ b = float(bound)
350
+ if b <= 0:
351
+ # Turn off weighting by signalling a non-positive bound
352
+ return torch.zeros_like(s_)
353
+ return b * torch.tanh(s_ / b)
354
+
355
+ # ---------------------------------------------------
356
+ # 1) Categorical loss (strict per-column CE)
357
+ # ---------------------------------------------------
358
+ if x_cat.dim() != 3:
359
+ raise ValueError(f"x_cat must be [B,M,Cmax], got {tuple(x_cat.shape)}")
360
+
361
+ B, M, Cmax = x_cat.shape
362
+
363
+ if s_cat.shape != (B, M):
364
+ raise ValueError(f"s_cat must be [B,M]=({B},{M}), got {tuple(s_cat.shape)}")
365
+ if y_cat.shape != (B, M):
366
+ raise ValueError(f"y_cat must be [B,M]=({B},{M}), got {tuple(y_cat.shape)}")
367
+ if loss_mask_cat.shape != (B, M):
368
+ raise ValueError(f"loss_mask_cat must be [B,M]=({B},{M}), got {tuple(loss_mask_cat.shape)}")
369
+ if valid_class_mask.shape != (M, Cmax):
370
+ raise ValueError(
371
+ f"valid_class_mask must be [M,Cmax]=({M},{Cmax}), got {tuple(valid_class_mask.shape)}"
372
+ )
373
+
374
+ x_cat_f = x_cat.float()
375
+ s_cat_f = _soft_bound_logvar(s_cat.float(), cat_s_bound)
376
+ y_cat_l = y_cat.long()
377
+ mcat = loss_mask_cat.float()
378
+ valid_class_mask = valid_class_mask.to(device=x_cat.device, dtype=torch.bool)
379
+
380
+ if cat_temperature != 1.0:
381
+ x_cat_f = x_cat_f / float(cat_temperature)
382
+
383
+ cat_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
384
+ cat_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
385
+ cat_correct_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
386
+
387
+ # denominator = number of actively supervised categorical cells
388
+ cat_denom = mcat.sum().clamp_min(float(eps))
389
+
390
+ for m in range(M):
391
+ cm = int(valid_class_mask[m].sum().item()) # real class count for column m
392
+ if cm <= 0:
393
+ raise ValueError(f"Column {m} has no valid classes")
394
+
395
+ logits_m = x_cat_f[:, m, :cm] # [B, C_m]
396
+ target_m = y_cat_l[:, m] # [B]
397
+ s_m = s_cat_f[:, m] # [B]
398
+ mask_m = mcat[:, m] # [B]
399
+
400
+ active = mask_m > 0
401
+ if active.any():
402
+ tgt_active = target_m[active]
403
+ if (tgt_active < 0).any() or (tgt_active >= cm).any():
404
+ raise ValueError(f"y_cat contains invalid class id for categorical column {m}")
405
+
406
+ target_m_safe = target_m.clone()
407
+ target_m_safe[~active] = 0
408
+
409
+ ce_m = F.cross_entropy(
410
+ logits_m,
411
+ target_m_safe,
412
+ reduction="none",
413
+ ) # [B], float32
414
+
415
+ # ---------------------------------------------------
416
+ # accuracy (only count active positions)
417
+ # ---------------------------------------------------
418
+ pred_m = logits_m.argmax(dim=-1) # [B]
419
+ correct_m = (pred_m == target_m_safe) & active # [B]
420
+ cat_correct_acc = cat_correct_acc + correct_m.float().sum()
421
+
422
+ # heteroscedastic weighting: exp(-s) * CE + s
423
+ L_m = torch.exp(-s_m) * ce_m + s_m # [B]
424
+
425
+ cat_loss_acc = cat_loss_acc + (L_m * mask_m).sum()
426
+ cat_base_acc = cat_base_acc + (ce_m * mask_m).sum()
427
+
428
+ if reduction == "mean":
429
+ cat_loss = cat_loss_acc / cat_denom
430
+ cat_base = cat_base_acc / cat_denom
431
+ elif reduction == "sum":
432
+ cat_loss = cat_loss_acc
433
+ cat_base = cat_base_acc
434
+ else:
435
+ raise ValueError(f"Unsupported reduction: {reduction}")
436
+ cat_acc = cat_correct_acc / cat_denom
437
+
438
+ # ---------------------------------------------------
439
+ # 2) Numeric loss (per-variable heteroscedastic MSE)
440
+ # ---------------------------------------------------
441
+ num_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
442
+ num_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
443
+ num_denom_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
444
+
445
+ for n_in, x in x_num.items():
446
+ if n_in not in y_num or n_in not in s_num or n_in not in loss_mask_num:
447
+ raise KeyError(f"Missing key n_in={n_in} in y_num/s_num/loss_mask_num")
448
+
449
+ y = y_num[n_in]
450
+ s = s_num[n_in]
451
+ m = loss_mask_num[n_in]
452
+
453
+ if x.shape != y.shape:
454
+ raise ValueError(
455
+ f"x_num[{n_in}] and y_num[{n_in}] shape mismatch: "
456
+ f"{tuple(x.shape)} vs {tuple(y.shape)}"
457
+ )
458
+ if x.dim() != 3:
459
+ raise ValueError(f"x_num[{n_in}] must be [B,V,n_in], got {tuple(x.shape)}")
460
+
461
+ Bb, V, Nin = x.shape
462
+ if Nin != n_in:
463
+ raise ValueError(f"x_num[{n_in}] last dim mismatch: got {Nin}, expected {n_in}")
464
+ if s.shape != (Bb, V):
465
+ raise ValueError(f"s_num[{n_in}] must be [B,V], got {tuple(s.shape)}")
466
+ if m.shape != (Bb, V):
467
+ raise ValueError(f"loss_mask_num[{n_in}] must be [B,V], got {tuple(m.shape)}")
468
+
469
+ x_f = x.float()
470
+ y_f = y.float()
471
+ s_f = _soft_bound_logvar(s.float(), num_s_bound)
472
+ m_f = m.float()
473
+
474
+ # base numeric loss per variable: mean over n_in dims
475
+ mse = (x_f - y_f).pow(2).mean(dim=-1) # [B,V]
476
+
477
+ # heteroscedastic weighting: exp(-s) * mse + s
478
+ L = torch.exp(-s_f) * mse + s_f # [B,V]
479
+
480
+ num_loss_acc = num_loss_acc + (L * m_f).sum()
481
+ num_base_acc = num_base_acc + (mse * m_f).sum()
482
+ num_denom_acc = num_denom_acc + m_f.sum()
483
+
484
+ num_denom = num_denom_acc.clamp_min(float(eps))
485
+
486
+ if reduction == "mean":
487
+ num_loss = num_loss_acc / num_denom
488
+ num_base = num_base_acc / num_denom
489
+ elif reduction == "sum":
490
+ num_loss = num_loss_acc
491
+ num_base = num_base_acc
492
+ else:
493
+ raise ValueError(f"Unsupported reduction: {reduction}")
494
+
495
+ # ---------------------------------------------------
496
+ # 3) Total
497
+ # ---------------------------------------------------
498
+ total = cat_loss + num_loss
499
+
500
+ stats = {
501
+ "total": total.detach(),
502
+ "cat_loss": cat_loss.detach(),
503
+ "num_loss": num_loss.detach(),
504
+ "cat_base": cat_base.detach(),
505
+ "num_base": num_base.detach(),
506
+ "cat_count": cat_denom.detach(),
507
+ "num_count": num_denom.detach(),
508
+ "cat_acc": cat_acc.detach(),
509
+ }
510
+ return total, stats
511
+
512
+
513
+ # ============================================================
514
+ # DEMO
515
+ # ============================================================
516
+
517
+ def _demo_main():
518
+ import argparse
519
+
520
+ parser = argparse.ArgumentParser()
521
+ parser.add_argument("--config_json", type=str, default="config/config_model.json")
522
+ parser.add_argument("--batch_size", type=int, default=2)
523
+ parser.add_argument("--with_vision", action="store_true")
524
+ args = parser.parse_args()
525
+
526
+ cfg = load_json(args.config_json)
527
+
528
+ print("===== Loaded config =====")
529
+ print(json.dumps(cfg, ensure_ascii=False, indent=2))
530
+
531
+ # --------------------------------------------------
532
+ # Ensure vocab files exist
533
+ # --------------------------------------------------
534
+ tabular_meta = load_json(cfg["tabular_meta"])
535
+
536
+ if not os.path.isfile(cfg["cat_vocab_json"]):
537
+ cat_names = get_categorical_feature_names_from_meta(tabular_meta)
538
+ vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names)
539
+ Path(cfg["cat_vocab_json"]).parent.mkdir(parents=True, exist_ok=True)
540
+ save_cat_vocab_json(vocab, cfg["cat_vocab_json"])
541
+ print(f"[demo] Built cat_vocab_json at {cfg['cat_vocab_json']}")
542
+
543
+ if not os.path.isfile(cfg["numeric_vocab_json"]):
544
+ spec = build_numeric_vocab_spec_from_meta(tabular_meta)
545
+ Path(cfg["numeric_vocab_json"]).parent.mkdir(parents=True, exist_ok=True)
546
+ save_json(spec, cfg["numeric_vocab_json"])
547
+ print(f"[demo] Built numeric_vocab_json at {cfg['numeric_vocab_json']}")
548
+
549
+ # --------------------------------------------------
550
+ # Build model
551
+ # --------------------------------------------------
552
+ model = SoilFormer(cfg)
553
+ model.init_weights()
554
+ model.eval()
555
+
556
+ device = next(model.parameters()).device
557
+ dtype = next(model.parameters()).dtype
558
+
559
+ B = args.batch_size
560
+
561
+ # --------------------------------------------------
562
+ # Build dummy categorical inputs
563
+ # --------------------------------------------------
564
+ cat_spec = load_json(cfg["cat_vocab_json"])
565
+ cat_items = sorted(cat_spec.items(), key=lambda x: x[1]["col_id"])
566
+ M_cat = len(cat_items)
567
+
568
+ cat_local_ids = torch.zeros(B, M_cat, dtype=torch.long, device=device)
569
+ cat_valid_positions = torch.ones(B, M_cat, dtype=torch.bool, device=device)
570
+
571
+ # --------------------------------------------------
572
+ # Build dummy numeric inputs
573
+ # --------------------------------------------------
574
+ num_spec = load_json(cfg["numeric_vocab_json"])
575
+
576
+ numeric_values_by_nin: Dict[int, torch.Tensor] = {}
577
+ numeric_valid_positions_by_nin: Dict[int, torch.Tensor] = {}
578
+
579
+ for g in num_spec["groups"]:
580
+ n_in = int(g["n_in"])
581
+ V = len(g["feature_names"])
582
+
583
+ numeric_values_by_nin[n_in] = torch.randn(B, V, n_in, device=device, dtype=dtype)
584
+ numeric_valid_positions_by_nin[n_in] = torch.ones(B, V, dtype=torch.bool, device=device)
585
+
586
+ # --------------------------------------------------
587
+ # Build dummy vision inputs
588
+ # --------------------------------------------------
589
+ if args.with_vision:
590
+ pixel_values = torch.randn(B, 3, 224, 224, device=device, dtype=dtype)
591
+ vision_valid_positions = torch.ones(B, dtype=torch.bool, device=device)
592
+ else:
593
+ pixel_values = None
594
+ vision_valid_positions = None
595
+
596
+ # --------------------------------------------------
597
+ # Vision debug
598
+ # --------------------------------------------------
599
+ print("\n===== Vision debug =====")
600
+ if pixel_values is None:
601
+ print("pixel_values: None")
602
+ print("vision_features: None")
603
+ print("vision_mask: None")
604
+ else:
605
+ print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device)
606
+ with torch.no_grad():
607
+ vision_features, vision_mask = model.vision_extractor.forward(
608
+ pixel_values=pixel_values,
609
+ valid_positions=vision_valid_positions,
610
+ )
611
+ print("vision_features:", tuple(vision_features.shape), vision_features.dtype, vision_features.device)
612
+ print("vision_mask:", tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device)
613
+
614
+ # --------------------------------------------------
615
+ # Forward
616
+ # --------------------------------------------------
617
+ with torch.no_grad():
618
+ cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab = model.forward(
619
+ cat_local_ids=cat_local_ids, # noqa
620
+ numeric_values_by_nin=numeric_values_by_nin,
621
+ cat_valid_positions=cat_valid_positions,
622
+ numeric_valid_positions_by_nin=numeric_valid_positions_by_nin,
623
+ pixel_values=pixel_values,
624
+ vision_valid_positions=vision_valid_positions,
625
+ )
626
+
627
+ print("\n===== SoilFormer demo =====")
628
+ print("cat_local_ids:", tuple(cat_local_ids.shape))
629
+ print("cat_valid_positions:", tuple(cat_valid_positions.shape))
630
+ print("numeric_values_by_nin:", {k: tuple(v.shape) for k, v in numeric_values_by_nin.items()})
631
+ print("numeric_valid_positions_by_nin:", {k: tuple(v.shape) for k, v in numeric_valid_positions_by_nin.items()})
632
+ print("x_tab_final:", tuple(x_tab.shape), x_tab.dtype, x_tab.device)
633
+
634
+ print("Categorical outputs:")
635
+ print("cat_logits_padded:", tuple(cat_logits_padded.shape), cat_logits_padded.dtype, cat_logits_padded.device)
636
+ print("cat_s:", tuple(cat_s.shape), cat_s.dtype, cat_s.device)
637
+
638
+ print("Numeric decoded values:", {k: tuple(v.shape) for k, v in value_by_nin.items()})
639
+ print("Numeric decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
640
+
641
+ # --------------------------------------------------
642
+ # Loss debug
643
+ # --------------------------------------------------
644
+ print("\n===== Loss debug =====")
645
+
646
+ if cat_logits_padded.dim() != 3:
647
+ raise RuntimeError(f"cat_logits_padded must be [B,M,Cmax], got {tuple(cat_logits_padded.shape)}")
648
+
649
+ B_logits, M_cat2, Cmax2 = cat_logits_padded.shape
650
+ if cat_s.shape != (B_logits, M_cat2):
651
+ raise RuntimeError(f"cat_s shape mismatch: got {tuple(cat_s.shape)} expected {(B_logits, M_cat2)}")
652
+
653
+ # Build dummy categorical targets within valid class ranges
654
+ num_classes = [int(s["num_classes"]) for _, s in cat_items]
655
+ if len(num_classes) != M_cat2:
656
+ raise RuntimeError("M_cat mismatch between vocab and model output")
657
+
658
+ y_cat = torch.zeros(B_logits, M_cat2, dtype=torch.long, device=device)
659
+ for m, cm in enumerate(num_classes):
660
+ y_cat[:, m] = torch.randint(low=0, high=cm, size=(B_logits,), device=device)
661
+
662
+ mask_cat = torch.ones(B_logits, M_cat2, dtype=torch.long, device=device)
663
+
664
+ # Build dummy numeric targets and masks
665
+ y_num = {
666
+ n_in: torch.randn_like(x_pred)
667
+ for n_in, x_pred in value_by_nin.items()
668
+ }
669
+
670
+ mask_num = {
671
+ n_in: torch.ones(x_pred.size(0), x_pred.size(1), dtype=torch.long, device=x_pred.device)
672
+ for n_in, x_pred in value_by_nin.items()
673
+ }
674
+
675
+ total_loss, stats = loss_function(
676
+ x_cat=cat_logits_padded,
677
+ s_cat=cat_s,
678
+ y_cat=y_cat,
679
+ loss_mask_cat=mask_cat,
680
+ x_num=value_by_nin,
681
+ s_num=s_by_nin,
682
+ y_num=y_num,
683
+ loss_mask_num=mask_num,
684
+ reduction="mean",
685
+ valid_class_mask=valid_class_mask
686
+ )
687
+
688
+ print("total_loss:", float(total_loss))
689
+ print("stats:", {k: float(v) for k, v in stats.items()})
690
+
691
+ if not torch.isfinite(total_loss):
692
+ raise RuntimeError("Loss is not finite!")
693
+
694
+
695
+ if __name__ == "__main__":
696
+ _demo_main()
modelling/train.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Dict, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.optim import AdamW
11
+ from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LinearLR, SequentialLR
12
+ from tqdm import tqdm
13
+
14
+ from loader import SoilFormerDataset, build_train_eval_dataloaders
15
+ from soilformer import SoilFormer, loss_function
16
+ from utils import get_dtype, load_json, save_json
17
+
18
+ try:
19
+ import wandb
20
+ except ImportError: # pragma: no cover
21
+ wandb = None
22
+
23
+
24
+ def set_seed(seed: int, deterministic: bool = True) -> None:
25
+ random.seed(seed)
26
+ np.random.seed(seed)
27
+ torch.manual_seed(seed)
28
+ if torch.cuda.is_available():
29
+ torch.cuda.manual_seed(seed)
30
+ torch.cuda.manual_seed_all(seed)
31
+
32
+ if deterministic:
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.backends.cudnn.benchmark = False
35
+
36
+
37
+ def resolve_device(device_str: str) -> torch.device:
38
+ device_str = device_str.lower()
39
+
40
+ if device_str == "cuda":
41
+ if not torch.cuda.is_available():
42
+ raise RuntimeError("config requests cuda, but CUDA is not available")
43
+ return torch.device("cuda")
44
+
45
+ if device_str == "mps":
46
+ if not torch.backends.mps.is_available():
47
+ raise RuntimeError("config requests mps, but MPS is not available")
48
+ return torch.device("mps")
49
+
50
+ if device_str == "cpu":
51
+ return torch.device("cpu")
52
+
53
+ raise ValueError(f"Unsupported device: {device_str}")
54
+
55
+
56
+ def move_batch_to_device(batch: Dict, device: torch.device, float_dtype: torch.dtype) -> Dict:
57
+ out = {}
58
+ for key, value in batch.items():
59
+ if isinstance(value, torch.Tensor):
60
+ if value.dtype.is_floating_point:
61
+ out[key] = value.to(device=device, dtype=float_dtype, non_blocking=True)
62
+ else:
63
+ out[key] = value.to(device=device, non_blocking=True)
64
+ elif isinstance(value, dict):
65
+ sub = {}
66
+ for sub_key, sub_value in value.items():
67
+ if isinstance(sub_value, torch.Tensor):
68
+ if sub_value.dtype.is_floating_point:
69
+ sub[sub_key] = sub_value.to(device=device, dtype=float_dtype, non_blocking=True)
70
+ else:
71
+ sub[sub_key] = sub_value.to(device=device, non_blocking=True)
72
+ else:
73
+ sub[sub_key] = sub_value
74
+ out[key] = sub
75
+ else:
76
+ out[key] = value
77
+ return out
78
+
79
+
80
+ def build_scheduler(
81
+ optimizer: torch.optim.Optimizer,
82
+ scheduler_cfg: Dict,
83
+ ):
84
+ scheduler_type = str(scheduler_cfg.get("type", "none")).lower()
85
+
86
+ if scheduler_type == "none":
87
+ return None
88
+
89
+ warmup_epochs = int(scheduler_cfg.get("warmup_epochs", 0))
90
+ warmup_start_factor = float(scheduler_cfg.get("warmup_start_factor", 0.1))
91
+
92
+ if scheduler_type == "cosine":
93
+ total_epochs = int(scheduler_cfg["total_epochs"])
94
+ eta_min = float(scheduler_cfg.get("eta_min", 1e-6))
95
+
96
+ if warmup_epochs > 0:
97
+ t_max = int(scheduler_cfg.get("t_max", total_epochs - warmup_epochs))
98
+ if t_max <= 0:
99
+ raise ValueError(
100
+ f"Invalid cosine scheduler config: total_epochs={total_epochs}, "
101
+ f"warmup_epochs={warmup_epochs}, resulting T_max={t_max}"
102
+ )
103
+ else:
104
+ t_max = int(scheduler_cfg.get("t_max", total_epochs))
105
+
106
+ main_scheduler = CosineAnnealingLR(
107
+ optimizer,
108
+ T_max=t_max,
109
+ eta_min=eta_min,
110
+ )
111
+
112
+ elif scheduler_type == "step":
113
+ step_size = int(scheduler_cfg["step_size"])
114
+ gamma = float(scheduler_cfg.get("gamma", 0.1))
115
+ main_scheduler = StepLR(
116
+ optimizer,
117
+ step_size=step_size,
118
+ gamma=gamma,
119
+ )
120
+
121
+ else:
122
+ raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
123
+
124
+ if warmup_epochs <= 0:
125
+ return main_scheduler
126
+
127
+ warmup_scheduler = LinearLR(
128
+ optimizer,
129
+ start_factor=warmup_start_factor,
130
+ total_iters=warmup_epochs,
131
+ )
132
+
133
+ scheduler = SequentialLR(
134
+ optimizer,
135
+ schedulers=[warmup_scheduler, main_scheduler],
136
+ milestones=[warmup_epochs],
137
+ )
138
+ return scheduler
139
+
140
+
141
+ def get_checkpoint_model_state(model: SoilFormer) -> Dict[str, torch.Tensor]:
142
+ if hasattr(model, "_checkpoint_state_dict"):
143
+ return model._checkpoint_state_dict() # noqa
144
+ return model.state_dict()
145
+
146
+
147
+ def load_checkpoint_model_state(model: SoilFormer, state_dict: Dict[str, torch.Tensor]) -> None:
148
+ if hasattr(model, "load_weights"):
149
+ payload = {"model_state_dict": state_dict}
150
+ tmp_path = None
151
+ try:
152
+ import tempfile
153
+ with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
154
+ tmp_path = f.name
155
+ torch.save(payload, tmp_path)
156
+ model.load_weights(tmp_path, map_location="cpu", strict=True)
157
+ finally:
158
+ if tmp_path is not None and os.path.exists(tmp_path):
159
+ os.remove(tmp_path)
160
+ return
161
+
162
+ model.load_state_dict(state_dict, strict=True)
163
+
164
+
165
+ def save_checkpoint(
166
+ checkpoint_path: Path,
167
+ model: SoilFormer,
168
+ optimizer: torch.optim.Optimizer,
169
+ scheduler,
170
+ epoch: int,
171
+ global_step: int,
172
+ config_train: Dict,
173
+ config_model: Dict,
174
+ config_data: Dict,
175
+ ) -> None:
176
+ checkpoint = {
177
+ "epoch": epoch,
178
+ "global_step": global_step,
179
+ "model_state_dict": get_checkpoint_model_state(model),
180
+ "optimizer_state_dict": optimizer.state_dict(),
181
+ "scheduler_state_dict": None if scheduler is None else scheduler.state_dict(),
182
+ "config_train": config_train,
183
+ "config_model": config_model,
184
+ "config_data": config_data,
185
+ }
186
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
187
+ torch.save(checkpoint, checkpoint_path)
188
+
189
+
190
+ def rotate_checkpoints(checkpoint_dir: Path, max_saved_checkpoints: int) -> None:
191
+ checkpoint_paths = sorted(checkpoint_dir.glob("checkpoint_epoch_*.pt"))
192
+ if max_saved_checkpoints is None or max_saved_checkpoints <= 0:
193
+ return
194
+ while len(checkpoint_paths) > max_saved_checkpoints:
195
+ oldest = checkpoint_paths.pop(0)
196
+ oldest.unlink(missing_ok=True)
197
+
198
+
199
+ def compute_loss_from_batch(
200
+ model: SoilFormer,
201
+ batch: Dict,
202
+ device: torch.device,
203
+ dtype: torch.dtype,
204
+ cat_s_bound: Optional[float] = None,
205
+ num_s_bound: Optional[float] = None,
206
+ ):
207
+ batch = move_batch_to_device(batch, device=device, float_dtype=dtype)
208
+
209
+ cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model(
210
+ cat_local_ids=batch["masked_cat_local_ids"],
211
+ numeric_values_by_nin=batch["masked_numeric_values_by_nin"],
212
+ cat_valid_positions=batch["masked_cat_valid_positions"],
213
+ numeric_valid_positions_by_nin=batch["masked_numeric_valid_positions_by_nin"],
214
+ pixel_values=batch["pixel_values"],
215
+ vision_valid_positions=batch["vision_valid_positions"],
216
+ )
217
+
218
+ total_loss, stats = loss_function(
219
+ x_cat=cat_logits_padded,
220
+ s_cat=cat_s,
221
+ y_cat=batch["original_cat_local_ids"],
222
+ loss_mask_cat=batch["cat_loss_mask"],
223
+ valid_class_mask=valid_class_mask,
224
+ x_num=value_by_nin,
225
+ s_num=s_by_nin,
226
+ y_num=batch["original_numeric_values_by_nin"],
227
+ loss_mask_num=batch["numeric_loss_mask_by_nin"],
228
+ reduction="mean",
229
+ cat_s_bound=cat_s_bound,
230
+ num_s_bound=num_s_bound,
231
+ )
232
+
233
+ return total_loss, stats
234
+
235
+
236
+ @torch.no_grad()
237
+ def evaluate(
238
+ model: SoilFormer,
239
+ dataset: SoilFormerDataset,
240
+ eval_loader,
241
+ device: torch.device,
242
+ dtype: torch.dtype,
243
+ cat_mask_ratio: float,
244
+ num_mask_ratio: float,
245
+ active_mask_seed: int,
246
+ show_tqdm: bool,
247
+ epoch: int,
248
+ cat_s_bound: Optional[float] = None,
249
+ num_s_bound: Optional[float] = None,
250
+ ):
251
+ model.eval()
252
+
253
+ totals = {
254
+ "total": 0.0,
255
+ "cat_loss": 0.0,
256
+ "num_loss": 0.0,
257
+ "cat_base": 0.0,
258
+ "num_base": 0.0,
259
+ "cat_acc": 0.0,
260
+ }
261
+ num_batches = 0
262
+
263
+ iterator = eval_loader
264
+ if show_tqdm:
265
+ iterator = tqdm(eval_loader, desc=f"Eval {epoch}", leave=False)
266
+
267
+ for batch_idx, raw_batch in enumerate(iterator):
268
+ mask_seed = int(active_mask_seed + batch_idx)
269
+ masked_batch = dataset.perform_active_mask(
270
+ raw_batch,
271
+ cat_ratio=cat_mask_ratio,
272
+ num_ratio=num_mask_ratio,
273
+ seed=mask_seed,
274
+ )
275
+
276
+ _, stats = compute_loss_from_batch(
277
+ model=model,
278
+ batch=masked_batch,
279
+ device=device,
280
+ dtype=dtype,
281
+ cat_s_bound=cat_s_bound,
282
+ num_s_bound=num_s_bound,
283
+ )
284
+
285
+ num_batches += 1
286
+ for key in totals:
287
+ totals[key] += float(stats[key].item())
288
+
289
+ if num_batches == 0:
290
+ raise RuntimeError("Eval dataloader is empty")
291
+
292
+ return {f"eval/{k}": v / num_batches for k, v in totals.items()}
293
+
294
+
295
+ def maybe_init_wandb(config_train: Dict):
296
+ wandb_cfg = config_train["logging"]["wandb"]
297
+ if not bool(wandb_cfg.get("enabled", False)):
298
+ return None
299
+
300
+ if wandb is None:
301
+ raise ImportError("wandb is enabled in config but package is not installed")
302
+
303
+ run = wandb.init(
304
+ project=wandb_cfg["project"],
305
+ entity=wandb_cfg.get("entity"),
306
+ name=wandb_cfg.get("run_name"),
307
+ dir=wandb_cfg.get("dir"),
308
+ config=config_train,
309
+ mode=wandb_cfg.get("mode", "online"),
310
+ )
311
+ return run
312
+
313
+
314
+ def print_parameter_stats(model):
315
+ total = 0
316
+ trainable = 0
317
+
318
+ for p in model.parameters():
319
+ num = p.numel()
320
+ total += num
321
+ if p.requires_grad:
322
+ trainable += num
323
+
324
+ print("\nParameter statistics:")
325
+ print(f"Total parameters: {total:,}")
326
+ print(f"Trainable parameters: {trainable:,}")
327
+ print(f"Frozen parameters: {total - trainable:,}\n")
328
+
329
+
330
+ def main():
331
+ parser = argparse.ArgumentParser()
332
+ parser.add_argument("--config", type=str, default="config/config_train.json")
333
+ args = parser.parse_args()
334
+
335
+ config_train = load_json(args.config)
336
+ config_paths = config_train["paths"]
337
+ config_data = load_json(config_paths["config_data_path"])
338
+ config_model = load_json(config_paths["config_model_path"])
339
+
340
+ seed_cfg = config_train["seed"]
341
+ runtime_cfg = config_train["runtime"]
342
+ optim_cfg = config_train["optimization"]
343
+ checkpoint_cfg = config_train["checkpoint"]
344
+ logging_cfg = config_train["logging"]
345
+ loss_cfg = config_train["loss"]
346
+
347
+ set_seed(int(seed_cfg["seed"]), deterministic=bool(seed_cfg.get("deterministic", True)))
348
+
349
+ device = resolve_device(runtime_cfg["device"])
350
+ dtype = get_dtype(config_model.get("dtype", "bfloat16"))
351
+
352
+ output_dir = Path(config_paths["output_dir"])
353
+ checkpoint_dir = output_dir / "checkpoints"
354
+ output_dir.mkdir(parents=True, exist_ok=True)
355
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
356
+
357
+ save_json(config_train, str(output_dir / "config_train.snapshot.json"))
358
+ save_json(config_data, str(output_dir / "config_data.snapshot.json"))
359
+ save_json(config_model, str(output_dir / "config_model.snapshot.json"))
360
+
361
+ dataset = SoilFormerDataset(
362
+ csv_path=config_data["data_csv_path"],
363
+ photo_map_path=config_data["photo_map_path"],
364
+ cat_vocab_path=config_data["cat_vocab_path"],
365
+ numeric_vocab_path=config_data["numeric_vocab_path"],
366
+ numeric_stats_path=config_data["numeric_stats_path"],
367
+ photo_root=config_data["photo_root"],
368
+ image_size=int(config_data["image_size"]),
369
+ )
370
+
371
+ train_loader, eval_loader, train_generator = build_train_eval_dataloaders(
372
+ dataset=dataset,
373
+ train_ratio=float(config_data["train_ratio"]),
374
+ seed=int(config_data["train_eval_split_seed"]),
375
+ batch_size=int(config_data["batch_size"]),
376
+ )
377
+ print("\nSample statistics:")
378
+ print("Train samples:", len(train_loader.dataset))
379
+ print("Eval samples:", len(eval_loader.dataset))
380
+ train_generator.manual_seed(int(seed_cfg["seed"]))
381
+
382
+ model = SoilFormer(config=config_model, device=str(device))
383
+
384
+ resume_path = checkpoint_cfg.get("resume_checkpoint_path")
385
+ if resume_path:
386
+ checkpoint = torch.load(resume_path, map_location="cpu")
387
+ load_checkpoint_model_state(model, checkpoint["model_state_dict"])
388
+ else:
389
+ model.init_weights(std=float(runtime_cfg.get("init_weight_std", 0.02)))
390
+ checkpoint = None
391
+
392
+ print_parameter_stats(model)
393
+
394
+ optimizer = AdamW(
395
+ [p for p in model.parameters() if p.requires_grad],
396
+ lr=float(optim_cfg["lr"]),
397
+ betas=(float(optim_cfg["beta1"]), float(optim_cfg["beta2"])),
398
+ eps=float(optim_cfg["eps"]),
399
+ weight_decay=float(optim_cfg["weight_decay"]),
400
+ )
401
+
402
+ scheduler = build_scheduler(
403
+ optimizer=optimizer,
404
+ scheduler_cfg=optim_cfg.get("scheduler", {"type": "none"})
405
+ )
406
+
407
+ start_epoch = 1
408
+ global_step = 0
409
+
410
+ if checkpoint is not None:
411
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
412
+ if scheduler is not None and checkpoint.get("scheduler_state_dict") is not None:
413
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
414
+ start_epoch = int(checkpoint["epoch"]) + 1
415
+ global_step = int(checkpoint.get("global_step", 0))
416
+
417
+ wandb_run = maybe_init_wandb(config_train)
418
+
419
+ num_epochs = int(runtime_cfg["num_epochs"])
420
+ show_tqdm = bool(logging_cfg.get("tqdm", True))
421
+ cat_mask_ratio = float(config_data["cat_mask_ratio"])
422
+ num_mask_ratio = float(config_data["num_mask_ratio"])
423
+ active_mask_seed = int(config_data["active_mask_seed"])
424
+ max_grad_norm = optim_cfg.get("max_grad_norm")
425
+ epochs_per_save = int(checkpoint_cfg["epochs_per_save"])
426
+ max_saved_checkpoints = int(checkpoint_cfg["max_saved_checkpoints"])
427
+
428
+ for epoch in range(start_epoch, num_epochs + 1):
429
+ model.train()
430
+
431
+ epoch_totals = {
432
+ "total": 0.0,
433
+ "cat_loss": 0.0,
434
+ "num_loss": 0.0,
435
+ "cat_base": 0.0,
436
+ "num_base": 0.0,
437
+ "cat_acc": 0.0,
438
+ }
439
+ num_batches = 0
440
+
441
+ iterator = train_loader
442
+ if show_tqdm:
443
+ iterator = tqdm(train_loader, desc=f"Train {epoch}", leave=True)
444
+
445
+ for batch_idx, raw_batch in enumerate(iterator):
446
+ global_step += 1
447
+ mask_seed = int(active_mask_seed + epoch * 1_000_000 + batch_idx)
448
+ masked_batch = dataset.perform_active_mask(
449
+ raw_batch,
450
+ cat_ratio=cat_mask_ratio,
451
+ num_ratio=num_mask_ratio,
452
+ seed=mask_seed,
453
+ )
454
+
455
+ optimizer.zero_grad(set_to_none=True)
456
+
457
+ total_loss, stats = compute_loss_from_batch(
458
+ model=model,
459
+ batch=masked_batch,
460
+ device=device,
461
+ dtype=dtype,
462
+ cat_s_bound=loss_cfg.get("cat_s_bound", None),
463
+ num_s_bound=loss_cfg.get("num_s_bound", None),
464
+ )
465
+
466
+ total_loss.backward()
467
+ if max_grad_norm is not None:
468
+ torch.nn.utils.clip_grad_norm_(model.parameters(), float(max_grad_norm))
469
+ optimizer.step()
470
+
471
+ num_batches += 1
472
+ for key in epoch_totals:
473
+ epoch_totals[key] += float(stats[key].item())
474
+
475
+ current_lr = float(optimizer.param_groups[0]["lr"])
476
+ train_step_log = {
477
+ "train/step_total": float(stats["total"].item()),
478
+ "train/step_cat_loss": float(stats["cat_loss"].item()),
479
+ "train/step_num_loss": float(stats["num_loss"].item()),
480
+ "train/step_cat_acc": float(stats["cat_acc"].item()),
481
+ "train/lr": current_lr,
482
+ "epoch": epoch,
483
+ "global_step": global_step,
484
+ }
485
+
486
+ if wandb_run is not None:
487
+ wandb.log(train_step_log, step=global_step)
488
+
489
+ if show_tqdm:
490
+ iterator.set_postfix(
491
+ loss=f"{train_step_log['train/step_total']:.4f}",
492
+ lr=f"{current_lr:.3e}",
493
+ )
494
+
495
+ if num_batches == 0:
496
+ raise RuntimeError("Train dataloader is empty")
497
+
498
+ train_epoch_log = {f"train/{k}": v / num_batches for k, v in epoch_totals.items()}
499
+ train_epoch_log["train/lr_epoch_end"] = float(optimizer.param_groups[0]["lr"])
500
+ train_epoch_log["epoch"] = epoch
501
+ train_epoch_log["global_step"] = global_step
502
+
503
+ eval_log = evaluate(
504
+ model=model,
505
+ dataset=dataset,
506
+ eval_loader=eval_loader,
507
+ device=device,
508
+ dtype=dtype,
509
+ cat_mask_ratio=cat_mask_ratio,
510
+ num_mask_ratio=num_mask_ratio,
511
+ active_mask_seed=active_mask_seed,
512
+ show_tqdm=show_tqdm,
513
+ epoch=epoch,
514
+ cat_s_bound=loss_cfg.get("cat_s_bound", None),
515
+ num_s_bound=loss_cfg.get("num_s_bound", None),
516
+ )
517
+ eval_log["epoch"] = epoch
518
+ eval_log["global_step"] = global_step
519
+
520
+ merged_log = {}
521
+ merged_log.update(train_epoch_log)
522
+ merged_log.update(eval_log)
523
+
524
+ print(json.dumps(merged_log, ensure_ascii=False))
525
+
526
+ if wandb_run is not None:
527
+ wandb.log(merged_log, step=global_step)
528
+
529
+ if scheduler is not None:
530
+ scheduler.step()
531
+
532
+ if epochs_per_save > 0 and epoch % epochs_per_save == 0:
533
+ checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
534
+ save_checkpoint(
535
+ checkpoint_path=checkpoint_path,
536
+ model=model,
537
+ optimizer=optimizer,
538
+ scheduler=scheduler,
539
+ epoch=epoch,
540
+ global_step=global_step,
541
+ config_train=config_train,
542
+ config_model=config_model,
543
+ config_data=config_data,
544
+ )
545
+ rotate_checkpoints(checkpoint_dir, max_saved_checkpoints)
546
+
547
+ if wandb_run is not None:
548
+ wandb.finish()
549
+
550
+
551
+ if __name__ == "__main__":
552
+ main()
modelling/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import json
5
+ from typing import Dict
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F # noqa
11
+
12
+
13
+ class GroupedMLP(nn.Module):
14
+ """
15
+ Batched per-variable MLP for a fixed n_in bucket.
16
+
17
+ Input: X [B, V, n_in]
18
+ Output: Y [B, V, n_out]
19
+
20
+ Per-variable weights (NOT shared across V):
21
+ - 1-layer: W [V, n_out, n_in], b [V, n_out]
22
+ - 2-layer: W1 [V, mid, n_in], b1 [V, mid]
23
+ W2 [V, n_out, mid], b2 [V, n_out]
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ n_var: int,
29
+ n_in: int,
30
+ n_out: int,
31
+ middle_size: Optional[int] = None,
32
+ bias: bool = True,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.n_var = int(n_var)
37
+ self.n_in = int(n_in)
38
+ self.n_out = int(n_out)
39
+ self.middle_size = None if middle_size is None else int(middle_size)
40
+ self.bias = bias
41
+
42
+ if self.middle_size is None:
43
+ self.W = nn.Parameter(torch.empty(self.n_var, self.n_out, self.n_in))
44
+
45
+ if bias:
46
+ self.b = nn.Parameter(torch.empty(self.n_var, self.n_out))
47
+ else:
48
+ self.register_parameter("b", None)
49
+
50
+ self.W1 = self.b1 = self.W2 = self.b2 = None
51
+
52
+ else:
53
+ mid = self.middle_size
54
+
55
+ self.W1 = nn.Parameter(torch.empty(self.n_var, mid, self.n_in))
56
+ self.W2 = nn.Parameter(torch.empty(self.n_var, self.n_out, mid))
57
+
58
+ if bias:
59
+ self.b1 = nn.Parameter(torch.empty(self.n_var, mid))
60
+ self.b2 = nn.Parameter(torch.empty(self.n_var, self.n_out))
61
+ else:
62
+ self.register_parameter("b1", None)
63
+ self.register_parameter("b2", None)
64
+
65
+ self.W = self.b = None
66
+
67
+ def init_weights(self, std: float = 0.02) -> None:
68
+ """
69
+ Initialize weights manually.
70
+ """
71
+ if self.middle_size is None:
72
+ nn.init.normal_(self.W, std=std)
73
+ if self.bias:
74
+ nn.init.zeros_(self.b)
75
+ else:
76
+ nn.init.normal_(self.W1, std=std)
77
+ nn.init.normal_(self.W2, std=std)
78
+
79
+ if self.bias:
80
+ nn.init.zeros_(self.b1)
81
+ nn.init.zeros_(self.b2)
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ if x.dim() != 3:
85
+ raise ValueError(f"Expected x [B,V,n_in], got {tuple(x.shape)}")
86
+
87
+ B, V, I = x.shape
88
+
89
+ if V != self.n_var or I != self.n_in:
90
+ raise ValueError(
91
+ f"Shape mismatch: expected V={self.n_var}, n_in={self.n_in}; got V={V}, n_in={I}"
92
+ )
93
+
94
+ if self.middle_size is None:
95
+ y = torch.einsum("bvi,voi->bvo", x, self.W)
96
+ if self.bias:
97
+ y = y + self.b.unsqueeze(0)
98
+ return y
99
+
100
+ h = torch.einsum("bvi,vmi->bvm", x, self.W1)
101
+ if self.bias:
102
+ h = h + self.b1.unsqueeze(0)
103
+
104
+ h = F.gelu(h)
105
+
106
+ y = torch.einsum("bvm,vom->bvo", h, self.W2)
107
+ if self.bias:
108
+ y = y + self.b2.unsqueeze(0)
109
+
110
+ return y
111
+
112
+
113
+ def get_dtype(dtype: Optional[str]) -> torch.dtype:
114
+ dtype_str = (dtype or "bfloat16").lower()
115
+ dtype_map = {
116
+ "bfloat16": torch.bfloat16,
117
+ "float16": torch.float16,
118
+ "float32": torch.float32,
119
+ }
120
+ if dtype_str not in dtype_map:
121
+ raise ValueError(f"Unsupported dtype={dtype}. Choose from {list(dtype_map.keys())}")
122
+ return dtype_map[dtype_str]
123
+
124
+
125
+ def load_json(path: str):
126
+ with open(path, "r", encoding="utf-8") as f:
127
+ return json.load(f)
128
+
129
+
130
+ def save_json(obj: Dict, path: str) -> None:
131
+ with open(path, "w", encoding="utf-8") as f:
132
+ json.dump(obj, f, ensure_ascii=False, indent=2) # noqa
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch~=2.10.0
2
+ numpy~=2.3.4
3
+ wandb~=0.25.1
4
+ tqdm~=4.67.1
5
+ pandas~=2.3.3
6
+ requests~=2.32.5
7
+ pillow~=12.0.0
8
+ torchvision~=0.25.0
9
+ safetensors~=0.7.0
10
+ transformers~=5.2.0
resources/arch.png ADDED

Git LFS Details

  • SHA256: cd0891f93c9b4970faeb6603ebcba7a07f8f41ff35a72de932ba0c3486187259
  • Pointer size: 131 Bytes
  • Size of remote file: 421 kB