jrhuebers commited on
Commit
b8f4b1f
·
verified ·
1 Parent(s): 4ab3747

Upload FIM-ODE base model

Browse files
base_model/checkpoints/best-model/best-model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad7078ae75a5c4417ec0da095a242cb6cdcd874f1c2dd9191efab9c124800fe4
3
+ size 52002366
base_model/checkpoints/best-model/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "model_config": {
4
+ "attention_map": "softmax",
5
+ "attention_method": "linear",
6
+ "dim_embed": 256,
7
+ "dim_feedforward": 1024,
8
+ "dim_ffn_u_model": 1024,
9
+ "dim_hidden_u_model": 256,
10
+ "dim_max_trajectory": 3,
11
+ "dropout": 0.1,
12
+ "num_context_encoder_layers": 2,
13
+ "num_heads": 8,
14
+ "num_res_layer_u_model": 6,
15
+ "num_res_layers_functional_decoder": 8,
16
+ "use_bias_for_projection": true,
17
+ "use_bias_in_attention": true,
18
+ "use_query_residual_in_attention": true
19
+ },
20
+ "train_config": {
21
+ "corruption_model_type": "odeformer",
22
+ "loss_filter_nans": true,
23
+ "loss_type": "l1",
24
+ "max_sigma_trajectory_noise": 0.06,
25
+ "max_subsampling_ration": 0.5,
26
+ "train_type": "vector_field",
27
+ "train_with_normalized_head": true
28
+ },
29
+ "transformers_version": "4.46.0"
30
+ }
base_model/checkpoints/best-model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df648066cb57306c558f4faa89399103cd24279c477c8db7676b182171b2d36
3
+ size 51907384
base_model/checkpoints/best-model/optimizers-checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f94f664040a5bfecb084568f6a99c9f89e502b7a8ec2c7b161f13b6987e09d30
3
+ size 19288
base_model/checkpoints/best-model/train-state-checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:008b081ce417e75defa6ab6cde494d1c6e7351244dc39622671b2e7b0662333e
3
+ size 643246
base_model/model_architecture.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==============================================================================================================
2
+ Layer (type:depth-idx) Output Shape Param #
3
+ ==============================================================================================================
4
+ TrainingWrapper -- --
5
+ ├─FimOdeon: 1-1 -- --
6
+ │ └─TrajectoryEncoder: 2-1 -- 896
7
+ │ │ └─TransformerEncoder: 3-1 [1, 1194, 256] 1,579,520
8
+ │ └─Sequential: 2-2 -- --
9
+ │ │ └─Linear: 3-2 [1, 2400, 256] 1,024
10
+ │ │ └─ReLU: 3-3 [1, 2400, 256] --
11
+ │ │ └─Linear: 3-4 [1, 2400, 256] 65,792
12
+ │ └─AttentionOperator: 2-3 -- --
13
+ │ │ └─ModuleList: 3-5 -- 6,318,080
14
+ │ │ └─MLP: 3-6 [1, 2400, 3] 132,355
15
+ ├─UncertaintyEstimator: 1-2 -- --
16
+ │ └─AttentionOperator: 2-4 -- --
17
+ │ │ └─ModuleList: 3-7 -- 4,738,560
18
+ │ │ └─MLP: 3-8 [1, 2400, 1] 131,841
19
+ ==============================================================================================================
20
+ Total params: 12,968,068
21
+ Trainable params: 12,968,068
22
+ Non-trainable params: 0
23
+ Total mult-adds (Units.MEGABYTES): 12.97
24
+ ==============================================================================================================
25
+ Input size (MB): 0.09
26
+ Forward/backward pass size (MB): 771.15
27
+ Params size (MB): 51.87
28
+ Estimated Total Size (MB): 823.11
29
+ ==============================================================================================================
base_model/train_parameters.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ dataset:
3
+ add_dim_keys:
4
+ test: !!python/tuple
5
+ - drift_at_observations
6
+ train: !!python/tuple
7
+ - drift_at_observations
8
+ validation: !!python/tuple
9
+ - drift_at_observations
10
+ add_paths_keys:
11
+ test: !!python/tuple
12
+ - drift_at_observations
13
+ train: !!python/tuple
14
+ - drift_at_observations
15
+ validation: !!python/tuple
16
+ - drift_at_observations
17
+ batch_size:
18
+ test: 32
19
+ train: 64
20
+ validation: 32
21
+ data_dirs:
22
+ test: !!python/tuple
23
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_3
24
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_2
25
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/test/test_deg_1
26
+ train: !!python/tuple
27
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_3
28
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_2
29
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/train/train_deg_1
30
+ validation: !!python/tuple
31
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_3
32
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_2
33
+ - /lustre/mlnvme/data/s78mmaue_hpc-demo2/data_generation/data/123_600k_with_obs_drift/0/data/processed/train/30k_drift_deg_3_ablation_studies/degree_and_monomial_survival_uniform/validation/val_deg_1
34
+ dataset_name:
35
+ test: HeterogeneousFIMSDEDataset
36
+ train: StreamingFIMSDEDataset
37
+ validation: StreamingFIMSDEDataset
38
+ files_to_load:
39
+ drift_at_locations: drift_at_locations.h5
40
+ drift_at_observations: drift_at_observations.h5
41
+ locations: locations.h5
42
+ obs_mask: obs_mask.h5
43
+ obs_times: obs_times.h5
44
+ obs_values: obs_values.h5
45
+ max_dim: 3
46
+ name: FIMSDEDataloaderIterableDataset
47
+ num_locations:
48
+ test: null
49
+ train: 2000
50
+ validation: 10000
51
+ num_observations:
52
+ test: null
53
+ train: !!python/tuple
54
+ - 0
55
+ - 1801
56
+ validation: !!python/tuple
57
+ - 1799
58
+ - 1801
59
+ num_workers:
60
+ test: 0
61
+ train: 7
62
+ validation: 5
63
+ shard:
64
+ test: false
65
+ train: true
66
+ validation: true
67
+ shuffle_elements: true
68
+ shuffle_locations:
69
+ test: false
70
+ train: true
71
+ validation: true
72
+ shuffle_paths: true
73
+
74
+ distributed:
75
+ activation_chekpoint: false
76
+ checkpoint_type: full_state
77
+ enabled: true
78
+ min_num_params: 1e5
79
+ sharding_strategy: NO_SHARD
80
+ wrap_policy: SIZE_BAZED
81
+
82
+ experiment:
83
+ device_map: cuda
84
+ name: big_model_l1_600k_examples
85
+ name_add_date: true
86
+ seed: 10
87
+
88
+ model:
89
+ model_config:
90
+ attention_map: softmax
91
+ attention_method: linear
92
+ dim_embed: 256
93
+ dim_feedforward: 1024
94
+ dim_ffn_u_model: 1024
95
+ dim_hidden_u_model: 256
96
+ dim_max_trajectory: 3
97
+ dropout: 0.1
98
+ num_context_encoder_layers: 2
99
+ num_heads: 8
100
+ num_res_layer_u_model: 6
101
+ num_res_layers_functional_decoder: 8
102
+ use_bias_for_projection: true
103
+ use_bias_in_attention: true
104
+ use_query_residual_in_attention: true
105
+ model_type: TrainingWrapper
106
+ train_config:
107
+ corruption_model_type: odeformer
108
+ loss_filter_nans: true
109
+ loss_type: l1
110
+ max_sigma_trajectory_noise: 0.06
111
+ max_subsampling_ration: 0.5
112
+ train_type: vector_field
113
+ train_with_normalized_head: true
114
+
115
+ optimizers: !!python/tuple
116
+ - optimizer_d:
117
+ gradient_norm_clipping: 10
118
+ lr: 1.0e-05
119
+ name: torch.optim.AdamW
120
+ weight_decay: 0.0001
121
+
122
+ trainer:
123
+ best_metric: loss
124
+ debug_iterations: null
125
+ detect_anomaly: false
126
+ epochs: 2500
127
+ experiment_dir: ./results/
128
+ gradient_accumulation_steps: 1
129
+ logging_format: RANK_%(rank)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s
130
+ name: Trainer
131
+ precision: bf16mixed
132
+ save_every: 1
133
+ schedulers: !!python/tuple
134
+ - beta: 1.0
135
+ label: drift_loss_scale
136
+ name: fim.utils.param_scheduler.ConstantScheduler