Ata Celen
commited on
Commit
·
dcaa3ad
1
Parent(s):
bb12ece
Model Weights added
Browse files- qwen2-vl-3d/config.json +48 -0
- qwen2-vl-3d/generation_config.json +14 -0
- qwen2-vl-3d/model-00001-of-00004.safetensors +3 -0
- qwen2-vl-3d/model-00002-of-00004.safetensors +3 -0
- qwen2-vl-3d/model-00003-of-00004.safetensors +3 -0
- qwen2-vl-3d/model-00004-of-00004.safetensors +3 -0
- qwen2-vl-3d/model.safetensors.index.json +899 -0
- qwen2-vl-3d/optimizer.pt +3 -0
- qwen2-vl-3d/rng_state.pth +3 -0
- qwen2-vl-3d/scheduler.pt +3 -0
- qwen2-vl-3d/trainer_state.json +398 -0
- residual-diffuser/args.json +97 -0
- residual-diffuser/dataset_config.pkl +3 -0
- residual-diffuser/diff.txt +1895 -0
- residual-diffuser/diffusion_config.pkl +3 -0
- residual-diffuser/model_config.pkl +3 -0
- residual-diffuser/render_config.pkl +3 -0
- residual-diffuser/state_58000.pt +3 -0
- residual-diffuser/test_indices.txt +100 -0
- residual-diffuser/train_indices.txt +691 -0
- residual-diffuser/trainer_config.pkl +3 -0
- residual-diffuser/val_indices.txt +87 -0
qwen2-vl-3d/config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "Qwen/Qwen2-VL-7B-Instruct",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"Qwen2VLSpatialForConditionalGeneration"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"eos_token_id": 151645,
|
| 9 |
+
"hidden_act": "silu",
|
| 10 |
+
"hidden_size": 3584,
|
| 11 |
+
"image_token_id": 151655,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 18944,
|
| 14 |
+
"max_position_embeddings": 32768,
|
| 15 |
+
"max_window_layers": 28,
|
| 16 |
+
"model_type": "qwen2_vl",
|
| 17 |
+
"num_attention_heads": 28,
|
| 18 |
+
"num_hidden_layers": 28,
|
| 19 |
+
"num_key_value_heads": 4,
|
| 20 |
+
"rms_norm_eps": 1e-06,
|
| 21 |
+
"rope_scaling": {
|
| 22 |
+
"mrope_section": [
|
| 23 |
+
16,
|
| 24 |
+
24,
|
| 25 |
+
24
|
| 26 |
+
],
|
| 27 |
+
"rope_type": "default",
|
| 28 |
+
"type": "default"
|
| 29 |
+
},
|
| 30 |
+
"rope_theta": 1000000.0,
|
| 31 |
+
"sliding_window": 32768,
|
| 32 |
+
"tie_word_embeddings": false,
|
| 33 |
+
"torch_dtype": "bfloat16",
|
| 34 |
+
"transformers_version": "4.49.0.dev0",
|
| 35 |
+
"use_cache": true,
|
| 36 |
+
"use_sliding_window": false,
|
| 37 |
+
"video_token_id": 151656,
|
| 38 |
+
"vision_config": {
|
| 39 |
+
"in_chans": 3,
|
| 40 |
+
"model_type": "qwen2_vl",
|
| 41 |
+
"spatial_patch_size": 14,
|
| 42 |
+
"torch_dtype": "bfloat16"
|
| 43 |
+
},
|
| 44 |
+
"vision_end_token_id": 151653,
|
| 45 |
+
"vision_start_token_id": 151652,
|
| 46 |
+
"vision_token_id": 151654,
|
| 47 |
+
"vocab_size": 151660
|
| 48 |
+
}
|
qwen2-vl-3d/generation_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attn_implementation": "flash_attention_2",
|
| 3 |
+
"bos_token_id": 151643,
|
| 4 |
+
"do_sample": true,
|
| 5 |
+
"eos_token_id": [
|
| 6 |
+
151645,
|
| 7 |
+
151643
|
| 8 |
+
],
|
| 9 |
+
"pad_token_id": 151643,
|
| 10 |
+
"temperature": 0.01,
|
| 11 |
+
"top_k": 1,
|
| 12 |
+
"top_p": 0.001,
|
| 13 |
+
"transformers_version": "4.49.0.dev0"
|
| 14 |
+
}
|
qwen2-vl-3d/model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ebcaeaea7f52c327c50f91b36d0f3d63cd96b445572c129fdb2760b001bd323
|
| 3 |
+
size 4963764072
|
qwen2-vl-3d/model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:638919d8742180acc5d1ba5016cb77d9c1d51eb0aaf4966410e82d2dcac48580
|
| 3 |
+
size 4991495816
|
qwen2-vl-3d/model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff7bf17a78fd87ad5adf702a460330b1ae89ae2f121175eb6045f3a6a3c905fc
|
| 3 |
+
size 4932751040
|
qwen2-vl-3d/model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:254fc35d50a309eefb1ef995ae67767137ad0a01b573b6ad30cf66cb6896e6a1
|
| 3 |
+
size 1720377840
|
qwen2-vl-3d/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 16608291000
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"adapter.bias": "model-00004-of-00004.safetensors",
|
| 7 |
+
"adapter.weight": "model-00004-of-00004.safetensors",
|
| 8 |
+
"diffuser.alphas_cumprod": "model-00004-of-00004.safetensors",
|
| 9 |
+
"diffuser.alphas_cumprod_prev": "model-00004-of-00004.safetensors",
|
| 10 |
+
"diffuser.betas": "model-00004-of-00004.safetensors",
|
| 11 |
+
"diffuser.log_one_minus_alphas_cumprod": "model-00004-of-00004.safetensors",
|
| 12 |
+
"diffuser.model.downs.0.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 13 |
+
"diffuser.model.downs.0.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 14 |
+
"diffuser.model.downs.0.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 15 |
+
"diffuser.model.downs.0.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 16 |
+
"diffuser.model.downs.0.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 17 |
+
"diffuser.model.downs.0.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 18 |
+
"diffuser.model.downs.0.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 19 |
+
"diffuser.model.downs.0.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 20 |
+
"diffuser.model.downs.0.0.residual_conv.bias": "model-00004-of-00004.safetensors",
|
| 21 |
+
"diffuser.model.downs.0.0.residual_conv.weight": "model-00004-of-00004.safetensors",
|
| 22 |
+
"diffuser.model.downs.0.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 23 |
+
"diffuser.model.downs.0.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 24 |
+
"diffuser.model.downs.0.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 25 |
+
"diffuser.model.downs.0.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 26 |
+
"diffuser.model.downs.0.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 27 |
+
"diffuser.model.downs.0.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 28 |
+
"diffuser.model.downs.0.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 29 |
+
"diffuser.model.downs.0.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 30 |
+
"diffuser.model.downs.0.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 31 |
+
"diffuser.model.downs.0.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 32 |
+
"diffuser.model.downs.0.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 33 |
+
"diffuser.model.downs.0.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 34 |
+
"diffuser.model.downs.0.2.conv.bias": "model-00004-of-00004.safetensors",
|
| 35 |
+
"diffuser.model.downs.0.2.conv.weight": "model-00004-of-00004.safetensors",
|
| 36 |
+
"diffuser.model.downs.1.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 37 |
+
"diffuser.model.downs.1.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 38 |
+
"diffuser.model.downs.1.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 39 |
+
"diffuser.model.downs.1.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 40 |
+
"diffuser.model.downs.1.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 41 |
+
"diffuser.model.downs.1.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 42 |
+
"diffuser.model.downs.1.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 43 |
+
"diffuser.model.downs.1.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 44 |
+
"diffuser.model.downs.1.0.residual_conv.bias": "model-00004-of-00004.safetensors",
|
| 45 |
+
"diffuser.model.downs.1.0.residual_conv.weight": "model-00004-of-00004.safetensors",
|
| 46 |
+
"diffuser.model.downs.1.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 47 |
+
"diffuser.model.downs.1.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 48 |
+
"diffuser.model.downs.1.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 49 |
+
"diffuser.model.downs.1.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 50 |
+
"diffuser.model.downs.1.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 51 |
+
"diffuser.model.downs.1.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 52 |
+
"diffuser.model.downs.1.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 53 |
+
"diffuser.model.downs.1.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 54 |
+
"diffuser.model.downs.1.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 55 |
+
"diffuser.model.downs.1.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 56 |
+
"diffuser.model.downs.1.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 57 |
+
"diffuser.model.downs.1.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 58 |
+
"diffuser.model.downs.1.2.conv.bias": "model-00004-of-00004.safetensors",
|
| 59 |
+
"diffuser.model.downs.1.2.conv.weight": "model-00004-of-00004.safetensors",
|
| 60 |
+
"diffuser.model.downs.2.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 61 |
+
"diffuser.model.downs.2.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 62 |
+
"diffuser.model.downs.2.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 63 |
+
"diffuser.model.downs.2.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 64 |
+
"diffuser.model.downs.2.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 65 |
+
"diffuser.model.downs.2.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 66 |
+
"diffuser.model.downs.2.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 67 |
+
"diffuser.model.downs.2.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 68 |
+
"diffuser.model.downs.2.0.residual_conv.bias": "model-00004-of-00004.safetensors",
|
| 69 |
+
"diffuser.model.downs.2.0.residual_conv.weight": "model-00004-of-00004.safetensors",
|
| 70 |
+
"diffuser.model.downs.2.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 71 |
+
"diffuser.model.downs.2.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 72 |
+
"diffuser.model.downs.2.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 73 |
+
"diffuser.model.downs.2.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 74 |
+
"diffuser.model.downs.2.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 75 |
+
"diffuser.model.downs.2.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 76 |
+
"diffuser.model.downs.2.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 77 |
+
"diffuser.model.downs.2.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 78 |
+
"diffuser.model.downs.2.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 79 |
+
"diffuser.model.downs.2.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 80 |
+
"diffuser.model.downs.2.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 81 |
+
"diffuser.model.downs.2.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 82 |
+
"diffuser.model.final_conv.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 83 |
+
"diffuser.model.final_conv.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 84 |
+
"diffuser.model.final_conv.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 85 |
+
"diffuser.model.final_conv.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 86 |
+
"diffuser.model.final_conv.1.bias": "model-00004-of-00004.safetensors",
|
| 87 |
+
"diffuser.model.final_conv.1.weight": "model-00004-of-00004.safetensors",
|
| 88 |
+
"diffuser.model.mid_block1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 89 |
+
"diffuser.model.mid_block1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 90 |
+
"diffuser.model.mid_block1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 91 |
+
"diffuser.model.mid_block1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 92 |
+
"diffuser.model.mid_block1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 93 |
+
"diffuser.model.mid_block1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 94 |
+
"diffuser.model.mid_block1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 95 |
+
"diffuser.model.mid_block1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 96 |
+
"diffuser.model.mid_block1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 97 |
+
"diffuser.model.mid_block1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 98 |
+
"diffuser.model.mid_block2.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 99 |
+
"diffuser.model.mid_block2.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 100 |
+
"diffuser.model.mid_block2.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 101 |
+
"diffuser.model.mid_block2.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 102 |
+
"diffuser.model.mid_block2.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 103 |
+
"diffuser.model.mid_block2.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 104 |
+
"diffuser.model.mid_block2.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 105 |
+
"diffuser.model.mid_block2.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 106 |
+
"diffuser.model.mid_block2.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 107 |
+
"diffuser.model.mid_block2.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 108 |
+
"diffuser.model.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 109 |
+
"diffuser.model.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 110 |
+
"diffuser.model.time_mlp.3.bias": "model-00004-of-00004.safetensors",
|
| 111 |
+
"diffuser.model.time_mlp.3.weight": "model-00004-of-00004.safetensors",
|
| 112 |
+
"diffuser.model.ups.0.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 113 |
+
"diffuser.model.ups.0.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 114 |
+
"diffuser.model.ups.0.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 115 |
+
"diffuser.model.ups.0.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 116 |
+
"diffuser.model.ups.0.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 117 |
+
"diffuser.model.ups.0.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 118 |
+
"diffuser.model.ups.0.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 119 |
+
"diffuser.model.ups.0.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 120 |
+
"diffuser.model.ups.0.0.residual_conv.bias": "model-00004-of-00004.safetensors",
|
| 121 |
+
"diffuser.model.ups.0.0.residual_conv.weight": "model-00004-of-00004.safetensors",
|
| 122 |
+
"diffuser.model.ups.0.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 123 |
+
"diffuser.model.ups.0.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 124 |
+
"diffuser.model.ups.0.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 125 |
+
"diffuser.model.ups.0.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 126 |
+
"diffuser.model.ups.0.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 127 |
+
"diffuser.model.ups.0.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 128 |
+
"diffuser.model.ups.0.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 129 |
+
"diffuser.model.ups.0.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 130 |
+
"diffuser.model.ups.0.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 131 |
+
"diffuser.model.ups.0.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 132 |
+
"diffuser.model.ups.0.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 133 |
+
"diffuser.model.ups.0.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 134 |
+
"diffuser.model.ups.0.2.conv.bias": "model-00004-of-00004.safetensors",
|
| 135 |
+
"diffuser.model.ups.0.2.conv.weight": "model-00004-of-00004.safetensors",
|
| 136 |
+
"diffuser.model.ups.1.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 137 |
+
"diffuser.model.ups.1.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 138 |
+
"diffuser.model.ups.1.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 139 |
+
"diffuser.model.ups.1.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 140 |
+
"diffuser.model.ups.1.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 141 |
+
"diffuser.model.ups.1.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 142 |
+
"diffuser.model.ups.1.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 143 |
+
"diffuser.model.ups.1.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 144 |
+
"diffuser.model.ups.1.0.residual_conv.bias": "model-00004-of-00004.safetensors",
|
| 145 |
+
"diffuser.model.ups.1.0.residual_conv.weight": "model-00004-of-00004.safetensors",
|
| 146 |
+
"diffuser.model.ups.1.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 147 |
+
"diffuser.model.ups.1.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 148 |
+
"diffuser.model.ups.1.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
|
| 149 |
+
"diffuser.model.ups.1.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
|
| 150 |
+
"diffuser.model.ups.1.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
|
| 151 |
+
"diffuser.model.ups.1.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
|
| 152 |
+
"diffuser.model.ups.1.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
|
| 153 |
+
"diffuser.model.ups.1.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
|
| 154 |
+
"diffuser.model.ups.1.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
|
| 155 |
+
"diffuser.model.ups.1.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
|
| 156 |
+
"diffuser.model.ups.1.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
|
| 157 |
+
"diffuser.model.ups.1.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
|
| 158 |
+
"diffuser.model.ups.1.2.conv.bias": "model-00004-of-00004.safetensors",
|
| 159 |
+
"diffuser.model.ups.1.2.conv.weight": "model-00004-of-00004.safetensors",
|
| 160 |
+
"diffuser.posterior_log_variance_clipped": "model-00004-of-00004.safetensors",
|
| 161 |
+
"diffuser.posterior_mean_coef1": "model-00004-of-00004.safetensors",
|
| 162 |
+
"diffuser.posterior_mean_coef2": "model-00004-of-00004.safetensors",
|
| 163 |
+
"diffuser.posterior_variance": "model-00004-of-00004.safetensors",
|
| 164 |
+
"diffuser.sqrt_alphas_cumprod": "model-00004-of-00004.safetensors",
|
| 165 |
+
"diffuser.sqrt_one_minus_alphas_cumprod": "model-00004-of-00004.safetensors",
|
| 166 |
+
"diffuser.sqrt_recip_alphas_cumprod": "model-00004-of-00004.safetensors",
|
| 167 |
+
"diffuser.sqrt_recipm1_alphas_cumprod": "model-00004-of-00004.safetensors",
|
| 168 |
+
"lm_head.weight": "model-00004-of-00004.safetensors",
|
| 169 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
| 170 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 171 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 172 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 173 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 174 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 175 |
+
"model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 176 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 177 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 178 |
+
"model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 179 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 180 |
+
"model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 181 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 182 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 183 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 184 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 185 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 186 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 187 |
+
"model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 188 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 189 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 190 |
+
"model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 191 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 192 |
+
"model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 193 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 194 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 195 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 196 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 197 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 198 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 199 |
+
"model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 200 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 201 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 202 |
+
"model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 203 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 204 |
+
"model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 205 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 206 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 207 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 208 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 209 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 210 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 211 |
+
"model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 212 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 213 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 214 |
+
"model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 215 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 216 |
+
"model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 217 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 218 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 219 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 220 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 221 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 222 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 223 |
+
"model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 224 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 225 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 226 |
+
"model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 227 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 228 |
+
"model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 229 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 230 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 231 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 232 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 233 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 234 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 235 |
+
"model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 236 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 237 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 238 |
+
"model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 239 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 240 |
+
"model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 241 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 242 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 243 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 244 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 245 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 246 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 247 |
+
"model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 248 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 249 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 250 |
+
"model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 251 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 252 |
+
"model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 253 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 254 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 255 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 256 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 257 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 258 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 259 |
+
"model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 260 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 261 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 262 |
+
"model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 263 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 264 |
+
"model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 265 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 266 |
+
"model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 267 |
+
"model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 268 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 269 |
+
"model.layers.16.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 270 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 271 |
+
"model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 272 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 273 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 274 |
+
"model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 275 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 276 |
+
"model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 277 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 278 |
+
"model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 279 |
+
"model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 280 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 281 |
+
"model.layers.17.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 282 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 283 |
+
"model.layers.17.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 284 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 285 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 286 |
+
"model.layers.17.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 287 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 288 |
+
"model.layers.17.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 289 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 290 |
+
"model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 291 |
+
"model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 292 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 293 |
+
"model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 294 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 295 |
+
"model.layers.18.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 296 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 297 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 298 |
+
"model.layers.18.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 299 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 300 |
+
"model.layers.18.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 301 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 302 |
+
"model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 303 |
+
"model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 304 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 305 |
+
"model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 306 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 307 |
+
"model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 308 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 309 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 310 |
+
"model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 311 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 312 |
+
"model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 313 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 314 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 315 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 316 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 317 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 318 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 319 |
+
"model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 320 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 321 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 322 |
+
"model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 323 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 324 |
+
"model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 325 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 326 |
+
"model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 327 |
+
"model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 328 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 329 |
+
"model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 330 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 331 |
+
"model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 332 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 333 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 334 |
+
"model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 335 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 336 |
+
"model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 337 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 338 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 339 |
+
"model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 340 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 341 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 342 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 343 |
+
"model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 344 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 345 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 346 |
+
"model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 347 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 348 |
+
"model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 349 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 350 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 351 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 352 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 353 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 354 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 355 |
+
"model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 356 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 357 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 358 |
+
"model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 359 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 360 |
+
"model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 361 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 362 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 363 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 364 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 365 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 366 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 367 |
+
"model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 368 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 369 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 370 |
+
"model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 371 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 372 |
+
"model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 373 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 374 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 375 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 376 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 377 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 378 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 379 |
+
"model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 380 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 381 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 382 |
+
"model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 383 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 384 |
+
"model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 385 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 386 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 387 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 388 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 389 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 390 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 391 |
+
"model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 392 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 393 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 394 |
+
"model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 395 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 396 |
+
"model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 397 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 398 |
+
"model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 399 |
+
"model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 400 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 401 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 402 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 403 |
+
"model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 404 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 405 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 406 |
+
"model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 407 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 408 |
+
"model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 409 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 410 |
+
"model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 411 |
+
"model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 412 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 413 |
+
"model.layers.27.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 414 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 415 |
+
"model.layers.27.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
|
| 416 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 417 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 418 |
+
"model.layers.27.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
|
| 419 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 420 |
+
"model.layers.27.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
|
| 421 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 422 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 423 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 424 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 425 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 426 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 427 |
+
"model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 428 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 429 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 430 |
+
"model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 431 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 432 |
+
"model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 433 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 434 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 435 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 436 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 437 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 438 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 439 |
+
"model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 440 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 441 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 442 |
+
"model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 443 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 444 |
+
"model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 445 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 446 |
+
"model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 447 |
+
"model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 448 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 449 |
+
"model.layers.5.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 450 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 451 |
+
"model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 452 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 453 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 454 |
+
"model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 455 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 456 |
+
"model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 457 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 458 |
+
"model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 459 |
+
"model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 460 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 461 |
+
"model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 462 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 463 |
+
"model.layers.6.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 464 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 465 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 466 |
+
"model.layers.6.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 467 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 468 |
+
"model.layers.6.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 469 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 470 |
+
"model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 471 |
+
"model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 472 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 473 |
+
"model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 474 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 475 |
+
"model.layers.7.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 476 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 477 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 478 |
+
"model.layers.7.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 479 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 480 |
+
"model.layers.7.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 481 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 482 |
+
"model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 483 |
+
"model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 484 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 485 |
+
"model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 486 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 487 |
+
"model.layers.8.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 488 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 489 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 490 |
+
"model.layers.8.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 491 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 492 |
+
"model.layers.8.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 493 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 494 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 495 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 496 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 497 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 498 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 499 |
+
"model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 500 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 501 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 502 |
+
"model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 503 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 504 |
+
"model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 505 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 506 |
+
"model.norm.weight": "model-00004-of-00004.safetensors",
|
| 507 |
+
"visual.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 508 |
+
"visual.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 509 |
+
"visual.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 510 |
+
"visual.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 511 |
+
"visual.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 512 |
+
"visual.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 513 |
+
"visual.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 514 |
+
"visual.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 515 |
+
"visual.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
|
| 516 |
+
"visual.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
|
| 517 |
+
"visual.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
|
| 518 |
+
"visual.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
|
| 519 |
+
"visual.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 520 |
+
"visual.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 521 |
+
"visual.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 522 |
+
"visual.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 523 |
+
"visual.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 524 |
+
"visual.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 525 |
+
"visual.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 526 |
+
"visual.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 527 |
+
"visual.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
|
| 528 |
+
"visual.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
|
| 529 |
+
"visual.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
|
| 530 |
+
"visual.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
|
| 531 |
+
"visual.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 532 |
+
"visual.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 533 |
+
"visual.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 534 |
+
"visual.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 535 |
+
"visual.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 536 |
+
"visual.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 537 |
+
"visual.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 538 |
+
"visual.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 539 |
+
"visual.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
|
| 540 |
+
"visual.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
|
| 541 |
+
"visual.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
|
| 542 |
+
"visual.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
|
| 543 |
+
"visual.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 544 |
+
"visual.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 545 |
+
"visual.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 546 |
+
"visual.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 547 |
+
"visual.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 548 |
+
"visual.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 549 |
+
"visual.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 550 |
+
"visual.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 551 |
+
"visual.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
|
| 552 |
+
"visual.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
|
| 553 |
+
"visual.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
|
| 554 |
+
"visual.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
|
| 555 |
+
"visual.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 556 |
+
"visual.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 557 |
+
"visual.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 558 |
+
"visual.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 559 |
+
"visual.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 560 |
+
"visual.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 561 |
+
"visual.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 562 |
+
"visual.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 563 |
+
"visual.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
|
| 564 |
+
"visual.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
|
| 565 |
+
"visual.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
|
| 566 |
+
"visual.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
|
| 567 |
+
"visual.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 568 |
+
"visual.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 569 |
+
"visual.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 570 |
+
"visual.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 571 |
+
"visual.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 572 |
+
"visual.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 573 |
+
"visual.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 574 |
+
"visual.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 575 |
+
"visual.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
|
| 576 |
+
"visual.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
|
| 577 |
+
"visual.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
|
| 578 |
+
"visual.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
|
| 579 |
+
"visual.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 580 |
+
"visual.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 581 |
+
"visual.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 582 |
+
"visual.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 583 |
+
"visual.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 584 |
+
"visual.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 585 |
+
"visual.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 586 |
+
"visual.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 587 |
+
"visual.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
|
| 588 |
+
"visual.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
|
| 589 |
+
"visual.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
|
| 590 |
+
"visual.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
|
| 591 |
+
"visual.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 592 |
+
"visual.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 593 |
+
"visual.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 594 |
+
"visual.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 595 |
+
"visual.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 596 |
+
"visual.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 597 |
+
"visual.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 598 |
+
"visual.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 599 |
+
"visual.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
|
| 600 |
+
"visual.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
|
| 601 |
+
"visual.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
|
| 602 |
+
"visual.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
|
| 603 |
+
"visual.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 604 |
+
"visual.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 605 |
+
"visual.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 606 |
+
"visual.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 607 |
+
"visual.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 608 |
+
"visual.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 609 |
+
"visual.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 610 |
+
"visual.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 611 |
+
"visual.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
|
| 612 |
+
"visual.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
|
| 613 |
+
"visual.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
|
| 614 |
+
"visual.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
|
| 615 |
+
"visual.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 616 |
+
"visual.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 617 |
+
"visual.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 618 |
+
"visual.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 619 |
+
"visual.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 620 |
+
"visual.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 621 |
+
"visual.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 622 |
+
"visual.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 623 |
+
"visual.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
|
| 624 |
+
"visual.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
|
| 625 |
+
"visual.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
|
| 626 |
+
"visual.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
|
| 627 |
+
"visual.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 628 |
+
"visual.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 629 |
+
"visual.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 630 |
+
"visual.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 631 |
+
"visual.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 632 |
+
"visual.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 633 |
+
"visual.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 634 |
+
"visual.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 635 |
+
"visual.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
|
| 636 |
+
"visual.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
|
| 637 |
+
"visual.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
|
| 638 |
+
"visual.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
|
| 639 |
+
"visual.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 640 |
+
"visual.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 641 |
+
"visual.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 642 |
+
"visual.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 643 |
+
"visual.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 644 |
+
"visual.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 645 |
+
"visual.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 646 |
+
"visual.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 647 |
+
"visual.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
|
| 648 |
+
"visual.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
|
| 649 |
+
"visual.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
|
| 650 |
+
"visual.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
|
| 651 |
+
"visual.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 652 |
+
"visual.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 653 |
+
"visual.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 654 |
+
"visual.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 655 |
+
"visual.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 656 |
+
"visual.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 657 |
+
"visual.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 658 |
+
"visual.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 659 |
+
"visual.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
|
| 660 |
+
"visual.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
|
| 661 |
+
"visual.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
|
| 662 |
+
"visual.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
|
| 663 |
+
"visual.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 664 |
+
"visual.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 665 |
+
"visual.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 666 |
+
"visual.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 667 |
+
"visual.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 668 |
+
"visual.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 669 |
+
"visual.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 670 |
+
"visual.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 671 |
+
"visual.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
|
| 672 |
+
"visual.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
|
| 673 |
+
"visual.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
|
| 674 |
+
"visual.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
|
| 675 |
+
"visual.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 676 |
+
"visual.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 677 |
+
"visual.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 678 |
+
"visual.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 679 |
+
"visual.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 680 |
+
"visual.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 681 |
+
"visual.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 682 |
+
"visual.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 683 |
+
"visual.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
|
| 684 |
+
"visual.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
|
| 685 |
+
"visual.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
|
| 686 |
+
"visual.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
|
| 687 |
+
"visual.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 688 |
+
"visual.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 689 |
+
"visual.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 690 |
+
"visual.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 691 |
+
"visual.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 692 |
+
"visual.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 693 |
+
"visual.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 694 |
+
"visual.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 695 |
+
"visual.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
|
| 696 |
+
"visual.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
|
| 697 |
+
"visual.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
|
| 698 |
+
"visual.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
|
| 699 |
+
"visual.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 700 |
+
"visual.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 701 |
+
"visual.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 702 |
+
"visual.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 703 |
+
"visual.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 704 |
+
"visual.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 705 |
+
"visual.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 706 |
+
"visual.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 707 |
+
"visual.blocks.23.norm1.bias": "model-00001-of-00004.safetensors",
|
| 708 |
+
"visual.blocks.23.norm1.weight": "model-00001-of-00004.safetensors",
|
| 709 |
+
"visual.blocks.23.norm2.bias": "model-00001-of-00004.safetensors",
|
| 710 |
+
"visual.blocks.23.norm2.weight": "model-00001-of-00004.safetensors",
|
| 711 |
+
"visual.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 712 |
+
"visual.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 713 |
+
"visual.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 714 |
+
"visual.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 715 |
+
"visual.blocks.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 716 |
+
"visual.blocks.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 717 |
+
"visual.blocks.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 718 |
+
"visual.blocks.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 719 |
+
"visual.blocks.24.norm1.bias": "model-00001-of-00004.safetensors",
|
| 720 |
+
"visual.blocks.24.norm1.weight": "model-00001-of-00004.safetensors",
|
| 721 |
+
"visual.blocks.24.norm2.bias": "model-00001-of-00004.safetensors",
|
| 722 |
+
"visual.blocks.24.norm2.weight": "model-00001-of-00004.safetensors",
|
| 723 |
+
"visual.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 724 |
+
"visual.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 725 |
+
"visual.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 726 |
+
"visual.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 727 |
+
"visual.blocks.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 728 |
+
"visual.blocks.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 729 |
+
"visual.blocks.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 730 |
+
"visual.blocks.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 731 |
+
"visual.blocks.25.norm1.bias": "model-00001-of-00004.safetensors",
|
| 732 |
+
"visual.blocks.25.norm1.weight": "model-00001-of-00004.safetensors",
|
| 733 |
+
"visual.blocks.25.norm2.bias": "model-00001-of-00004.safetensors",
|
| 734 |
+
"visual.blocks.25.norm2.weight": "model-00001-of-00004.safetensors",
|
| 735 |
+
"visual.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 736 |
+
"visual.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 737 |
+
"visual.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 738 |
+
"visual.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 739 |
+
"visual.blocks.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 740 |
+
"visual.blocks.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 741 |
+
"visual.blocks.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 742 |
+
"visual.blocks.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 743 |
+
"visual.blocks.26.norm1.bias": "model-00001-of-00004.safetensors",
|
| 744 |
+
"visual.blocks.26.norm1.weight": "model-00001-of-00004.safetensors",
|
| 745 |
+
"visual.blocks.26.norm2.bias": "model-00001-of-00004.safetensors",
|
| 746 |
+
"visual.blocks.26.norm2.weight": "model-00001-of-00004.safetensors",
|
| 747 |
+
"visual.blocks.27.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 748 |
+
"visual.blocks.27.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 749 |
+
"visual.blocks.27.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 750 |
+
"visual.blocks.27.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 751 |
+
"visual.blocks.27.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 752 |
+
"visual.blocks.27.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 753 |
+
"visual.blocks.27.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 754 |
+
"visual.blocks.27.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 755 |
+
"visual.blocks.27.norm1.bias": "model-00001-of-00004.safetensors",
|
| 756 |
+
"visual.blocks.27.norm1.weight": "model-00001-of-00004.safetensors",
|
| 757 |
+
"visual.blocks.27.norm2.bias": "model-00001-of-00004.safetensors",
|
| 758 |
+
"visual.blocks.27.norm2.weight": "model-00001-of-00004.safetensors",
|
| 759 |
+
"visual.blocks.28.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 760 |
+
"visual.blocks.28.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 761 |
+
"visual.blocks.28.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 762 |
+
"visual.blocks.28.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 763 |
+
"visual.blocks.28.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 764 |
+
"visual.blocks.28.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 765 |
+
"visual.blocks.28.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 766 |
+
"visual.blocks.28.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 767 |
+
"visual.blocks.28.norm1.bias": "model-00001-of-00004.safetensors",
|
| 768 |
+
"visual.blocks.28.norm1.weight": "model-00001-of-00004.safetensors",
|
| 769 |
+
"visual.blocks.28.norm2.bias": "model-00001-of-00004.safetensors",
|
| 770 |
+
"visual.blocks.28.norm2.weight": "model-00001-of-00004.safetensors",
|
| 771 |
+
"visual.blocks.29.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 772 |
+
"visual.blocks.29.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 773 |
+
"visual.blocks.29.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 774 |
+
"visual.blocks.29.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 775 |
+
"visual.blocks.29.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 776 |
+
"visual.blocks.29.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 777 |
+
"visual.blocks.29.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 778 |
+
"visual.blocks.29.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 779 |
+
"visual.blocks.29.norm1.bias": "model-00001-of-00004.safetensors",
|
| 780 |
+
"visual.blocks.29.norm1.weight": "model-00001-of-00004.safetensors",
|
| 781 |
+
"visual.blocks.29.norm2.bias": "model-00001-of-00004.safetensors",
|
| 782 |
+
"visual.blocks.29.norm2.weight": "model-00001-of-00004.safetensors",
|
| 783 |
+
"visual.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 784 |
+
"visual.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 785 |
+
"visual.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 786 |
+
"visual.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 787 |
+
"visual.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 788 |
+
"visual.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 789 |
+
"visual.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 790 |
+
"visual.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 791 |
+
"visual.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
|
| 792 |
+
"visual.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
|
| 793 |
+
"visual.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
|
| 794 |
+
"visual.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
|
| 795 |
+
"visual.blocks.30.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 796 |
+
"visual.blocks.30.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 797 |
+
"visual.blocks.30.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 798 |
+
"visual.blocks.30.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 799 |
+
"visual.blocks.30.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 800 |
+
"visual.blocks.30.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 801 |
+
"visual.blocks.30.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 802 |
+
"visual.blocks.30.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 803 |
+
"visual.blocks.30.norm1.bias": "model-00001-of-00004.safetensors",
|
| 804 |
+
"visual.blocks.30.norm1.weight": "model-00001-of-00004.safetensors",
|
| 805 |
+
"visual.blocks.30.norm2.bias": "model-00001-of-00004.safetensors",
|
| 806 |
+
"visual.blocks.30.norm2.weight": "model-00001-of-00004.safetensors",
|
| 807 |
+
"visual.blocks.31.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 808 |
+
"visual.blocks.31.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 809 |
+
"visual.blocks.31.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 810 |
+
"visual.blocks.31.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 811 |
+
"visual.blocks.31.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 812 |
+
"visual.blocks.31.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 813 |
+
"visual.blocks.31.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 814 |
+
"visual.blocks.31.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 815 |
+
"visual.blocks.31.norm1.bias": "model-00001-of-00004.safetensors",
|
| 816 |
+
"visual.blocks.31.norm1.weight": "model-00001-of-00004.safetensors",
|
| 817 |
+
"visual.blocks.31.norm2.bias": "model-00001-of-00004.safetensors",
|
| 818 |
+
"visual.blocks.31.norm2.weight": "model-00001-of-00004.safetensors",
|
| 819 |
+
"visual.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 820 |
+
"visual.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 821 |
+
"visual.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 822 |
+
"visual.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 823 |
+
"visual.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 824 |
+
"visual.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 825 |
+
"visual.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 826 |
+
"visual.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 827 |
+
"visual.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
|
| 828 |
+
"visual.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
|
| 829 |
+
"visual.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
|
| 830 |
+
"visual.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
|
| 831 |
+
"visual.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 832 |
+
"visual.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 833 |
+
"visual.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 834 |
+
"visual.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 835 |
+
"visual.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 836 |
+
"visual.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 837 |
+
"visual.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 838 |
+
"visual.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 839 |
+
"visual.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
|
| 840 |
+
"visual.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
|
| 841 |
+
"visual.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
|
| 842 |
+
"visual.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
|
| 843 |
+
"visual.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 844 |
+
"visual.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 845 |
+
"visual.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 846 |
+
"visual.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 847 |
+
"visual.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 848 |
+
"visual.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 849 |
+
"visual.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 850 |
+
"visual.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 851 |
+
"visual.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
|
| 852 |
+
"visual.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
|
| 853 |
+
"visual.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
|
| 854 |
+
"visual.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
|
| 855 |
+
"visual.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 856 |
+
"visual.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 857 |
+
"visual.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 858 |
+
"visual.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 859 |
+
"visual.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 860 |
+
"visual.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 861 |
+
"visual.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 862 |
+
"visual.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 863 |
+
"visual.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
|
| 864 |
+
"visual.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
|
| 865 |
+
"visual.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
|
| 866 |
+
"visual.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
|
| 867 |
+
"visual.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 868 |
+
"visual.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 869 |
+
"visual.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 870 |
+
"visual.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 871 |
+
"visual.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 872 |
+
"visual.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 873 |
+
"visual.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 874 |
+
"visual.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 875 |
+
"visual.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
|
| 876 |
+
"visual.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
|
| 877 |
+
"visual.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
|
| 878 |
+
"visual.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
|
| 879 |
+
"visual.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 880 |
+
"visual.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 881 |
+
"visual.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
|
| 882 |
+
"visual.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 883 |
+
"visual.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 884 |
+
"visual.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 885 |
+
"visual.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 886 |
+
"visual.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 887 |
+
"visual.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
|
| 888 |
+
"visual.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
|
| 889 |
+
"visual.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
|
| 890 |
+
"visual.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
|
| 891 |
+
"visual.merger.ln_q.bias": "model-00001-of-00004.safetensors",
|
| 892 |
+
"visual.merger.ln_q.weight": "model-00001-of-00004.safetensors",
|
| 893 |
+
"visual.merger.mlp.0.bias": "model-00001-of-00004.safetensors",
|
| 894 |
+
"visual.merger.mlp.0.weight": "model-00001-of-00004.safetensors",
|
| 895 |
+
"visual.merger.mlp.2.bias": "model-00001-of-00004.safetensors",
|
| 896 |
+
"visual.merger.mlp.2.weight": "model-00001-of-00004.safetensors",
|
| 897 |
+
"visual.patch_embed.proj.weight": "model-00001-of-00004.safetensors"
|
| 898 |
+
}
|
| 899 |
+
}
|
qwen2-vl-3d/optimizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc71ef5f1c3be4d29c372baeea36b661da41256d91477608542b21e214e357dc
|
| 3 |
+
size 165361722
|
qwen2-vl-3d/rng_state.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a50d9ce198e2733053d8d4a9703a034370d3ce84ca4a64a639d3035f21bb559
|
| 3 |
+
size 14244
|
qwen2-vl-3d/scheduler.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b46bd00f15da7bf12c17740f5c0095e4932c178c4baf4ffcf2aaa705fd44155
|
| 3 |
+
size 1064
|
qwen2-vl-3d/trainer_state.json
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": 1.1060227155685425,
|
| 3 |
+
"best_model_checkpoint": "qwen2-7b-instruct-trl-sft-housetour-l64-adapter/checkpoint-500",
|
| 4 |
+
"epoch": 3.4246575342465753,
|
| 5 |
+
"eval_steps": 500,
|
| 6 |
+
"global_step": 500,
|
| 7 |
+
"is_hyper_param_search": false,
|
| 8 |
+
"is_local_process_zero": true,
|
| 9 |
+
"is_world_process_zero": true,
|
| 10 |
+
"log_history": [
|
| 11 |
+
{
|
| 12 |
+
"epoch": 0.00684931506849315,
|
| 13 |
+
"grad_norm": 18.83251953125,
|
| 14 |
+
"learning_rate": 1.1363636363636363e-08,
|
| 15 |
+
"loss": 2.0308,
|
| 16 |
+
"step": 1
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"epoch": 0.0684931506849315,
|
| 20 |
+
"grad_norm": 16.91636848449707,
|
| 21 |
+
"learning_rate": 1.1363636363636363e-07,
|
| 22 |
+
"loss": 2.2756,
|
| 23 |
+
"step": 10
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"epoch": 0.136986301369863,
|
| 27 |
+
"grad_norm": 16.575881958007812,
|
| 28 |
+
"learning_rate": 2.2727272727272726e-07,
|
| 29 |
+
"loss": 2.2648,
|
| 30 |
+
"step": 20
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"epoch": 0.2054794520547945,
|
| 34 |
+
"grad_norm": 10.142265319824219,
|
| 35 |
+
"learning_rate": 3.4090909090909085e-07,
|
| 36 |
+
"loss": 2.103,
|
| 37 |
+
"step": 30
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"epoch": 0.273972602739726,
|
| 41 |
+
"grad_norm": 10.252799034118652,
|
| 42 |
+
"learning_rate": 4.545454545454545e-07,
|
| 43 |
+
"loss": 1.9579,
|
| 44 |
+
"step": 40
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"epoch": 0.3424657534246575,
|
| 48 |
+
"grad_norm": 10.1475830078125,
|
| 49 |
+
"learning_rate": 5.681818181818182e-07,
|
| 50 |
+
"loss": 1.7427,
|
| 51 |
+
"step": 50
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"epoch": 0.410958904109589,
|
| 55 |
+
"grad_norm": 6.072550296783447,
|
| 56 |
+
"learning_rate": 6.818181818181817e-07,
|
| 57 |
+
"loss": 1.6655,
|
| 58 |
+
"step": 60
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"epoch": 0.4794520547945205,
|
| 62 |
+
"grad_norm": 4.325073719024658,
|
| 63 |
+
"learning_rate": 7.954545454545454e-07,
|
| 64 |
+
"loss": 1.5283,
|
| 65 |
+
"step": 70
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"epoch": 0.547945205479452,
|
| 69 |
+
"grad_norm": 1.543124794960022,
|
| 70 |
+
"learning_rate": 9.09090909090909e-07,
|
| 71 |
+
"loss": 1.3892,
|
| 72 |
+
"step": 80
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"epoch": 0.6164383561643836,
|
| 76 |
+
"grad_norm": 1.4148296117782593,
|
| 77 |
+
"learning_rate": 9.999987694108851e-07,
|
| 78 |
+
"loss": 1.2604,
|
| 79 |
+
"step": 90
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"epoch": 0.684931506849315,
|
| 83 |
+
"grad_norm": 1.1791110038757324,
|
| 84 |
+
"learning_rate": 9.999556994278908e-07,
|
| 85 |
+
"loss": 1.1968,
|
| 86 |
+
"step": 100
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"epoch": 0.7534246575342466,
|
| 90 |
+
"grad_norm": 0.6725325584411621,
|
| 91 |
+
"learning_rate": 9.99851106046421e-07,
|
| 92 |
+
"loss": 1.2527,
|
| 93 |
+
"step": 110
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"epoch": 0.821917808219178,
|
| 97 |
+
"grad_norm": 0.6848694682121277,
|
| 98 |
+
"learning_rate": 9.996850021374967e-07,
|
| 99 |
+
"loss": 1.1948,
|
| 100 |
+
"step": 120
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"epoch": 0.8904109589041096,
|
| 104 |
+
"grad_norm": 0.7549965977668762,
|
| 105 |
+
"learning_rate": 9.994574081414829e-07,
|
| 106 |
+
"loss": 1.1981,
|
| 107 |
+
"step": 130
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"epoch": 0.958904109589041,
|
| 111 |
+
"grad_norm": 0.7096678614616394,
|
| 112 |
+
"learning_rate": 9.991683520655733e-07,
|
| 113 |
+
"loss": 1.1431,
|
| 114 |
+
"step": 140
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"epoch": 1.0273972602739727,
|
| 118 |
+
"grad_norm": 0.5900934934616089,
|
| 119 |
+
"learning_rate": 9.988178694803437e-07,
|
| 120 |
+
"loss": 1.1278,
|
| 121 |
+
"step": 150
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"epoch": 1.095890410958904,
|
| 125 |
+
"grad_norm": 0.5475574135780334,
|
| 126 |
+
"learning_rate": 9.98406003515375e-07,
|
| 127 |
+
"loss": 1.0992,
|
| 128 |
+
"step": 160
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"epoch": 1.1643835616438356,
|
| 132 |
+
"grad_norm": 0.7482232451438904,
|
| 133 |
+
"learning_rate": 9.979328048539456e-07,
|
| 134 |
+
"loss": 1.1284,
|
| 135 |
+
"step": 170
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"epoch": 1.2328767123287672,
|
| 139 |
+
"grad_norm": 0.5973280072212219,
|
| 140 |
+
"learning_rate": 9.973983317267942e-07,
|
| 141 |
+
"loss": 1.0852,
|
| 142 |
+
"step": 180
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"epoch": 1.3013698630136985,
|
| 146 |
+
"grad_norm": 0.6261045932769775,
|
| 147 |
+
"learning_rate": 9.968026499049549e-07,
|
| 148 |
+
"loss": 1.1337,
|
| 149 |
+
"step": 190
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"epoch": 1.36986301369863,
|
| 153 |
+
"grad_norm": 0.5769844055175781,
|
| 154 |
+
"learning_rate": 9.961458326916622e-07,
|
| 155 |
+
"loss": 1.0631,
|
| 156 |
+
"step": 200
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"epoch": 1.4383561643835616,
|
| 160 |
+
"grad_norm": 0.5357353687286377,
|
| 161 |
+
"learning_rate": 9.95427960913332e-07,
|
| 162 |
+
"loss": 1.0776,
|
| 163 |
+
"step": 210
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"epoch": 1.5068493150684932,
|
| 167 |
+
"grad_norm": 0.4853924810886383,
|
| 168 |
+
"learning_rate": 9.946491229096141e-07,
|
| 169 |
+
"loss": 1.1915,
|
| 170 |
+
"step": 220
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"epoch": 1.5753424657534247,
|
| 174 |
+
"grad_norm": 0.46966347098350525,
|
| 175 |
+
"learning_rate": 9.93809414522522e-07,
|
| 176 |
+
"loss": 1.1294,
|
| 177 |
+
"step": 230
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"epoch": 1.643835616438356,
|
| 181 |
+
"grad_norm": 0.4522402286529541,
|
| 182 |
+
"learning_rate": 9.929089390846387e-07,
|
| 183 |
+
"loss": 1.1039,
|
| 184 |
+
"step": 240
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"epoch": 1.7123287671232876,
|
| 188 |
+
"grad_norm": 0.5600599646568298,
|
| 189 |
+
"learning_rate": 9.919478074064001e-07,
|
| 190 |
+
"loss": 1.1376,
|
| 191 |
+
"step": 250
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"epoch": 1.7808219178082192,
|
| 195 |
+
"grad_norm": 0.6039386987686157,
|
| 196 |
+
"learning_rate": 9.9092613776246e-07,
|
| 197 |
+
"loss": 0.9726,
|
| 198 |
+
"step": 260
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"epoch": 1.8493150684931505,
|
| 202 |
+
"grad_norm": 0.49978339672088623,
|
| 203 |
+
"learning_rate": 9.89844055877135e-07,
|
| 204 |
+
"loss": 1.0931,
|
| 205 |
+
"step": 270
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"epoch": 1.9178082191780823,
|
| 209 |
+
"grad_norm": 0.5846579074859619,
|
| 210 |
+
"learning_rate": 9.887016949089332e-07,
|
| 211 |
+
"loss": 0.9936,
|
| 212 |
+
"step": 280
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"epoch": 1.9863013698630136,
|
| 216 |
+
"grad_norm": 0.7952700257301331,
|
| 217 |
+
"learning_rate": 9.874991954341681e-07,
|
| 218 |
+
"loss": 0.9754,
|
| 219 |
+
"step": 290
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"epoch": 2.0547945205479454,
|
| 223 |
+
"grad_norm": 0.4901112914085388,
|
| 224 |
+
"learning_rate": 9.862367054296588e-07,
|
| 225 |
+
"loss": 1.0696,
|
| 226 |
+
"step": 300
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"epoch": 2.1232876712328768,
|
| 230 |
+
"grad_norm": 0.5387308597564697,
|
| 231 |
+
"learning_rate": 9.84914380254522e-07,
|
| 232 |
+
"loss": 1.0019,
|
| 233 |
+
"step": 310
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"epoch": 2.191780821917808,
|
| 237 |
+
"grad_norm": 0.5762762427330017,
|
| 238 |
+
"learning_rate": 9.83532382631052e-07,
|
| 239 |
+
"loss": 1.033,
|
| 240 |
+
"step": 320
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"epoch": 2.26027397260274,
|
| 244 |
+
"grad_norm": 0.4994489848613739,
|
| 245 |
+
"learning_rate": 9.82090882624698e-07,
|
| 246 |
+
"loss": 0.9562,
|
| 247 |
+
"step": 330
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"epoch": 2.328767123287671,
|
| 251 |
+
"grad_norm": 0.5436204671859741,
|
| 252 |
+
"learning_rate": 9.805900576231357e-07,
|
| 253 |
+
"loss": 0.9257,
|
| 254 |
+
"step": 340
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"epoch": 2.3972602739726026,
|
| 258 |
+
"grad_norm": 0.6050807237625122,
|
| 259 |
+
"learning_rate": 9.790300923144372e-07,
|
| 260 |
+
"loss": 1.0112,
|
| 261 |
+
"step": 350
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"epoch": 2.4657534246575343,
|
| 265 |
+
"grad_norm": 0.41216209530830383,
|
| 266 |
+
"learning_rate": 9.77411178664346e-07,
|
| 267 |
+
"loss": 0.9853,
|
| 268 |
+
"step": 360
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"epoch": 2.5342465753424657,
|
| 272 |
+
"grad_norm": 0.4330998957157135,
|
| 273 |
+
"learning_rate": 9.75733515892652e-07,
|
| 274 |
+
"loss": 1.0301,
|
| 275 |
+
"step": 370
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"epoch": 2.602739726027397,
|
| 279 |
+
"grad_norm": 0.4711519479751587,
|
| 280 |
+
"learning_rate": 9.739973104486777e-07,
|
| 281 |
+
"loss": 0.9389,
|
| 282 |
+
"step": 380
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"epoch": 2.671232876712329,
|
| 286 |
+
"grad_norm": 0.5700051784515381,
|
| 287 |
+
"learning_rate": 9.722027759858714e-07,
|
| 288 |
+
"loss": 0.9781,
|
| 289 |
+
"step": 390
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"epoch": 2.73972602739726,
|
| 293 |
+
"grad_norm": 0.36050090193748474,
|
| 294 |
+
"learning_rate": 9.703501333355166e-07,
|
| 295 |
+
"loss": 1.0553,
|
| 296 |
+
"step": 400
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"epoch": 2.808219178082192,
|
| 300 |
+
"grad_norm": 0.41387563943862915,
|
| 301 |
+
"learning_rate": 9.68439610479557e-07,
|
| 302 |
+
"loss": 1.0056,
|
| 303 |
+
"step": 410
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"epoch": 2.8767123287671232,
|
| 307 |
+
"grad_norm": 0.4400821030139923,
|
| 308 |
+
"learning_rate": 9.664714425225413e-07,
|
| 309 |
+
"loss": 0.9877,
|
| 310 |
+
"step": 420
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"epoch": 2.9452054794520546,
|
| 314 |
+
"grad_norm": 0.4176802635192871,
|
| 315 |
+
"learning_rate": 9.644458716626911e-07,
|
| 316 |
+
"loss": 0.9547,
|
| 317 |
+
"step": 430
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"epoch": 3.0136986301369864,
|
| 321 |
+
"grad_norm": 0.3389221727848053,
|
| 322 |
+
"learning_rate": 9.623631471620979e-07,
|
| 323 |
+
"loss": 1.026,
|
| 324 |
+
"step": 440
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"epoch": 3.0821917808219177,
|
| 328 |
+
"grad_norm": 0.36593514680862427,
|
| 329 |
+
"learning_rate": 9.602235253160481e-07,
|
| 330 |
+
"loss": 1.0277,
|
| 331 |
+
"step": 450
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"epoch": 3.1506849315068495,
|
| 335 |
+
"grad_norm": 0.3649665117263794,
|
| 336 |
+
"learning_rate": 9.580272694214854e-07,
|
| 337 |
+
"loss": 0.9359,
|
| 338 |
+
"step": 460
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"epoch": 3.219178082191781,
|
| 342 |
+
"grad_norm": 0.34383371472358704,
|
| 343 |
+
"learning_rate": 9.557746497446085e-07,
|
| 344 |
+
"loss": 0.9152,
|
| 345 |
+
"step": 470
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"epoch": 3.287671232876712,
|
| 349 |
+
"grad_norm": 0.34904035925865173,
|
| 350 |
+
"learning_rate": 9.53465943487614e-07,
|
| 351 |
+
"loss": 0.9315,
|
| 352 |
+
"step": 480
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"epoch": 3.356164383561644,
|
| 356 |
+
"grad_norm": 0.34317487478256226,
|
| 357 |
+
"learning_rate": 9.511014347545837e-07,
|
| 358 |
+
"loss": 0.9726,
|
| 359 |
+
"step": 490
|
| 360 |
+
},
|
| 361 |
+
{
|
| 362 |
+
"epoch": 3.4246575342465753,
|
| 363 |
+
"grad_norm": 0.30464446544647217,
|
| 364 |
+
"learning_rate": 9.48681414516524e-07,
|
| 365 |
+
"loss": 0.9659,
|
| 366 |
+
"step": 500
|
| 367 |
+
},
|
| 368 |
+
{
|
| 369 |
+
"epoch": 3.4246575342465753,
|
| 370 |
+
"eval_loss": 1.1060227155685425,
|
| 371 |
+
"eval_runtime": 641.2039,
|
| 372 |
+
"eval_samples_per_second": 0.203,
|
| 373 |
+
"eval_steps_per_second": 0.203,
|
| 374 |
+
"step": 500
|
| 375 |
+
}
|
| 376 |
+
],
|
| 377 |
+
"logging_steps": 10,
|
| 378 |
+
"max_steps": 2920,
|
| 379 |
+
"num_input_tokens_seen": 0,
|
| 380 |
+
"num_train_epochs": 20,
|
| 381 |
+
"save_steps": 500,
|
| 382 |
+
"stateful_callbacks": {
|
| 383 |
+
"TrainerControl": {
|
| 384 |
+
"args": {
|
| 385 |
+
"should_epoch_stop": false,
|
| 386 |
+
"should_evaluate": false,
|
| 387 |
+
"should_log": false,
|
| 388 |
+
"should_save": true,
|
| 389 |
+
"should_training_stop": false
|
| 390 |
+
},
|
| 391 |
+
"attributes": {}
|
| 392 |
+
}
|
| 393 |
+
},
|
| 394 |
+
"total_flos": 3.5300059477526943e+18,
|
| 395 |
+
"train_batch_size": 1,
|
| 396 |
+
"trial_name": null,
|
| 397 |
+
"trial_params": null
|
| 398 |
+
}
|
residual-diffuser/args.json
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"action_weight": 1,
|
| 3 |
+
"add_extras": {
|
| 4 |
+
"_string": "<bound method Parser.add_extras of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 5 |
+
"_type": "python_object (type = method)",
|
| 6 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaD6GlFKULg=="
|
| 7 |
+
},
|
| 8 |
+
"batch_size": 1,
|
| 9 |
+
"bucket": null,
|
| 10 |
+
"clip_denoised": true,
|
| 11 |
+
"commit": "13b4d404bdbd9d0d74c08d5a51ee39f84e765830 maze2d",
|
| 12 |
+
"config": "config.maze2d",
|
| 13 |
+
"dataset": "maze2d-large-v1",
|
| 14 |
+
"device": "cuda",
|
| 15 |
+
"diffusion": "models.GaussianDiffusion",
|
| 16 |
+
"dim_mults": {
|
| 17 |
+
"_type": "tuple",
|
| 18 |
+
"_value": [
|
| 19 |
+
1,
|
| 20 |
+
4,
|
| 21 |
+
8
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
"ema_decay": 0.995,
|
| 25 |
+
"eval_fstrings": {
|
| 26 |
+
"_string": "<bound method Parser.eval_fstrings of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 27 |
+
"_type": "python_object (type = method)",
|
| 28 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaBuGlFKULg=="
|
| 29 |
+
},
|
| 30 |
+
"exp_name": "diffusion/H384_T16",
|
| 31 |
+
"generate_exp_name": {
|
| 32 |
+
"_string": "<bound method Parser.generate_exp_name of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 33 |
+
"_type": "python_object (type = method)",
|
| 34 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaEmGlFKULg=="
|
| 35 |
+
},
|
| 36 |
+
"get_commit": {
|
| 37 |
+
"_string": "<bound method Parser.get_commit of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 38 |
+
"_type": "python_object (type = method)",
|
| 39 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaBKGlFKULg=="
|
| 40 |
+
},
|
| 41 |
+
"gradient_accumulate_every": 8,
|
| 42 |
+
"horizon": 384,
|
| 43 |
+
"learning_rate": 5e-06,
|
| 44 |
+
"loader": "datasets.GoalDataset",
|
| 45 |
+
"logbase": "logs",
|
| 46 |
+
"loss_discount": 1,
|
| 47 |
+
"loss_type": "spline",
|
| 48 |
+
"loss_weights": null,
|
| 49 |
+
"max_path_length": 40000,
|
| 50 |
+
"mkdir": {
|
| 51 |
+
"_string": "<bound method Parser.mkdir of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 52 |
+
"_type": "python_object (type = method)",
|
| 53 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaA+GlFKULg=="
|
| 54 |
+
},
|
| 55 |
+
"model": "models.TemporalUnet",
|
| 56 |
+
"n_diffusion_steps": 16,
|
| 57 |
+
"n_reference": 50,
|
| 58 |
+
"n_samples": 10,
|
| 59 |
+
"n_saves": 50,
|
| 60 |
+
"n_steps_per_epoch": 60000,
|
| 61 |
+
"n_train_steps": 60000.0,
|
| 62 |
+
"normalizer": "LimitsNormalizer",
|
| 63 |
+
"predict_epsilon": true,
|
| 64 |
+
"prefix": "diffusion/",
|
| 65 |
+
"preprocess_fns": [
|
| 66 |
+
"maze2d_set_terminals"
|
| 67 |
+
],
|
| 68 |
+
"read_config": {
|
| 69 |
+
"_string": "<bound method Parser.read_config of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 70 |
+
"_type": "python_object (type = method)",
|
| 71 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaCuGlFKULg=="
|
| 72 |
+
},
|
| 73 |
+
"renderer": "utils.Maze2dRenderer",
|
| 74 |
+
"reproducibility": {
|
| 75 |
+
"command_line": "python scripts/train_tour.py",
|
| 76 |
+
"git_has_uncommitted_changes": true,
|
| 77 |
+
"git_root": "/local/home/atcelen/work/diffuser",
|
| 78 |
+
"git_url": "https://github.com/jannerm/diffuser/tree/13b4d404bdbd9d0d74c08d5a51ee39f84e765830",
|
| 79 |
+
"time": "Thu Jun 26 17:05:14 2025"
|
| 80 |
+
},
|
| 81 |
+
"sample_freq": 1000,
|
| 82 |
+
"save_diff": {
|
| 83 |
+
"_string": "<bound method Parser.save_diff of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 84 |
+
"_type": "python_object (type = method)",
|
| 85 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaFWGlFKULg=="
|
| 86 |
+
},
|
| 87 |
+
"save_freq": 2000,
|
| 88 |
+
"save_parallel": false,
|
| 89 |
+
"savepath": "logs/maze2d-large-v1/diffusion/H384_T16",
|
| 90 |
+
"set_seed": {
|
| 91 |
+
"_string": "<bound method Parser.set_seed of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
|
| 92 |
+
"_type": "python_object (type = method)",
|
| 93 |
+
"_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaAiGlFKULg=="
|
| 94 |
+
},
|
| 95 |
+
"termination_penalty": null,
|
| 96 |
+
"use_padding": false
|
| 97 |
+
}
|
residual-diffuser/dataset_config.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c441212970d6518583ad1a3d0dcb8a68690102f1dbca6bc3fbdbe3ec5b8430f8
|
| 3 |
+
size 280
|
residual-diffuser/diff.txt
ADDED
|
@@ -0,0 +1,1895 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diff --git a/config/locomotion.py b/config/locomotion.py
|
| 2 |
+
deleted file mode 100644
|
| 3 |
+
index 4410bb1..0000000
|
| 4 |
+
--- a/config/locomotion.py
|
| 5 |
+
+++ /dev/null
|
| 6 |
+
@@ -1,70 +0,0 @@
|
| 7 |
+
-import socket
|
| 8 |
+
-
|
| 9 |
+
-from diffuser.utils import watch
|
| 10 |
+
-
|
| 11 |
+
-#------------------------ base ------------------------#
|
| 12 |
+
-
|
| 13 |
+
-## automatically make experiment names for planning
|
| 14 |
+
-## by labelling folders with these args
|
| 15 |
+
-
|
| 16 |
+
-diffusion_args_to_watch = [
|
| 17 |
+
- ('prefix', ''),
|
| 18 |
+
- ('horizon', 'H'),
|
| 19 |
+
- ('n_diffusion_steps', 'T'),
|
| 20 |
+
-]
|
| 21 |
+
-
|
| 22 |
+
-base = {
|
| 23 |
+
- 'diffusion': {
|
| 24 |
+
- ## model
|
| 25 |
+
- 'model': 'models.TemporalUnet',
|
| 26 |
+
- 'diffusion': 'models.GaussianDiffusion',
|
| 27 |
+
- 'horizon': 32,
|
| 28 |
+
- 'n_diffusion_steps': 100,
|
| 29 |
+
- 'action_weight': 10,
|
| 30 |
+
- 'loss_weights': None,
|
| 31 |
+
- 'loss_discount': 1,
|
| 32 |
+
- 'predict_epsilon': False,
|
| 33 |
+
- 'dim_mults': (1, 4, 8),
|
| 34 |
+
- 'renderer': 'utils.MuJoCoRenderer',
|
| 35 |
+
-
|
| 36 |
+
- ## dataset
|
| 37 |
+
- 'loader': 'datasets.SequenceDataset',
|
| 38 |
+
- 'normalizer': 'LimitsNormalizer',
|
| 39 |
+
- 'preprocess_fns': [],
|
| 40 |
+
- 'clip_denoised': True,
|
| 41 |
+
- 'use_padding': True,
|
| 42 |
+
- 'max_path_length': 1000,
|
| 43 |
+
-
|
| 44 |
+
- ## serialization
|
| 45 |
+
- 'logbase': 'logs',
|
| 46 |
+
- 'prefix': 'diffusion/',
|
| 47 |
+
- 'exp_name': watch(diffusion_args_to_watch),
|
| 48 |
+
-
|
| 49 |
+
- ## training
|
| 50 |
+
- 'n_steps_per_epoch': 10000,
|
| 51 |
+
- 'loss_type': 'l2',
|
| 52 |
+
- 'n_train_steps': 1e6,
|
| 53 |
+
- 'batch_size': 32,
|
| 54 |
+
- 'learning_rate': 2e-4,
|
| 55 |
+
- 'gradient_accumulate_every': 2,
|
| 56 |
+
- 'ema_decay': 0.995,
|
| 57 |
+
- 'save_freq': 1000,
|
| 58 |
+
- 'sample_freq': 1000,
|
| 59 |
+
- 'n_saves': 5,
|
| 60 |
+
- 'save_parallel': False,
|
| 61 |
+
- 'n_reference': 8,
|
| 62 |
+
- 'n_samples': 2,
|
| 63 |
+
- 'bucket': None,
|
| 64 |
+
- 'device': 'cuda',
|
| 65 |
+
- },
|
| 66 |
+
-}
|
| 67 |
+
-
|
| 68 |
+
-#------------------------ overrides ------------------------#
|
| 69 |
+
-
|
| 70 |
+
-## put environment-specific overrides here
|
| 71 |
+
-
|
| 72 |
+
-halfcheetah_medium_expert_v2 = {
|
| 73 |
+
- 'diffusion': {
|
| 74 |
+
- 'horizon': 16,
|
| 75 |
+
- },
|
| 76 |
+
-}
|
| 77 |
+
diff --git a/config/maze2d.py b/config/maze2d.py
|
| 78 |
+
index a06ac7f..0a8d22a 100644
|
| 79 |
+
--- a/config/maze2d.py
|
| 80 |
+
+++ b/config/maze2d.py
|
| 81 |
+
@@ -34,11 +34,11 @@ base = {
|
| 82 |
+
'model': 'models.TemporalUnet',
|
| 83 |
+
'diffusion': 'models.GaussianDiffusion',
|
| 84 |
+
'horizon': 256,
|
| 85 |
+
- 'n_diffusion_steps': 256,
|
| 86 |
+
+ 'n_diffusion_steps': 512,
|
| 87 |
+
'action_weight': 1,
|
| 88 |
+
'loss_weights': None,
|
| 89 |
+
'loss_discount': 1,
|
| 90 |
+
- 'predict_epsilon': False,
|
| 91 |
+
+ 'predict_epsilon': True,
|
| 92 |
+
'dim_mults': (1, 4, 8),
|
| 93 |
+
'renderer': 'utils.Maze2dRenderer',
|
| 94 |
+
|
| 95 |
+
@@ -57,14 +57,14 @@ base = {
|
| 96 |
+
'exp_name': watch(diffusion_args_to_watch),
|
| 97 |
+
|
| 98 |
+
## training
|
| 99 |
+
- 'n_steps_per_epoch': 10000,
|
| 100 |
+
- 'loss_type': 'l2',
|
| 101 |
+
- 'n_train_steps': 2e6,
|
| 102 |
+
- 'batch_size': 32,
|
| 103 |
+
- 'learning_rate': 2e-4,
|
| 104 |
+
- 'gradient_accumulate_every': 2,
|
| 105 |
+
+ 'n_steps_per_epoch': 60000,
|
| 106 |
+
+ 'loss_type': 'spline',
|
| 107 |
+
+ 'n_train_steps': 6e4,
|
| 108 |
+
+ 'batch_size': 1,
|
| 109 |
+
+ 'learning_rate': 5e-6,
|
| 110 |
+
+ 'gradient_accumulate_every': 8,
|
| 111 |
+
'ema_decay': 0.995,
|
| 112 |
+
- 'save_freq': 1000,
|
| 113 |
+
+ 'save_freq': 2000,
|
| 114 |
+
'sample_freq': 1000,
|
| 115 |
+
'n_saves': 50,
|
| 116 |
+
'save_parallel': False,
|
| 117 |
+
@@ -89,7 +89,6 @@ base = {
|
| 118 |
+
'prefix': 'plans/release',
|
| 119 |
+
'exp_name': watch(plan_args_to_watch),
|
| 120 |
+
'suffix': '0',
|
| 121 |
+
-
|
| 122 |
+
'conditional': False,
|
| 123 |
+
|
| 124 |
+
## loading
|
| 125 |
+
@@ -122,10 +121,10 @@ maze2d_umaze_v1 = {
|
| 126 |
+
maze2d_large_v1 = {
|
| 127 |
+
'diffusion': {
|
| 128 |
+
'horizon': 384,
|
| 129 |
+
- 'n_diffusion_steps': 256,
|
| 130 |
+
+ 'n_diffusion_steps': 16,
|
| 131 |
+
},
|
| 132 |
+
'plan': {
|
| 133 |
+
'horizon': 384,
|
| 134 |
+
- 'n_diffusion_steps': 256,
|
| 135 |
+
+ 'n_diffusion_steps': 16,
|
| 136 |
+
},
|
| 137 |
+
}
|
| 138 |
+
diff --git a/diffuser/datasets/buffer.py b/diffuser/datasets/buffer.py
|
| 139 |
+
index 1ad2106..5991f01 100644
|
| 140 |
+
--- a/diffuser/datasets/buffer.py
|
| 141 |
+
+++ b/diffuser/datasets/buffer.py
|
| 142 |
+
@@ -9,7 +9,7 @@ class ReplayBuffer:
|
| 143 |
+
|
| 144 |
+
def __init__(self, max_n_episodes, max_path_length, termination_penalty):
|
| 145 |
+
self._dict = {
|
| 146 |
+
- 'path_lengths': np.zeros(max_n_episodes, dtype=np.int),
|
| 147 |
+
+ 'path_lengths': np.zeros(max_n_episodes, dtype=np.int_),
|
| 148 |
+
}
|
| 149 |
+
self._count = 0
|
| 150 |
+
self.max_n_episodes = max_n_episodes
|
| 151 |
+
diff --git a/diffuser/datasets/sequence.py b/diffuser/datasets/sequence.py
|
| 152 |
+
index 356c540..73c1b04 100644
|
| 153 |
+
--- a/diffuser/datasets/sequence.py
|
| 154 |
+
+++ b/diffuser/datasets/sequence.py
|
| 155 |
+
@@ -83,6 +83,7 @@ class SequenceDataset(torch.utils.data.Dataset):
|
| 156 |
+
actions = self.fields.normed_actions[path_ind, start:end]
|
| 157 |
+
|
| 158 |
+
conditions = self.get_conditions(observations)
|
| 159 |
+
+
|
| 160 |
+
trajectories = np.concatenate([actions, observations], axis=-1)
|
| 161 |
+
batch = Batch(trajectories, conditions)
|
| 162 |
+
return batch
|
| 163 |
+
diff --git a/diffuser/models/diffusion.py b/diffuser/models/diffusion.py
|
| 164 |
+
index fae4cfd..461680a 100644
|
| 165 |
+
--- a/diffuser/models/diffusion.py
|
| 166 |
+
+++ b/diffuser/models/diffusion.py
|
| 167 |
+
@@ -2,6 +2,7 @@ import numpy as np
|
| 168 |
+
import torch
|
| 169 |
+
from torch import nn
|
| 170 |
+
import pdb
|
| 171 |
+
+import matplotlib.pyplot as plt
|
| 172 |
+
|
| 173 |
+
import diffuser.utils as utils
|
| 174 |
+
from .helpers import (
|
| 175 |
+
@@ -9,6 +10,7 @@ from .helpers import (
|
| 176 |
+
extract,
|
| 177 |
+
apply_conditioning,
|
| 178 |
+
Losses,
|
| 179 |
+
+ catmull_rom_spline_with_rotation,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
class GaussianDiffusion(nn.Module):
|
| 183 |
+
@@ -26,6 +28,7 @@ class GaussianDiffusion(nn.Module):
|
| 184 |
+
betas = cosine_beta_schedule(n_timesteps)
|
| 185 |
+
alphas = 1. - betas
|
| 186 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
| 187 |
+
+ print(f"Alphas Cumprod: {alphas_cumprod}")
|
| 188 |
+
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
| 189 |
+
|
| 190 |
+
self.n_timesteps = int(n_timesteps)
|
| 191 |
+
@@ -73,7 +76,7 @@ class GaussianDiffusion(nn.Module):
|
| 192 |
+
'''
|
| 193 |
+
self.action_weight = action_weight
|
| 194 |
+
|
| 195 |
+
- dim_weights = torch.ones(self.transition_dim, dtype=torch.float32)
|
| 196 |
+
+ dim_weights = torch.ones(self.transition_dim, dtype=torch.float64)
|
| 197 |
+
|
| 198 |
+
## set loss coefficients for dimensions of observation
|
| 199 |
+
if weights_dict is None: weights_dict = {}
|
| 200 |
+
@@ -97,18 +100,16 @@ class GaussianDiffusion(nn.Module):
|
| 201 |
+
otherwise, model predicts x0 directly
|
| 202 |
+
'''
|
| 203 |
+
if self.predict_epsilon:
|
| 204 |
+
- return (
|
| 205 |
+
- extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 206 |
+
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 207 |
+
- )
|
| 208 |
+
+ return noise
|
| 209 |
+
else:
|
| 210 |
+
return noise
|
| 211 |
+
|
| 212 |
+
def q_posterior(self, x_start, x_t, t):
|
| 213 |
+
posterior_mean = (
|
| 214 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 215 |
+
- extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 216 |
+
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t[:, :, self.action_dim:]
|
| 217 |
+
)
|
| 218 |
+
+
|
| 219 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 220 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 221 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 222 |
+
@@ -129,7 +130,7 @@ class GaussianDiffusion(nn.Module):
|
| 223 |
+
def p_sample(self, x, cond, t):
|
| 224 |
+
b, *_, device = *x.shape, x.device
|
| 225 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
|
| 226 |
+
- noise = torch.randn_like(x)
|
| 227 |
+
+ noise = torch.randn_like(x[:, :, self.action_dim:])
|
| 228 |
+
# no noise when t == 0
|
| 229 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
| 230 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 231 |
+
@@ -139,22 +140,59 @@ class GaussianDiffusion(nn.Module):
|
| 232 |
+
device = self.betas.device
|
| 233 |
+
|
| 234 |
+
batch_size = shape[0]
|
| 235 |
+
- x = torch.randn(shape, device=device)
|
| 236 |
+
- x = apply_conditioning(x, cond, self.action_dim)
|
| 237 |
+
+ # x = torch.randn(shape, device=device, dtype=torch.float64)
|
| 238 |
+
+ # Extract known indices and values
|
| 239 |
+
+ known_indices = np.array(list(cond.keys()), dtype=int)
|
| 240 |
+
+
|
| 241 |
+
+ # candidate_no x batch_size x dim
|
| 242 |
+
+ known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0)
|
| 243 |
+
+ known_values = np.moveaxis(known_values, 0, 1)
|
| 244 |
+
+
|
| 245 |
+
+ # Sort the timepoints
|
| 246 |
+
+ sorted_indices = np.argsort(known_indices)
|
| 247 |
+
+ known_indices = known_indices[sorted_indices]
|
| 248 |
+
+ known_values = known_values[:, sorted_indices]
|
| 249 |
+
+
|
| 250 |
+
+ # Build the structured spline guess
|
| 251 |
+
+ catmull_spline_trajectory = np.array([
|
| 252 |
+
+ catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, shape[1])
|
| 253 |
+
+ for b in range(batch_size)
|
| 254 |
+
+ ])
|
| 255 |
+
+ catmull_spline_trajectory = torch.tensor(
|
| 256 |
+
+ catmull_spline_trajectory,
|
| 257 |
+
+ dtype=torch.float64,
|
| 258 |
+
+ device=device
|
| 259 |
+
+ )
|
| 260 |
+
+
|
| 261 |
+
+
|
| 262 |
+
+ if self.predict_epsilon:
|
| 263 |
+
+ x = torch.randn((shape[0], shape[1], self.observation_dim), device=device, dtype=torch.float64)
|
| 264 |
+
+ cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()}
|
| 265 |
+
+ is_cond = torch.zeros((shape[0], shape[1], 1), device=device, dtype=torch.float64)
|
| 266 |
+
+ is_cond[:, known_indices, :] = 1.0
|
| 267 |
+
|
| 268 |
+
if return_diffusion: diffusion = [x]
|
| 269 |
+
|
| 270 |
+
- progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
|
| 271 |
+
+ # progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
|
| 272 |
+
for i in reversed(range(0, self.n_timesteps)):
|
| 273 |
+
+ if self.predict_epsilon:
|
| 274 |
+
+ x = torch.cat([catmull_spline_trajectory, is_cond, x], dim=-1)
|
| 275 |
+
+
|
| 276 |
+
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
| 277 |
+
- x = self.p_sample(x, cond, timesteps)
|
| 278 |
+
- x = apply_conditioning(x, cond, self.action_dim)
|
| 279 |
+
+ x = self.p_sample(x, cond_residual, timesteps)
|
| 280 |
+
+
|
| 281 |
+
+ x = apply_conditioning(x, cond_residual, 0)
|
| 282 |
+
|
| 283 |
+
- progress.update({'t': i})
|
| 284 |
+
+ if return_diffusion: diffusion.append(x)
|
| 285 |
+
|
| 286 |
+
- if return_diffusion: diffusion.append(x)
|
| 287 |
+
+ x = catmull_spline_trajectory + x
|
| 288 |
+
|
| 289 |
+
- progress.close()
|
| 290 |
+
+
|
| 291 |
+
+
|
| 292 |
+
+ # Normalize the quaternions
|
| 293 |
+
+ # x[:, :, 3:7] = x[:, :, 3:7] / torch.norm(x[:, :, 3:7], dim=-1, keepdim=True)
|
| 294 |
+
+
|
| 295 |
+
+ # progress.close()
|
| 296 |
+
|
| 297 |
+
if return_diffusion:
|
| 298 |
+
return x, torch.stack(diffusion, dim=1)
|
| 299 |
+
@@ -167,7 +205,7 @@ class GaussianDiffusion(nn.Module):
|
| 300 |
+
conditions : [ (time, state), ... ]
|
| 301 |
+
'''
|
| 302 |
+
device = self.betas.device
|
| 303 |
+
- batch_size = len(cond[0])
|
| 304 |
+
+ batch_size = len(next(iter(cond.values())))
|
| 305 |
+
horizon = horizon or self.horizon
|
| 306 |
+
shape = (batch_size, horizon, self.transition_dim)
|
| 307 |
+
|
| 308 |
+
@@ -175,38 +213,106 @@ class GaussianDiffusion(nn.Module):
|
| 309 |
+
|
| 310 |
+
#------------------------------------------ training ------------------------------------------#
|
| 311 |
+
|
| 312 |
+
- def q_sample(self, x_start, t, noise=None):
|
| 313 |
+
+ def q_sample(self, x_start, t, spline=None, noise=None):
|
| 314 |
+
+ x_start_noise = x_start[:, : , :-1]
|
| 315 |
+
+ x_start_is_cond = x_start[:, :, [-1]]
|
| 316 |
+
+
|
| 317 |
+
+ if spline is None:
|
| 318 |
+
+ spline = torch.randn_like(x_start_noise)
|
| 319 |
+
if noise is None:
|
| 320 |
+
- noise = torch.randn_like(x_start)
|
| 321 |
+
+ noise = torch.randn_like(x_start_noise)
|
| 322 |
+
|
| 323 |
+
- sample = (
|
| 324 |
+
- extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 325 |
+
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 326 |
+
- )
|
| 327 |
+
+ alpha = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
|
| 328 |
+
+ oneminusalpha = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 329 |
+
+
|
| 330 |
+
+ # Weighted combination of x_0 and the spline
|
| 331 |
+
+ out = alpha * x_start_noise + oneminusalpha * noise
|
| 332 |
+
+
|
| 333 |
+
+ # Concatenate the binary feature and the spline as the conditioning
|
| 334 |
+
+ out = torch.cat([spline, x_start_is_cond, out], dim=-1)
|
| 335 |
+
|
| 336 |
+
- return sample
|
| 337 |
+
+ return out
|
| 338 |
+
|
| 339 |
+
def p_losses(self, x_start, cond, t):
|
| 340 |
+
- noise = torch.randn_like(x_start)
|
| 341 |
+
+ batch_size, horizon, _ = x_start.shape
|
| 342 |
+
+ # Extract known indices and values
|
| 343 |
+
+ known_indices = np.array(list(cond.keys()), dtype=int)
|
| 344 |
+
+
|
| 345 |
+
+ # candidate_no x batch_size x dim
|
| 346 |
+
+ known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0)
|
| 347 |
+
+ known_values = np.moveaxis(known_values, 0, 1)
|
| 348 |
+
+
|
| 349 |
+
+ # Sort the timepoints
|
| 350 |
+
+ sorted_indices = np.argsort(known_indices)
|
| 351 |
+
+ known_indices = known_indices[sorted_indices]
|
| 352 |
+
+ known_values = known_values[:, sorted_indices]
|
| 353 |
+
+
|
| 354 |
+
+ # Build your structured guess
|
| 355 |
+
+ catmull_spline_trajectory = np.array([
|
| 356 |
+
+ catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, horizon)
|
| 357 |
+
+ for b in range(batch_size)
|
| 358 |
+
+ ])
|
| 359 |
+
+ catmull_spline_trajectory = torch.tensor(
|
| 360 |
+
+ catmull_spline_trajectory,
|
| 361 |
+
+ dtype=torch.float64,
|
| 362 |
+
+ device=x_start.device
|
| 363 |
+
+ )
|
| 364 |
+
|
| 365 |
+
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| 366 |
+
- x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
|
| 367 |
+
+ # Plot the quaternions
|
| 368 |
+
+ # plt.plot(x_start[0, :, 3].cpu().numpy())
|
| 369 |
+
+ # plt.plot(catmull_spline_trajectory[0, :, 3].cpu().numpy())
|
| 370 |
+
+ # plt.legend(["x_start", "catmull_spline"])
|
| 371 |
+
+ # plt.show()
|
| 372 |
+
+ # raise Exception
|
| 373 |
+
|
| 374 |
+
- x_recon = self.model(x_noisy, cond, t)
|
| 375 |
+
- x_recon = apply_conditioning(x_recon, cond, self.action_dim)
|
| 376 |
+
|
| 377 |
+
- assert noise.shape == x_recon.shape
|
| 378 |
+
+ if not self.predict_epsilon:
|
| 379 |
+
+ # Forward diffuse with the structured trajectory
|
| 380 |
+
+ x_noisy = self.q_sample(
|
| 381 |
+
+ x_start,
|
| 382 |
+
+ t,
|
| 383 |
+
+ spline=catmull_spline_trajectory,
|
| 384 |
+
+ )
|
| 385 |
+
+ x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
|
| 386 |
+
|
| 387 |
+
- if self.predict_epsilon:
|
| 388 |
+
- loss, info = self.loss_fn(x_recon, noise)
|
| 389 |
+
+ # Reverse pass guess
|
| 390 |
+
+ x_recon = self.model(x_noisy, cond, t)
|
| 391 |
+
+ x_recon = apply_conditioning(x_recon, cond, self.action_dim)
|
| 392 |
+
+
|
| 393 |
+
+ # Then x_recon is the predicted x_0, compare to the true x_0
|
| 394 |
+
+ loss, info = self.loss_fn(x_recon, x_start, cond)
|
| 395 |
+
else:
|
| 396 |
+
- loss, info = self.loss_fn(x_recon, x_start)
|
| 397 |
+
+ residual = x_start.clone()
|
| 398 |
+
+
|
| 399 |
+
+ residual[:, :, :-1] -= catmull_spline_trajectory
|
| 400 |
+
+
|
| 401 |
+
+
|
| 402 |
+
+ cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()}
|
| 403 |
+
+
|
| 404 |
+
+ x_noisy = self.q_sample(
|
| 405 |
+
+ residual,
|
| 406 |
+
+ t,
|
| 407 |
+
+ spline=catmull_spline_trajectory,
|
| 408 |
+
+ )
|
| 409 |
+
+ x_noisy = apply_conditioning(x_noisy, cond_residual, self.action_dim)
|
| 410 |
+
+
|
| 411 |
+
+ # Reverse pass guess
|
| 412 |
+
+ x_recon = self.model(x_noisy, cond, t)
|
| 413 |
+
+ x_recon = apply_conditioning(x_recon, cond_residual, 0)
|
| 414 |
+
+
|
| 415 |
+
+ x_recon = x_recon + catmull_spline_trajectory
|
| 416 |
+
+
|
| 417 |
+
+ loss, info = self.loss_fn(x_recon, x_start[:, :, :-1], cond)
|
| 418 |
+
|
| 419 |
+
return loss, info
|
| 420 |
+
|
| 421 |
+
def loss(self, x, cond):
|
| 422 |
+
batch_size = len(x)
|
| 423 |
+
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
|
| 424 |
+
+ # t = torch.randint(1, 2, (batch_size,), device=x.device).long()
|
| 425 |
+
+ # x = x.double()
|
| 426 |
+
+ # cond = {k: v.double() for k, v in cond.items()}
|
| 427 |
+
+ # print(f"Time: {t.item()}")
|
| 428 |
+
return self.p_losses(x, cond, t)
|
| 429 |
+
|
| 430 |
+
def forward(self, cond, *args, **kwargs):
|
| 431 |
+
diff --git a/diffuser/models/helpers.py b/diffuser/models/helpers.py
|
| 432 |
+
index d39f35d..9f43ef8 100644
|
| 433 |
+
--- a/diffuser/models/helpers.py
|
| 434 |
+
+++ b/diffuser/models/helpers.py
|
| 435 |
+
@@ -1,11 +1,11 @@
|
| 436 |
+
import math
|
| 437 |
+
+import json
|
| 438 |
+
import numpy as np
|
| 439 |
+
import torch
|
| 440 |
+
import torch.nn as nn
|
| 441 |
+
import torch.nn.functional as F
|
| 442 |
+
-import einops
|
| 443 |
+
from einops.layers.torch import Rearrange
|
| 444 |
+
-import pdb
|
| 445 |
+
+from pytorch3d.transforms import quaternion_to_matrix, quaternion_to_axis_angle
|
| 446 |
+
|
| 447 |
+
import diffuser.utils as utils
|
| 448 |
+
|
| 449 |
+
@@ -30,7 +30,7 @@ class SinusoidalPosEmb(nn.Module):
|
| 450 |
+
class Downsample1d(nn.Module):
|
| 451 |
+
def __init__(self, dim):
|
| 452 |
+
super().__init__()
|
| 453 |
+
- self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
| 454 |
+
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1).to(torch.float64)
|
| 455 |
+
|
| 456 |
+
def forward(self, x):
|
| 457 |
+
return self.conv(x)
|
| 458 |
+
@@ -38,7 +38,7 @@ class Downsample1d(nn.Module):
|
| 459 |
+
class Upsample1d(nn.Module):
|
| 460 |
+
def __init__(self, dim):
|
| 461 |
+
super().__init__()
|
| 462 |
+
- self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
| 463 |
+
+ self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1).to(torch.float64)
|
| 464 |
+
|
| 465 |
+
def forward(self, x):
|
| 466 |
+
return self.conv(x)
|
| 467 |
+
@@ -52,9 +52,9 @@ class Conv1dBlock(nn.Module):
|
| 468 |
+
super().__init__()
|
| 469 |
+
|
| 470 |
+
self.block = nn.Sequential(
|
| 471 |
+
- nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
| 472 |
+
+ nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2).to(torch.float64),
|
| 473 |
+
Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
| 474 |
+
- nn.GroupNorm(n_groups, out_channels),
|
| 475 |
+
+ nn.GroupNorm(n_groups, out_channels).to(torch.float64),
|
| 476 |
+
Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
| 477 |
+
nn.Mish(),
|
| 478 |
+
)
|
| 479 |
+
@@ -72,7 +72,7 @@ def extract(a, t, x_shape):
|
| 480 |
+
out = a.gather(-1, t)
|
| 481 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 482 |
+
|
| 483 |
+
-def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
|
| 484 |
+
+def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float64):
|
| 485 |
+
"""
|
| 486 |
+
cosine schedule
|
| 487 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 488 |
+
@@ -157,9 +157,979 @@ class ValueL2(ValueLoss):
|
| 489 |
+
def _loss(self, pred, targ):
|
| 490 |
+
return F.mse_loss(pred, targ, reduction='none')
|
| 491 |
+
|
| 492 |
+
+class GeodesicL2Loss(nn.Module):
|
| 493 |
+
+ def __init__(self, *args):
|
| 494 |
+
+ super().__init__()
|
| 495 |
+
+ pass
|
| 496 |
+
+
|
| 497 |
+
+ def _loss(self, pred, targ):
|
| 498 |
+
+ # Compute L2 loss for the first three dimensions
|
| 499 |
+
+ l2_loss = F.mse_loss(pred[..., :3], targ[..., :3], reduction='mean')
|
| 500 |
+
+
|
| 501 |
+
+ # Normalize to unit quaternions for the last four dimensions
|
| 502 |
+
+ pred_quat = pred[..., 3:] / pred[..., 3:].norm(dim=-1, keepdim=True)
|
| 503 |
+
+ targ_quat = targ[..., 3:] / targ[..., 3:].norm(dim=-1, keepdim=True)
|
| 504 |
+
+
|
| 505 |
+
+ assert not torch.isnan(pred_quat).any(), "Pred Quat has NaNs"
|
| 506 |
+
+ assert not torch.isnan(targ_quat).any(), "Targ Quat has NaNs"
|
| 507 |
+
+
|
| 508 |
+
+ # Compute dot product for the quaternions
|
| 509 |
+
+ dot_product = torch.sum(pred_quat * targ_quat, dim=-1)
|
| 510 |
+
+ dot_product = torch.clamp(torch.abs(dot_product), -1.0, 1.0)
|
| 511 |
+
+
|
| 512 |
+
+ # Compute geodesic loss for the quaternions
|
| 513 |
+
+ geodesic_loss = 2 * torch.acos(dot_product).mean()
|
| 514 |
+
+
|
| 515 |
+
+ assert not torch.isnan(geodesic_loss).any(), "Geodesic Loss has NaNs"
|
| 516 |
+
+ assert not torch.isnan(l2_loss).any(), "L2 Loss has NaNs"
|
| 517 |
+
+
|
| 518 |
+
+ return l2_loss + geodesic_loss, l2_loss, geodesic_loss
|
| 519 |
+
+
|
| 520 |
+
+ def forward(self, pred, targ):
|
| 521 |
+
+ loss, l2, geodesic = self._loss(pred, targ)
|
| 522 |
+
+
|
| 523 |
+
+ info = {
|
| 524 |
+
+ 'l2': l2.item(),
|
| 525 |
+
+ 'geodesic': geodesic.item(),
|
| 526 |
+
+ }
|
| 527 |
+
+
|
| 528 |
+
+ return loss, info
|
| 529 |
+
+
|
| 530 |
+
+class RotationTranslationLoss(nn.Module):
|
| 531 |
+
+ def __init__(self, *args):
|
| 532 |
+
+ super().__init__()
|
| 533 |
+
+ pass
|
| 534 |
+
+
|
| 535 |
+
+ def _loss(self, pred, targ, cond=None):
|
| 536 |
+
+
|
| 537 |
+
+ # Make sure the dtype is float64
|
| 538 |
+
+ pred = pred.to(torch.float64)
|
| 539 |
+
+ targ = targ.to(torch.float64)
|
| 540 |
+
+
|
| 541 |
+
+ eps = 1e-8
|
| 542 |
+
+
|
| 543 |
+
+ pred_trans = pred[..., :3]
|
| 544 |
+
+ pred_quat = pred[..., 3:7]
|
| 545 |
+
+ targ_trans = targ[..., :3]
|
| 546 |
+
+ targ_quat = targ[..., 3:7]
|
| 547 |
+
+
|
| 548 |
+
+ l2_loss = F.mse_loss(pred_trans, targ_trans, reduction='mean')
|
| 549 |
+
+
|
| 550 |
+
+ # Calculate the geodesic loss
|
| 551 |
+
+ pred_n = pred_quat.norm(dim=-1, keepdim=True).clamp(min=eps)
|
| 552 |
+
+ targ_n = targ_quat.norm(dim=-1, keepdim=True).clamp(min=eps)
|
| 553 |
+
+
|
| 554 |
+
+ pred_quat_norm = pred_quat / pred_n
|
| 555 |
+
+ targ_quat_norm = targ_quat / targ_n
|
| 556 |
+
+
|
| 557 |
+
+
|
| 558 |
+
+ dot_product = torch.sum(pred_quat_norm * targ_quat_norm, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
|
| 559 |
+
+ quaternion_dist = 1 - (dot_product ** 2).mean()
|
| 560 |
+
+
|
| 561 |
+
+ # Calculate the rotation error
|
| 562 |
+
+ pred_rot = quaternion_to_matrix(pred_quat_norm).reshape(-1, 3, 3)
|
| 563 |
+
+ targ_rot = quaternion_to_matrix(targ_quat_norm).reshape(-1, 3, 3)
|
| 564 |
+
+
|
| 565 |
+
+ r2r1 = pred_rot @ targ_rot.permute(0, 2, 1)
|
| 566 |
+
+ trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1)
|
| 567 |
+
+ trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps)
|
| 568 |
+
+ geodesic_loss = torch.acos(trace).mean()
|
| 569 |
+
+
|
| 570 |
+
+ # Add a smoothness and acceleration term to the positions and quaternions
|
| 571 |
+
+ alpha = 1.0
|
| 572 |
+
+ smoothness_loss = F.mse_loss(pred[:, 1:, :7].reshape(-1, 7), pred[:, :-1, :7].reshape(-1, 7), reduction='mean')
|
| 573 |
+
+ acceleration_loss = F.mse_loss(pred[:, 2:, :7].reshape(-1, 7), 2 * pred[:, 1:-1, :7].reshape(-1, 7) - pred[:, :-2, :7].reshape(-1, 7), reduction='mean')
|
| 574 |
+
+
|
| 575 |
+
+ l2_multiplier = 10.0
|
| 576 |
+
+
|
| 577 |
+
+ loss = l2_multiplier * l2_loss + quaternion_dist + geodesic_loss + alpha * (smoothness_loss + acceleration_loss)
|
| 578 |
+
+
|
| 579 |
+
+ dtw = DynamicTimeWarpingLoss()
|
| 580 |
+
+ dtw_loss, _ = dtw.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
|
| 581 |
+
+
|
| 582 |
+
+ hausdorff = HausdorffDistanceLoss()
|
| 583 |
+
+ hausdorff_loss, _ = hausdorff.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
|
| 584 |
+
+
|
| 585 |
+
+ frec = FrechetDistanceLoss()
|
| 586 |
+
+ frechet_loss, _ = frec.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
|
| 587 |
+
+
|
| 588 |
+
+ chamfer = ChamferDistanceLoss()
|
| 589 |
+
+ chamfer_loss, _ = chamfer.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
|
| 590 |
+
+
|
| 591 |
+
+ return loss, l2_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss
|
| 592 |
+
+
|
| 593 |
+
+
|
| 594 |
+
+ def forward(self, pred, targ, cond=None):
|
| 595 |
+
+ loss, err_t, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond)
|
| 596 |
+
+
|
| 597 |
+
+ info = {
|
| 598 |
+
+ 'rot. error': err_r.item(),
|
| 599 |
+
+ 'geodesic error': err_geo.item(),
|
| 600 |
+
+ 'trans. error': err_t.item(),
|
| 601 |
+
+ 'dtw': err_dtw.item(),
|
| 602 |
+
+ 'hausdorff': err_hausdorff.item(),
|
| 603 |
+
+ 'frechet': err_frechet.item(),
|
| 604 |
+
+ 'chamfer': err_chamfer.item(),
|
| 605 |
+
+ }
|
| 606 |
+
+
|
| 607 |
+
+ return loss, info
|
| 608 |
+
+
|
| 609 |
+
+class SplineLoss(nn.Module):
|
| 610 |
+
+ def __init__(self, *args):
|
| 611 |
+
+ super().__init__()
|
| 612 |
+
+ self.scales = json.load(open('scene_scale.json'))
|
| 613 |
+
+
|
| 614 |
+
+ def compute_spline_coeffs(self, trans):
|
| 615 |
+
+ p0 = trans[:, :-3, :]
|
| 616 |
+
+ p1 = trans[:, 1:-2, :]
|
| 617 |
+
+ p2 = trans[:, 2:-1, :]
|
| 618 |
+
+ p3 = trans[:, 3:, :]
|
| 619 |
+
+
|
| 620 |
+
+ # Tangent approximations
|
| 621 |
+
+ m1 = 0.5 * (-p0 + p2)
|
| 622 |
+
+ m2 = 0.5 * (-p1 + p3)
|
| 623 |
+
+
|
| 624 |
+
+ # Cubic spline coefficients for each dimension
|
| 625 |
+
+ a = (2 * p1 - 2 * p2 + m1 + m2)
|
| 626 |
+
+ b = (-3 * p1 + 3 * p2 - 2 * m1 - m2)
|
| 627 |
+
+ c = (m1)
|
| 628 |
+
+ d = (p1)
|
| 629 |
+
+
|
| 630 |
+
+ return torch.stack([a, b, c, d], dim=-1)
|
| 631 |
+
+
|
| 632 |
+
+ def q_normalize(self, q):
|
| 633 |
+
+ return q / q.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12)
|
| 634 |
+
+
|
| 635 |
+
+ def q_conjugate(self, q):
|
| 636 |
+
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
|
| 637 |
+
+ return torch.stack([w, -x, -y, -z], dim=-1)
|
| 638 |
+
+
|
| 639 |
+
+ def q_multiply(self, q1, q2):
|
| 640 |
+
+ """
|
| 641 |
+
+ q1*q2.
|
| 642 |
+
+ """
|
| 643 |
+
+ w1, x1, y1, z1 = q1.unbind(-1)
|
| 644 |
+
+ w2, x2, y2, z2 = q2.unbind(-1)
|
| 645 |
+
+ w = w1*w2 - x1*x2 - y1*y2 - z1*z2
|
| 646 |
+
+ x = w1*x2 + x1*w2 + y1*z2 - z1*y2
|
| 647 |
+
+ y = w1*y2 - x1*z2 + y1*w2 + z1*x2
|
| 648 |
+
+ z = w1*z2 + x1*y2 - y1*x2 + z1*w2
|
| 649 |
+
+ return torch.stack([w, x, y, z], dim=-1)
|
| 650 |
+
+
|
| 651 |
+
+ def q_inverse(self, q):
|
| 652 |
+
+ return self.q_conjugate(self.q_normalize(q))
|
| 653 |
+
+
|
| 654 |
+
+ def q_log(self, q):
|
| 655 |
+
+ """
|
| 656 |
+
+ Quaternion logarithm for a unit quaternion
|
| 657 |
+
+ Only returns the imaginary part
|
| 658 |
+
+ """
|
| 659 |
+
+ q = self.q_normalize(q)
|
| 660 |
+
+ w = q[..., 0]
|
| 661 |
+
+ xyz = q[..., 1:] # shape [..., 3]
|
| 662 |
+
+ mag_v = xyz.norm(p=2, dim=-1)
|
| 663 |
+
+ eps = 1e-12
|
| 664 |
+
+ angle = torch.acos(w.clamp(-1.0 + eps, 1.0 - eps))
|
| 665 |
+
+
|
| 666 |
+
+ # We do a safe-guard against zero for sin(angle)
|
| 667 |
+
+ small_mask = (mag_v < 1e-12) | (angle < 1e-12)
|
| 668 |
+
+ # Where small_mask is True => near identity => log(q) ~ 0
|
| 669 |
+
+ log_val = torch.zeros_like(xyz)
|
| 670 |
+
+
|
| 671 |
+
+ # Normal case
|
| 672 |
+
+ scale = angle / mag_v.clamp(min=1e-12)
|
| 673 |
+
+ normal_case = scale.unsqueeze(-1) * xyz
|
| 674 |
+
+
|
| 675 |
+
+ log_val = torch.where(
|
| 676 |
+
+ small_mask.unsqueeze(-1),
|
| 677 |
+
+ torch.zeros_like(xyz),
|
| 678 |
+
+ normal_case
|
| 679 |
+
+ )
|
| 680 |
+
+ return log_val
|
| 681 |
+
+
|
| 682 |
+
+ def q_exp(self, v):
|
| 683 |
+
+ """
|
| 684 |
+
+ Quaternion exponential
|
| 685 |
+
+ """
|
| 686 |
+
+ norm_v = v.norm(p=2, dim=-1)
|
| 687 |
+
+ small_mask = norm_v < 1e-12
|
| 688 |
+
+
|
| 689 |
+
+ w = torch.cos(norm_v)
|
| 690 |
+
+ sin_v = torch.sin(norm_v)
|
| 691 |
+
+ scale = torch.where(
|
| 692 |
+
+ small_mask,
|
| 693 |
+
+ torch.zeros_like(norm_v), # if zero, sin(0)/0 => 0
|
| 694 |
+
+ sin_v / norm_v.clamp(min=1e-12)
|
| 695 |
+
+ )
|
| 696 |
+
+ xyz = scale.unsqueeze(-1) * v
|
| 697 |
+
+
|
| 698 |
+
+ # For small angles, we approximate cos(norm_v) ~ 1, sin(norm_v)/norm_v ~ 1
|
| 699 |
+
+ w = torch.where(
|
| 700 |
+
+ small_mask,
|
| 701 |
+
+ torch.ones_like(w),
|
| 702 |
+
+ w
|
| 703 |
+
+ )
|
| 704 |
+
+ return torch.cat([w.unsqueeze(-1), xyz], dim=-1)
|
| 705 |
+
+
|
| 706 |
+
+ def q_slerp(self, q1, q2, t):
|
| 707 |
+
+ """
|
| 708 |
+
+ Spherical linear interpolation from q1 to q2 at t in [0,1].
|
| 709 |
+
+ Both q1, q2 assumed normalized.
|
| 710 |
+
+ q1, q2, t can be 1D or broadcastable shapes, but typically 1D.
|
| 711 |
+
+ """
|
| 712 |
+
+ q1 = self.q_normalize(q1)
|
| 713 |
+
+ q2 = self.q_normalize(q2)
|
| 714 |
+
+ dot = (q1 * q2).sum(dim=-1, keepdim=True) # the dot product
|
| 715 |
+
+
|
| 716 |
+
+ eps = 1e-12
|
| 717 |
+
+ dot = dot.clamp(-1.0 + eps, 1.0 - eps)
|
| 718 |
+
+
|
| 719 |
+
+ flip_mask = dot < 0.0
|
| 720 |
+
+ if flip_mask.any():
|
| 721 |
+
+ q2 = torch.where(flip_mask, -q2, q2)
|
| 722 |
+
+ dot = torch.where(flip_mask, -dot, dot)
|
| 723 |
+
+
|
| 724 |
+
+ # If they're very close, do a simple linear interpolation
|
| 725 |
+
+ close_mask = dot.squeeze(-1) > 0.9995
|
| 726 |
+
+ # Using an epsilon to avoid potential issues close to 1.0
|
| 727 |
+
+
|
| 728 |
+
+ # Branch 1: Very close
|
| 729 |
+
+ # linear LERP
|
| 730 |
+
+ lerp_val = (1.0 - t) * q1 + t * q2
|
| 731 |
+
+ lerp_val = self.q_normalize(lerp_val)
|
| 732 |
+
+
|
| 733 |
+
+ # Branch 2: Standard SLERP
|
| 734 |
+
+ theta_0 = torch.acos(dot)
|
| 735 |
+
+ sin_theta_0 = torch.sin(theta_0)
|
| 736 |
+
+ theta = theta_0 * t
|
| 737 |
+
+ s1 = torch.sin(theta_0 - theta) / sin_theta_0.clamp(min=1e-12)
|
| 738 |
+
+ s2 = torch.sin(theta) / sin_theta_0.clamp(min=1e-12)
|
| 739 |
+
+ slerp_val = s1 * q1 + s2 * q2
|
| 740 |
+
+ slerp_val = self.q_normalize(slerp_val)
|
| 741 |
+
+
|
| 742 |
+
+ # Combine
|
| 743 |
+
+ return torch.where(
|
| 744 |
+
+ close_mask.unsqueeze(-1),
|
| 745 |
+
+ lerp_val,
|
| 746 |
+
+ slerp_val
|
| 747 |
+
+ )
|
| 748 |
+
+
|
| 749 |
+
+ def compute_uniform_tangent(self, q_im1, q_i, q_ip1):
|
| 750 |
+
+ """
|
| 751 |
+
+ Computes a 'Catmull–Rom-like' tangent T_i for quaternion q_i,
|
| 752 |
+
+ given neighbors q_im1, q_i, q_ip1.
|
| 753 |
+
+
|
| 754 |
+
+ T_i = q_i * exp( -0.25 * [ log(q_i^-1 q_ip1) + log(q_i^-1 q_im1) ] )
|
| 755 |
+
+ """
|
| 756 |
+
+ q_im1 = self.q_normalize(q_im1)
|
| 757 |
+
+ q_i = self.q_normalize(q_i)
|
| 758 |
+
+ q_ip1 = self.q_normalize(q_ip1)
|
| 759 |
+
+
|
| 760 |
+
+ inv_qi = self.q_inverse(q_i)
|
| 761 |
+
+ r1 = self.q_multiply(inv_qi, q_ip1)
|
| 762 |
+
+ r2 = self.q_multiply(inv_qi, q_im1)
|
| 763 |
+
+
|
| 764 |
+
+ lr1 = self.q_log(r1)
|
| 765 |
+
+ lr2 = self.q_log(r2)
|
| 766 |
+
+
|
| 767 |
+
+ m = -0.25 * (lr1 + lr2)
|
| 768 |
+
+ exp_m = self.q_exp(m)
|
| 769 |
+
+ return self.q_multiply(q_i, exp_m)
|
| 770 |
+
+
|
| 771 |
+
+ def compute_all_uniform_tangents(self, quats):
|
| 772 |
+
+ """
|
| 773 |
+
+ Vectorized version that computes tangents T_i for all keyframe quaternions at once.
|
| 774 |
+
+ quats shape: [N,4], N >= 2
|
| 775 |
+
+ Returns shape [N,4].
|
| 776 |
+
+ """
|
| 777 |
+
+ q_im1 = torch.cat([quats[[0]], quats[:-1]], dim=0) # q_im1[0] = q0
|
| 778 |
+
+ q_ip1 = torch.cat([quats[1:], quats[[-1]]], dim=0) # q_ip1[N-1]= q_{N-1}
|
| 779 |
+
+
|
| 780 |
+
+ return self.compute_uniform_tangent(q_im1, quats, q_ip1)
|
| 781 |
+
+
|
| 782 |
+
+ def squad(self, q0, a, b, q1, t):
|
| 783 |
+
+ """
|
| 784 |
+
+ Shoemake's "squad" interpolation for quaternion splines:
|
| 785 |
+
+ squad(q0, a, b, q1; t) = slerp( slerp(q0, q1; t),
|
| 786 |
+
+ slerp(a, b; t),
|
| 787 |
+
+ 2t(1-t) )
|
| 788 |
+
+ where a, b are tangential control quaternions for q0, q1.
|
| 789 |
+
+ """
|
| 790 |
+
+ s1 = self.q_slerp(q0, q1, t)
|
| 791 |
+
+ s2 = self.q_slerp(a, b, t)
|
| 792 |
+
+ alpha = 2.0*t*(1.0 - t)
|
| 793 |
+
+ return self.q_slerp(s1, s2, alpha)
|
| 794 |
+
+
|
| 795 |
+
+ def uniform_cr_spline(self, quats, num_samples_per_segment=10):
|
| 796 |
+
+ """
|
| 797 |
+
+ Given a list of keyframe quaternions quats (each a torch 1D tensor [4]),
|
| 798 |
+
+ compute a "Uniform Catmull–Rom–like" quaternion spline through them.
|
| 799 |
+
+
|
| 800 |
+
+ Returns:
|
| 801 |
+
+ A list (Python list) of interpolated quaternions (torch tensors),
|
| 802 |
+
+ including all segment endpoints.
|
| 803 |
+
+
|
| 804 |
+
+ Each interior qi gets a tangent T_i using neighbors q_{i-1}, q_i, q_{i+1}.
|
| 805 |
+
+ For boundary tangents, we replicate the end quaternions.
|
| 806 |
+
+ """
|
| 807 |
+
+ n = quats.shape[0]
|
| 808 |
+
+ if n < 2:
|
| 809 |
+
+ return quats.unsqueeze(0) # not enough quats to interpolate
|
| 810 |
+
+
|
| 811 |
+
+ # Precompute tangents
|
| 812 |
+
+ tangents = self.compute_all_uniform_tangents(quats)
|
| 813 |
+
+
|
| 814 |
+
+ # Interpolate each segment [qi, q_{i+1}]
|
| 815 |
+
+ q0 = quats[:-1].unsqueeze(1)
|
| 816 |
+
+ q1 = quats[1:].unsqueeze(1)
|
| 817 |
+
+ a = tangents[:-1].unsqueeze(1)
|
| 818 |
+
+ b = tangents[1:].unsqueeze(1)
|
| 819 |
+
+
|
| 820 |
+
+ t_vals = torch.linspace(0.0, 1.0, num_samples_per_segment, device=quats.device, dtype=quats.dtype)
|
| 821 |
+
+ t_vals = t_vals.view(1, -1, 1)
|
| 822 |
+
+
|
| 823 |
+
+ out = self.squad(q0, a, b, q1, t_vals)
|
| 824 |
+
+ return out
|
| 825 |
+
+
|
| 826 |
+
+
|
| 827 |
+
+ def forward(self, pred, targ, cond=None, scene_id=None, norm_params=None):
|
| 828 |
+
+ loss, err_t, err_smooth, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond, scene_id, norm_params)
|
| 829 |
+
+
|
| 830 |
+
+ info = {
|
| 831 |
+
+ 'trans. error': err_t.item(),
|
| 832 |
+
+ 'smoothness error': err_smooth.item(),
|
| 833 |
+
+ # 'dtw': err_dtw.item(),
|
| 834 |
+
+ # 'hausdorff': err_hausdorff.item(),
|
| 835 |
+
+ # 'frechet': err_frechet.item(),
|
| 836 |
+
+ # 'chamfer': err_chamfer.item(),
|
| 837 |
+
+ 'quat. dist.': err_r.item(),
|
| 838 |
+
+ 'geodesic dist.': err_geo.item(),
|
| 839 |
+
+ }
|
| 840 |
+
+
|
| 841 |
+
+ return loss, info
|
| 842 |
+
+
|
| 843 |
+
+ def _loss(self, pred, targ, cond=None, scene_id=None, norm_params=None):
|
| 844 |
+
+ def poly_eval(coeffs, x):
|
| 845 |
+
+ """
|
| 846 |
+
+ Evaluates a polynomial (with highest-degree term first) at points x.
|
| 847 |
+
+ coeffs: 2D tensor of shape [num_polynomials, degree + 1], highest-degree term first.
|
| 848 |
+
+ x: 1D tensor of points at which to evaluate the polynomial.
|
| 849 |
+
+ Returns:
|
| 850 |
+
+ 2D tensor of shape [num_polynomials, len(x)], containing p(x).
|
| 851 |
+
+ """
|
| 852 |
+
+ x_powers = torch.stack([x**i for i in range(coeffs.shape[-1] - 1, -1, -1)], dim=-1)
|
| 853 |
+
+ x_powers = x_powers.to(torch.float64).to(coeffs.device)
|
| 854 |
+
+ y = torch.matmul(coeffs, x_powers.T)
|
| 855 |
+
+ return y
|
| 856 |
+
+
|
| 857 |
+
+ # Make sure the dtype is float64
|
| 858 |
+
+ pred = pred.to(torch.float64)
|
| 859 |
+
+ targ = targ.to(torch.float64)
|
| 860 |
+
+
|
| 861 |
+
+ # Rescale the translations
|
| 862 |
+
+ if scene_id is not None and norm_params is not None:
|
| 863 |
+
+ scene_id = scene_id.item()
|
| 864 |
+
+ scene_scale = self.scales[str(scene_id)]
|
| 865 |
+
+ scene_scale = norm_params['scale'][0] * scene_scale
|
| 866 |
+
+ pred[..., :3] = pred[..., :3] * scene_scale
|
| 867 |
+
+ targ[..., :3] = targ[..., :3] * scene_scale
|
| 868 |
+
+ # print(pred[..., :3].max(), targ[..., :3].max())
|
| 869 |
+
+
|
| 870 |
+
+ # We only consider interpolated points for loss calculation
|
| 871 |
+
+ candidate_idxs = sorted(cond.keys())
|
| 872 |
+
+ pred = pred[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :]
|
| 873 |
+
+ targ = targ[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :]
|
| 874 |
+
+
|
| 875 |
+
+ pred_trans = pred[..., :3]
|
| 876 |
+
+ pred_quat = pred[..., 3:7]
|
| 877 |
+
+ targ_trans = targ[..., :3]
|
| 878 |
+
+ targ_quat = targ[..., 3:7]
|
| 879 |
+
+
|
| 880 |
+
+ pred_coeffs = self.compute_spline_coeffs(pred_trans)
|
| 881 |
+
+ targ_coeffs = self.compute_spline_coeffs(targ_trans)
|
| 882 |
+
+
|
| 883 |
+
+ n_points = 2000
|
| 884 |
+
+
|
| 885 |
+
+ # Distribute sample points among intervals
|
| 886 |
+
+ dists = torch.norm(targ_trans[:, 1:, :] - targ_trans[:, :-1, :], dim=-1).reshape(-1)
|
| 887 |
+
+ dists_c = torch.zeros(len(candidate_idxs) - 1, device=pred.device)
|
| 888 |
+
+ for i in range(len(candidate_idxs) - 1):
|
| 889 |
+
+ dists_c[i] = dists[candidate_idxs[i]:candidate_idxs[i+1]].sum()
|
| 890 |
+
+
|
| 891 |
+
+ weights_c = dists_c / dists_c.sum()
|
| 892 |
+
+ scaled_c = weights_c * n_points
|
| 893 |
+
+ points_c = torch.floor(scaled_c).int()
|
| 894 |
+
+
|
| 895 |
+
+ while points_c.sum() < n_points:
|
| 896 |
+
+ idx = torch.argmax(scaled_c - points_c)
|
| 897 |
+
+ points_c[idx] += 1
|
| 898 |
+
+
|
| 899 |
+
+ # Calculate the spline loss
|
| 900 |
+
+ sample_points = 50
|
| 901 |
+
+ x = torch.linspace(0, 1, sample_points, device=pred.device)
|
| 902 |
+
+ pred_spline = poly_eval(pred_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3)
|
| 903 |
+
+ targ_spline = poly_eval(targ_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3)
|
| 904 |
+
+
|
| 905 |
+
+ indexes = []
|
| 906 |
+
+ start_idx = candidate_idxs[0]
|
| 907 |
+
+ for c, (idx_i0, idx_i1) in enumerate(zip(candidate_idxs[:-1], candidate_idxs[1:])):
|
| 908 |
+
+ p = points_c[c]
|
| 909 |
+
+ total_dist = dists_c[c]
|
| 910 |
+
+ dist_arr = dists[idx_i0 - start_idx : idx_i1 - start_idx]
|
| 911 |
+
+
|
| 912 |
+
+ step_distances = (dist_arr / sample_points).repeat_interleave(sample_points)
|
| 913 |
+
+ cumul_distances = step_distances.cumsum(dim=0)
|
| 914 |
+
+
|
| 915 |
+
+ dist_per_pick = total_dist / p
|
| 916 |
+
+ pick_targets = torch.arange(1, p + 1, device=dists.device) * dist_per_pick
|
| 917 |
+
+
|
| 918 |
+
+ pick_idxs = torch.searchsorted(cumul_distances, pick_targets, right=True)
|
| 919 |
+
+ pick_idxs = torch.clamp(pick_idxs, max=len(cumul_distances) - 1)
|
| 920 |
+
+
|
| 921 |
+
+
|
| 922 |
+
+ indexes_1d = torch.zeros_like(step_distances)
|
| 923 |
+
+ indexes_1d[pick_idxs] = 1
|
| 924 |
+
+
|
| 925 |
+
+ indexes_2d = indexes_1d.view(len(dist_arr), sample_points)
|
| 926 |
+
+
|
| 927 |
+
+ indexes.append(indexes_2d)
|
| 928 |
+
+
|
| 929 |
+
+ indexes = torch.cat(indexes)[1: -1] # The first and last candidates don't have spline representations
|
| 930 |
+
+
|
| 931 |
+
+ indexes_trans = torch.stack([indexes for _ in range(3)], dim=-1)
|
| 932 |
+
+ indexes_quat = torch.stack([indexes for _ in range(4)], dim=-1)
|
| 933 |
+
+
|
| 934 |
+
+ indexes_trans = indexes_trans.to(torch.bool)
|
| 935 |
+
+ indexes_quat = indexes_quat.to(torch.bool)
|
| 936 |
+
+
|
| 937 |
+
+ pred_trans_selected_values = pred_spline[indexes_trans]
|
| 938 |
+
+ targ_trans_selected_values = targ_spline[indexes_trans]
|
| 939 |
+
+
|
| 940 |
+
+ pred_trans_selected_values = pred_trans_selected_values.reshape(-1, 3)
|
| 941 |
+
+ targ_trans_selected_values = targ_trans_selected_values.reshape(-1, 3)
|
| 942 |
+
+
|
| 943 |
+
+ # Calculate the loss for quaternions
|
| 944 |
+
+ pred_quat = pred_quat / pred_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 945 |
+
+ targ_quat = targ_quat / targ_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 946 |
+
+
|
| 947 |
+
+ targ_quat_spline = self.uniform_cr_spline(targ_quat.reshape(-1, 4), num_samples_per_segment=sample_points)
|
| 948 |
+
+ pred_quat_spline = self.uniform_cr_spline(pred_quat.reshape(-1, 4), num_samples_per_segment=sample_points)
|
| 949 |
+
+
|
| 950 |
+
+
|
| 951 |
+
+ targ_quat_spline = targ_quat_spline[1:-1]
|
| 952 |
+
+ pred_quat_spline = pred_quat_spline[1:-1]
|
| 953 |
+
+
|
| 954 |
+
+
|
| 955 |
+
+ pred_quat_selected_values = pred_quat_spline[indexes_quat]
|
| 956 |
+
+ targ_quat_selected_values = targ_quat_spline[indexes_quat]
|
| 957 |
+
+
|
| 958 |
+
+ pred_quat_selected_values = pred_quat_selected_values.reshape(-1, 4)
|
| 959 |
+
+ targ_quat_selected_values = targ_quat_selected_values.reshape(-1, 4)
|
| 960 |
+
+
|
| 961 |
+
+ # Calculate the geodesic loss
|
| 962 |
+
+ pred_rot = quaternion_to_matrix(pred_quat_selected_values).reshape(-1, 3, 3)
|
| 963 |
+
+ targ_rot = quaternion_to_matrix(targ_quat_selected_values).reshape(-1, 3, 3)
|
| 964 |
+
+
|
| 965 |
+
+ eps = 1e-12
|
| 966 |
+
+ r2r1 = pred_rot @ targ_rot.permute(0, 2, 1)
|
| 967 |
+
+ trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1)
|
| 968 |
+
+ trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps)
|
| 969 |
+
+ geodesic_loss = torch.acos(trace).mean()
|
| 970 |
+
+
|
| 971 |
+
+ # Calculate the rotation error
|
| 972 |
+
+ dot_product = torch.sum(pred_quat_selected_values * targ_quat_selected_values, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
|
| 973 |
+
+ quaternion_dist = 1 - (dot_product ** 2).mean()
|
| 974 |
+
+
|
| 975 |
+
+ # Calculate the L2 loss
|
| 976 |
+
+ l2_loss = F.mse_loss(pred_trans_selected_values, targ_trans_selected_values, reduction='mean')
|
| 977 |
+
+
|
| 978 |
+
+ # Calculate the smoothness loss for translation and quaternion
|
| 979 |
+
+ smoothness_multiplier = 10 ** 2 # Empirically determined multiplier for smoothness loss
|
| 980 |
+
+ weight_acceleration = 0.1
|
| 981 |
+
+ weight_jerk = 0.05
|
| 982 |
+
+
|
| 983 |
+
+ pos_acc = pred_trans_selected_values[2:, :] - 2 * pred_trans_selected_values[1:-1, :] + pred_trans_selected_values[:-2, :]
|
| 984 |
+
+ pos_jerk = pred_trans_selected_values[3:, :] - 3 * pred_trans_selected_values[2:-1, :] + 3 * pred_trans_selected_values[1:-2, :] - pred_trans_selected_values[:-3, :]
|
| 985 |
+
+
|
| 986 |
+
+ pos_acceleration_loss = torch.mean(pos_acc ** 2)
|
| 987 |
+
+ pos_jerk_loss = torch.mean(pos_jerk ** 2)
|
| 988 |
+
+
|
| 989 |
+
+ q0 = pred_quat_selected_values[:-1, :]
|
| 990 |
+
+ q1 = pred_quat_selected_values[1:, :]
|
| 991 |
+
+ sign = torch.where((q0 * q1).sum(dim=-1) < 0, -1.0, 1.0)
|
| 992 |
+
+ q1 = sign.unsqueeze(-1) * q1
|
| 993 |
+
+
|
| 994 |
+
+ dq = self.q_multiply(q1, self.q_inverse(q0))
|
| 995 |
+
+ theta = 2 * torch.acos(torch.clamp(dq[..., 0], -1.0 + 1e-8, 1.0 - 1e-8))
|
| 996 |
+
+
|
| 997 |
+
+ rot_acc = theta[2:] - 2*theta[1:-1] + theta[:-2]
|
| 998 |
+
+ rot_jerk = theta[3:] - 3*theta[2:-1] + 3*theta[1:-2] - theta[:-3]
|
| 999 |
+
+
|
| 1000 |
+
+ rot_acceleration_loss = torch.mean(rot_acc ** 2)
|
| 1001 |
+
+ rot_jerk_loss = torch.mean(rot_jerk ** 2)
|
| 1002 |
+
+
|
| 1003 |
+
+ alpha_rot = 0.1 # <-- tune this (e.g. 0.1 … 10)
|
| 1004 |
+
+
|
| 1005 |
+
+
|
| 1006 |
+
+ acceleration_loss = pos_acceleration_loss + alpha_rot * rot_acceleration_loss
|
| 1007 |
+
+ jerk_loss = pos_jerk_loss + alpha_rot * rot_jerk_loss
|
| 1008 |
+
+
|
| 1009 |
+
+ smoothness_loss = (
|
| 1010 |
+
+ weight_acceleration * acceleration_loss
|
| 1011 |
+
+ + weight_jerk * jerk_loss
|
| 1012 |
+
+ ) * smoothness_multiplier
|
| 1013 |
+
+
|
| 1014 |
+
+
|
| 1015 |
+
+ # Calculate the spline loss
|
| 1016 |
+
+ l2_multiplier = 10.0
|
| 1017 |
+
+ spline_loss = l2_multiplier * (l2_loss + smoothness_loss) + geodesic_loss + quaternion_dist
|
| 1018 |
+
+
|
| 1019 |
+
+ dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss = None, None, None, None
|
| 1020 |
+
+
|
| 1021 |
+
+ # Uncomment these lines if you want to use the other losses
|
| 1022 |
+
+ '''
|
| 1023 |
+
+ dtw = DynamicTimeWarpingLoss()
|
| 1024 |
+
+ dtw_loss, _ = dtw.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
|
| 1025 |
+
+
|
| 1026 |
+
+ hausdorff = HausdorffDistanceLoss()
|
| 1027 |
+
+ hausdorff_loss, _ = hausdorff.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
|
| 1028 |
+
+
|
| 1029 |
+
+ frec = FrechetDistanceLoss()
|
| 1030 |
+
+ frechet_loss, _ = frec.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
|
| 1031 |
+
+
|
| 1032 |
+
+ chamfer = ChamferDistanceLoss()
|
| 1033 |
+
+ chamfer_loss, _ = chamfer.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
|
| 1034 |
+
+ '''
|
| 1035 |
+
+
|
| 1036 |
+
+ return spline_loss, l2_multiplier * l2_loss, l2_multiplier * smoothness_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss
|
| 1037 |
+
+
|
| 1038 |
+
+
|
| 1039 |
+
+class DynamicTimeWarpingLoss(nn.Module):
|
| 1040 |
+
+ def __init__(self):
|
| 1041 |
+
+ super().__init__()
|
| 1042 |
+
+
|
| 1043 |
+
+ def _dtw_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor:
|
| 1044 |
+
+ """
|
| 1045 |
+
+ Computes the DTW distance between two 2D tensors (T x D),
|
| 1046 |
+
+ where T is sequence length and D is feature dimension.
|
| 1047 |
+
+ """
|
| 1048 |
+
+ # seq1, seq2 shapes: (time_steps, feature_dim)
|
| 1049 |
+
+ n, m = seq1.size(0), seq2.size(0)
|
| 1050 |
+
+
|
| 1051 |
+
+ # Cost matrix (pairwise distances between all elements)
|
| 1052 |
+
+ cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype)
|
| 1053 |
+
+ for i in range(n):
|
| 1054 |
+
+ for j in range(m):
|
| 1055 |
+
+ cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2)
|
| 1056 |
+
+
|
| 1057 |
+
+ # Accumulated cost matrix
|
| 1058 |
+
+ dist = torch.full((n + 1, m + 1), float('inf'),
|
| 1059 |
+
+ device=seq1.device, dtype=seq1.dtype)
|
| 1060 |
+
+ dist[0, 0] = 0.0
|
| 1061 |
+
+
|
| 1062 |
+
+ # Populate the DP table
|
| 1063 |
+
+ for i in range(1, n + 1):
|
| 1064 |
+
+ for j in range(1, m + 1):
|
| 1065 |
+
+ dist[i, j] = cost[i - 1, j - 1] + torch.min(
|
| 1066 |
+
+ torch.min(
|
| 1067 |
+
+ dist[i - 1, j], # Insertion
|
| 1068 |
+
+ dist[i, j - 1], # Deletion
|
| 1069 |
+
+ ),
|
| 1070 |
+
+ dist[i - 1, j - 1]# Match
|
| 1071 |
+
+ )
|
| 1072 |
+
+
|
| 1073 |
+
+ return dist[n, m]
|
| 1074 |
+
+
|
| 1075 |
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
| 1076 |
+
+ """
|
| 1077 |
+
+ Compute the average DTW loss over a batch of sequences.
|
| 1078 |
+
+
|
| 1079 |
+
+ pred, targ shapes: (batch_size, T, D)
|
| 1080 |
+
+ """
|
| 1081 |
+
+ # Ensure shapes match in batch dimension
|
| 1082 |
+
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
|
| 1083 |
+
+
|
| 1084 |
+
+ # Compute DTW distance per sample in the batch
|
| 1085 |
+
+ distances = []
|
| 1086 |
+
+ for b in range(pred.size(0)):
|
| 1087 |
+
+ seq1 = pred[b]
|
| 1088 |
+
+ seq2 = targ[b]
|
| 1089 |
+
+ dtw_val = self._dtw_distance(seq1, seq2)
|
| 1090 |
+
+ distances.append(dtw_val)
|
| 1091 |
+
+
|
| 1092 |
+
+ # Stack and take mean to get scalar loss
|
| 1093 |
+
+ dtw_loss = torch.stack(distances).mean()
|
| 1094 |
+
+ return dtw_loss
|
| 1095 |
+
+
|
| 1096 |
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
|
| 1097 |
+
+ """
|
| 1098 |
+
+ Returns a tuple: (loss, info_dict),
|
| 1099 |
+
+ where loss is a scalar tensor and info_dict is a dictionary
|
| 1100 |
+
+ of extra information (e.g., loss components).
|
| 1101 |
+
+ """
|
| 1102 |
+
+ loss = self._loss(pred, targ)
|
| 1103 |
+
+
|
| 1104 |
+
+ info = {
|
| 1105 |
+
+ 'dtw': loss.item()
|
| 1106 |
+
+ }
|
| 1107 |
+
+
|
| 1108 |
+
+ return loss, info
|
| 1109 |
+
+
|
| 1110 |
+
+class HausdorffDistanceLoss(nn.Module):
|
| 1111 |
+
+ def __init__(self):
|
| 1112 |
+
+ super().__init__()
|
| 1113 |
+
+
|
| 1114 |
+
+ def _hausdorff_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor:
|
| 1115 |
+
+ """
|
| 1116 |
+
+ Computes the Hausdorff distance between two 2D tensors (N x D),
|
| 1117 |
+
+ where N is the number of points and D is the feature dimension.
|
| 1118 |
+
+
|
| 1119 |
+
+ The Hausdorff distance H(A,B) between two sets A and B is defined as:
|
| 1120 |
+
+ H(A, B) = max( h(A, B), h(B, A) ),
|
| 1121 |
+
+ where
|
| 1122 |
+
+ h(A, B) = max_{a in A} min_{b in B} d(a, b).
|
| 1123 |
+
+
|
| 1124 |
+
+ Here, d(a, b) is the Euclidean distance between points a and b.
|
| 1125 |
+
+ """
|
| 1126 |
+
+ # set1, set2 shapes: (num_points, feature_dim)
|
| 1127 |
+
+ n, m = set1.size(0), set2.size(0)
|
| 1128 |
+
+
|
| 1129 |
+
+ # Compute pairwise distances
|
| 1130 |
+
+ cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype)
|
| 1131 |
+
+ for i in range(n):
|
| 1132 |
+
+ for j in range(m):
|
| 1133 |
+
+ cost[i, j] = torch.norm(set1[i] - set2[j], p=2)
|
| 1134 |
+
+
|
| 1135 |
+
+ # Forward direction: for each point in set1, find distance to closest point in set2
|
| 1136 |
+
+ forward_min = cost.min(dim=1)[0] # Shape (n,)
|
| 1137 |
+
+ forward_hausdorff = forward_min.max() # max over n
|
| 1138 |
+
+
|
| 1139 |
+
+ # Backward direction: for each point in set2, find distance to closest point in set1
|
| 1140 |
+
+ backward_min = cost.min(dim=0)[0] # Shape (m,)
|
| 1141 |
+
+ backward_hausdorff = backward_min.max() # max over m
|
| 1142 |
+
+
|
| 1143 |
+
+ # Hausdorff distance is the max of the two
|
| 1144 |
+
+ hausdorff_dist = torch.max(forward_hausdorff, backward_hausdorff)
|
| 1145 |
+
+ return hausdorff_dist
|
| 1146 |
+
+
|
| 1147 |
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
| 1148 |
+
+ """
|
| 1149 |
+
+ Compute the average Hausdorff distance over a batch of point sets.
|
| 1150 |
+
+
|
| 1151 |
+
+ pred, targ shapes: (batch_size, N, D)
|
| 1152 |
+
+ """
|
| 1153 |
+
+ # Ensure shapes match in batch dimension
|
| 1154 |
+
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
|
| 1155 |
+
+
|
| 1156 |
+
+ distances = []
|
| 1157 |
+
+ for b in range(pred.size(0)):
|
| 1158 |
+
+ set1 = pred[b]
|
| 1159 |
+
+ set2 = targ[b]
|
| 1160 |
+
+ h_dist = self._hausdorff_distance(set1, set2)
|
| 1161 |
+
+ distances.append(h_dist)
|
| 1162 |
+
+
|
| 1163 |
+
+ # Stack and take mean to get scalar loss
|
| 1164 |
+
+ hausdorff_loss = torch.stack(distances).mean()
|
| 1165 |
+
+ return hausdorff_loss
|
| 1166 |
+
+
|
| 1167 |
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
|
| 1168 |
+
+ """
|
| 1169 |
+
+ Returns a tuple: (loss, info_dict),
|
| 1170 |
+
+ where loss is a scalar tensor and info_dict is a dictionary
|
| 1171 |
+
+ of extra information (e.g., distance components).
|
| 1172 |
+
+ """
|
| 1173 |
+
+ loss = self._loss(pred, targ)
|
| 1174 |
+
+
|
| 1175 |
+
+ info = {
|
| 1176 |
+
+ 'hausdorff': loss.item()
|
| 1177 |
+
+ }
|
| 1178 |
+
+
|
| 1179 |
+
+ return loss, info
|
| 1180 |
+
+
|
| 1181 |
+
+class FrechetDistanceLoss(nn.Module):
|
| 1182 |
+
+ def __init__(self):
|
| 1183 |
+
+ super().__init__()
|
| 1184 |
+
+
|
| 1185 |
+
+ def _frechet_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor:
|
| 1186 |
+
+ """
|
| 1187 |
+
+ Computes the (discrete) Fr��chet distance between two 2D tensors (T x D),
|
| 1188 |
+
+ where T is the sequence length and D is the feature dimension.
|
| 1189 |
+
+
|
| 1190 |
+
+ The Fréchet distance between two curves in discrete form can be computed
|
| 1191 |
+
+ by filling in a DP table “ca” where:
|
| 1192 |
+
+
|
| 1193 |
+
+ ca[i, j] = max( d(seq1[i], seq2[j]),
|
| 1194 |
+
+ min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]) )
|
| 1195 |
+
+
|
| 1196 |
+
+ with boundary conditions handled appropriately.
|
| 1197 |
+
+ Here, d(seq1[i], seq2[j]) is the Euclidean distance.
|
| 1198 |
+
+ """
|
| 1199 |
+
+ n, m = seq1.size(0), seq2.size(0)
|
| 1200 |
+
+
|
| 1201 |
+
+ # Cost matrix (pairwise distances between all elements)
|
| 1202 |
+
+ cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype)
|
| 1203 |
+
+ for i in range(n):
|
| 1204 |
+
+ for j in range(m):
|
| 1205 |
+
+ cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2)
|
| 1206 |
+
+
|
| 1207 |
+
+ # DP matrix for the Fréchet distance
|
| 1208 |
+
+ ca = torch.full((n, m), float('inf'), device=seq1.device, dtype=seq1.dtype)
|
| 1209 |
+
+ ca[0, 0] = cost[0, 0]
|
| 1210 |
+
+
|
| 1211 |
+
+ # Initialize first row
|
| 1212 |
+
+ for i in range(1, n):
|
| 1213 |
+
+ ca[i, 0] = torch.max(ca[i - 1, 0], cost[i, 0])
|
| 1214 |
+
+
|
| 1215 |
+
+ # Initialize first column
|
| 1216 |
+
+ for j in range(1, m):
|
| 1217 |
+
+ ca[0, j] = torch.max(ca[0, j - 1], cost[0, j])
|
| 1218 |
+
+
|
| 1219 |
+
+ # Populate the DP table
|
| 1220 |
+
+ for i in range(1, n):
|
| 1221 |
+
+ for j in range(1, m):
|
| 1222 |
+
+ ca[i, j] = torch.max(
|
| 1223 |
+
+ cost[i, j],
|
| 1224 |
+
+ torch.min(
|
| 1225 |
+
+ torch.min(
|
| 1226 |
+
+ ca[i - 1, j],
|
| 1227 |
+
+ ca[i, j - 1],
|
| 1228 |
+
+ ),
|
| 1229 |
+
+ ca[i - 1, j - 1]
|
| 1230 |
+
+ )
|
| 1231 |
+
+ )
|
| 1232 |
+
+
|
| 1233 |
+
+ return ca[n - 1, m - 1]
|
| 1234 |
+
+
|
| 1235 |
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
| 1236 |
+
+ """
|
| 1237 |
+
+ Compute the average Fréchet distance over a batch of sequences.
|
| 1238 |
+
+
|
| 1239 |
+
+ pred, targ shapes: (batch_size, T, D)
|
| 1240 |
+
+ """
|
| 1241 |
+
+ # Ensure shapes match in batch dimension
|
| 1242 |
+
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
|
| 1243 |
+
+
|
| 1244 |
+
+ distances = []
|
| 1245 |
+
+ for b in range(pred.size(0)):
|
| 1246 |
+
+ seq1 = pred[b]
|
| 1247 |
+
+ seq2 = targ[b]
|
| 1248 |
+
+ fd_val = self._frechet_distance(seq1, seq2)
|
| 1249 |
+
+ distances.append(fd_val)
|
| 1250 |
+
+
|
| 1251 |
+
+ # Stack and take mean to get scalar loss
|
| 1252 |
+
+ frechet_loss = torch.stack(distances).mean()
|
| 1253 |
+
+ return frechet_loss
|
| 1254 |
+
+
|
| 1255 |
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
|
| 1256 |
+
+ """
|
| 1257 |
+
+ Returns a tuple: (loss, info_dict),
|
| 1258 |
+
+ where loss is a scalar tensor and info_dict is a dictionary
|
| 1259 |
+
+ of extra information (e.g., distance components).
|
| 1260 |
+
+ """
|
| 1261 |
+
+ loss = self._loss(pred, targ)
|
| 1262 |
+
+ info = {
|
| 1263 |
+
+ 'frechet': loss.item()
|
| 1264 |
+
+ }
|
| 1265 |
+
+ return loss, info
|
| 1266 |
+
+
|
| 1267 |
+
+class ChamferDistanceLoss(nn.Module):
|
| 1268 |
+
+ def __init__(self):
|
| 1269 |
+
+ super().__init__()
|
| 1270 |
+
+
|
| 1271 |
+
+ def _chamfer_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor:
|
| 1272 |
+
+ """
|
| 1273 |
+
+ Computes the symmetrical Chamfer distance between
|
| 1274 |
+
+ two 2D tensors (N x D), where N is the number of points
|
| 1275 |
+
+ and D is the feature dimension.
|
| 1276 |
+
+
|
| 1277 |
+
+ The Chamfer distance between two point sets A and B is often defined as:
|
| 1278 |
+
+
|
| 1279 |
+
+ d_chamfer(A, B) = 1/|A| ∑_{a ∈ A} min_{b ∈ B} ‖a - b‖₂
|
| 1280 |
+
+ + 1/|B| ∑_{b ∈ B} min_{a ∈ A} ‖b - a‖₂,
|
| 1281 |
+
+
|
| 1282 |
+
+ where ‖·‖₂ is the Euclidean distance.
|
| 1283 |
+
+ """
|
| 1284 |
+
+ # set1, set2 shapes: (num_points, feature_dim)
|
| 1285 |
+
+ n, m = set1.size(0), set2.size(0)
|
| 1286 |
+
+
|
| 1287 |
+
+ cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype)
|
| 1288 |
+
+ for i in range(n):
|
| 1289 |
+
+ for j in range(m):
|
| 1290 |
+
+ cost[i, j] = torch.norm(set1[i] - set2[j], p=2)
|
| 1291 |
+
+
|
| 1292 |
+
+ # For each point in set1, find distance to the closest point in set2
|
| 1293 |
+
+ forward_min = cost.min(dim=1)[0] # shape: (n,)
|
| 1294 |
+
+ forward_mean = forward_min.mean()
|
| 1295 |
+
+
|
| 1296 |
+
+ # For each point in set2, find distance to the closest point in set1
|
| 1297 |
+
+ backward_min = cost.min(dim=0)[0] # shape: (m,)
|
| 1298 |
+
+ backward_mean = backward_min.mean()
|
| 1299 |
+
+
|
| 1300 |
+
+ chamfer_dist = forward_mean + backward_mean
|
| 1301 |
+
+ return chamfer_dist
|
| 1302 |
+
+
|
| 1303 |
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
| 1304 |
+
+ """
|
| 1305 |
+
+ Compute the average Chamfer distance over a batch of point sets.
|
| 1306 |
+
+
|
| 1307 |
+
+ pred, targ shapes: (batch_size, N, D)
|
| 1308 |
+
+ """
|
| 1309 |
+
+ # Ensure shapes match in batch dimension
|
| 1310 |
+
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
|
| 1311 |
+
+
|
| 1312 |
+
+ distances = []
|
| 1313 |
+
+ for b in range(pred.size(0)):
|
| 1314 |
+
+ set1 = pred[b]
|
| 1315 |
+
+ set2 = targ[b]
|
| 1316 |
+
+ distance_val = self._chamfer_distance(set1, set2)
|
| 1317 |
+
+ distances.append(distance_val)
|
| 1318 |
+
+
|
| 1319 |
+
+ # Combine into a single scalar
|
| 1320 |
+
+ chamfer_loss = torch.stack(distances).mean()
|
| 1321 |
+
+ return chamfer_loss
|
| 1322 |
+
+
|
| 1323 |
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
|
| 1324 |
+
+ """
|
| 1325 |
+
+ Returns a tuple: (loss, info_dict),
|
| 1326 |
+
+ where 'loss' is a scalar tensor and 'info_dict' is a dictionary
|
| 1327 |
+
+ of extra information (e.g., distance components).
|
| 1328 |
+
+ """
|
| 1329 |
+
+ loss = self._loss(pred, targ)
|
| 1330 |
+
+ info = {
|
| 1331 |
+
+ 'chamfer': loss.item()
|
| 1332 |
+
+ }
|
| 1333 |
+
+ return loss, info
|
| 1334 |
+
+
|
| 1335 |
+
+
|
| 1336 |
+
+def slerp(q1, q2, t):
|
| 1337 |
+
+ """Spherical linear interpolation between two quaternions."""
|
| 1338 |
+
+ q1 = q1 / np.linalg.norm(q1)
|
| 1339 |
+
+ q2 = q2 / np.linalg.norm(q2)
|
| 1340 |
+
+ dot = np.dot(q1, q2)
|
| 1341 |
+
+
|
| 1342 |
+
+ if dot < 0.0:
|
| 1343 |
+
+ q2 = -q2
|
| 1344 |
+
+ dot = -dot
|
| 1345 |
+
+ # If dot is very close to 1, use linear interpolation
|
| 1346 |
+
+
|
| 1347 |
+
+ if dot > 0.9995:
|
| 1348 |
+
+ result = q1 + t * (q2 - q1)
|
| 1349 |
+
+ result = result / np.linalg.norm(result)
|
| 1350 |
+
+ return result
|
| 1351 |
+
+
|
| 1352 |
+
+ theta_0 = np.arccos(dot)
|
| 1353 |
+
+ theta = theta_0 * t
|
| 1354 |
+
+
|
| 1355 |
+
+ q3 = q2 - q1 * dot
|
| 1356 |
+
+ q3 = q3 / np.linalg.norm(q3)
|
| 1357 |
+
+ return q1 * np.cos(theta) + q3 * np.sin(theta)
|
| 1358 |
+
+
|
| 1359 |
+
+def catmull_rom_spline_with_rotation(control_points, timepoints, horizon):
|
| 1360 |
+
+ """Compute Catmull-Rom spline for both position and quaternion rotation."""
|
| 1361 |
+
+ spline_points = []
|
| 1362 |
+
+ # Extrapolate the initial points
|
| 1363 |
+
+ if timepoints[0] != 0:
|
| 1364 |
+
+ for t in range(timepoints[0]):
|
| 1365 |
+
+ x = control_points[0][0]
|
| 1366 |
+
+ y = control_points[0][1]
|
| 1367 |
+
+ z = control_points[0][2]
|
| 1368 |
+
+ q = control_points[0][3:7]
|
| 1369 |
+
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
|
| 1370 |
+
+
|
| 1371 |
+
+ #Linear interpolate between 0th and 1th control points
|
| 1372 |
+
+ for t in np.linspace(0, 1, timepoints[1] - timepoints[0] + 1):
|
| 1373 |
+
+ x = control_points[0][0] + t * (control_points[1][0] - control_points[0][0])
|
| 1374 |
+
+ y = control_points[0][1] + t * (control_points[1][1] - control_points[0][1])
|
| 1375 |
+
+ z = control_points[0][2] + t * (control_points[1][2] - control_points[0][2])
|
| 1376 |
+
+ q = slerp(control_points[0][3:7], control_points[1][3:7], t)
|
| 1377 |
+
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
|
| 1378 |
+
+
|
| 1379 |
+
+
|
| 1380 |
+
+ # Iterate over the control points
|
| 1381 |
+
+ for i in range(1, len(control_points) - 2):
|
| 1382 |
+
+ P0 = control_points[i-1][:3]
|
| 1383 |
+
+ P1 = control_points[i][:3]
|
| 1384 |
+
+ P2 = control_points[i+1][:3]
|
| 1385 |
+
+ P3 = control_points[i+2][:3]
|
| 1386 |
+
+ Q0 = control_points[i-1][3:7]
|
| 1387 |
+
+ Q1 = control_points[i][3:7]
|
| 1388 |
+
+ Q2 = control_points[i+1][3:7]
|
| 1389 |
+
+ Q3 = control_points[i+2][3:7]
|
| 1390 |
+
+
|
| 1391 |
+
+ # Interpolate position (using Catmull-Rom spline)
|
| 1392 |
+
+ for idx, t in enumerate(np.linspace(0, 1, timepoints[i+1] - timepoints[i] + 1)):
|
| 1393 |
+
+ if idx == 0:
|
| 1394 |
+
+ continue
|
| 1395 |
+
+
|
| 1396 |
+
+ x = 0.5 * ((2 * P1[0]) + (-P0[0] + P2[0]) * t +
|
| 1397 |
+
+ (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t**2 +
|
| 1398 |
+
+ (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t**3)
|
| 1399 |
+
+ y = 0.5 * ((2 * P1[1]) + (-P0[1] + P2[1]) * t +
|
| 1400 |
+
+ (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t**2 +
|
| 1401 |
+
+ (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t**3)
|
| 1402 |
+
+ z = 0.5 * ((2 * P1[2]) + (-P0[2] + P2[2]) * t +
|
| 1403 |
+
+ (2 * P0[2] - 5 * P1[2] + 4 * P2[2] - P3[2]) * t**2 +
|
| 1404 |
+
+ (-P0[2] + 3 * P1[2] - 3 * P2[2] + P3[2]) * t**3)
|
| 1405 |
+
+ q = slerp(Q1, Q2, t)
|
| 1406 |
+
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
|
| 1407 |
+
+
|
| 1408 |
+
+ #Linear interpolate between 2nd last and last control points
|
| 1409 |
+
+ for idx, t in enumerate(np.linspace(0, 1, timepoints[-1] - timepoints[-2] + 1)):
|
| 1410 |
+
+ if idx == 0:
|
| 1411 |
+
+ continue
|
| 1412 |
+
+ x = control_points[-2][0] + t * (control_points[-1][0] - control_points[-2][0])
|
| 1413 |
+
+ y = control_points[-2][1] + t * (control_points[-1][1] - control_points[-2][1])
|
| 1414 |
+
+ z = control_points[-2][2] + t * (control_points[-1][2] - control_points[-2][2])
|
| 1415 |
+
+ q = slerp(control_points[-2][3:7], control_points[-1][3:7], t)
|
| 1416 |
+
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
|
| 1417 |
+
+
|
| 1418 |
+
+ # Extrapolate the rest of the points
|
| 1419 |
+
+ if timepoints[-1] != horizon:
|
| 1420 |
+
+ for t in range(timepoints[-1] + 1, horizon):
|
| 1421 |
+
+ x = control_points[-1][0]
|
| 1422 |
+
+ y = control_points[-1][1]
|
| 1423 |
+
+ z = control_points[-1][2]
|
| 1424 |
+
+ q = control_points[-1][3:7]
|
| 1425 |
+
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
|
| 1426 |
+
+
|
| 1427 |
+
+ stacked_spline_points = np.stack(spline_points, axis=0)
|
| 1428 |
+
+
|
| 1429 |
+
+ if control_points.shape[1] != 7:
|
| 1430 |
+
+ stacked_spline_points = np.concatenate([stacked_spline_points, np.zeros((stacked_spline_points.shape[0], 1))], axis=1)
|
| 1431 |
+
+
|
| 1432 |
+
+
|
| 1433 |
+
+ return stacked_spline_points
|
| 1434 |
+
+
|
| 1435 |
+
+def catmull_rom_loss(trajectories, conditions, loss_fc):
|
| 1436 |
+
+ '''
|
| 1437 |
+
+ loss for catmull-rom interpolation
|
| 1438 |
+
+ '''
|
| 1439 |
+
+ batch_size, horizon, transition = trajectories.shape
|
| 1440 |
+
+
|
| 1441 |
+
+ # Extract known indices and values
|
| 1442 |
+
+ known_indices = np.array(list(conditions.keys()), dtype=int)
|
| 1443 |
+
+
|
| 1444 |
+
+ # candidate_no x batch_size x dim
|
| 1445 |
+
+ known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
|
| 1446 |
+
+ known_values = np.moveaxis(known_values, 0, 1)
|
| 1447 |
+
+
|
| 1448 |
+
+ # Sort the timepoints
|
| 1449 |
+
+ sorted_indices = np.argsort(known_indices)
|
| 1450 |
+
+ known_indices = known_indices[sorted_indices]
|
| 1451 |
+
+ known_values = known_values[:, sorted_indices]
|
| 1452 |
+
+ spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)])
|
| 1453 |
+
+
|
| 1454 |
+
+ # Convert to tensor and move to the same device as trajectories
|
| 1455 |
+
+ spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device)
|
| 1456 |
+
+ assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}"
|
| 1457 |
+
+ return loss_fc(spline_points, trajectories)
|
| 1458 |
+
+
|
| 1459 |
+
Losses = {
|
| 1460 |
+
'l1': WeightedL1,
|
| 1461 |
+
'l2': WeightedL2,
|
| 1462 |
+
'value_l1': ValueL1,
|
| 1463 |
+
'value_l2': ValueL2,
|
| 1464 |
+
+ 'geodesic_l2': GeodesicL2Loss,
|
| 1465 |
+
+ 'rotation_translation': RotationTranslationLoss,
|
| 1466 |
+
+ 'spline': SplineLoss,
|
| 1467 |
+
}
|
| 1468 |
+
diff --git a/diffuser/models/temporal.py b/diffuser/models/temporal.py
|
| 1469 |
+
index e0b9e5c..0f7854a 100644
|
| 1470 |
+
--- a/diffuser/models/temporal.py
|
| 1471 |
+
+++ b/diffuser/models/temporal.py
|
| 1472 |
+
@@ -17,18 +17,18 @@ class ResidualTemporalBlock(nn.Module):
|
| 1473 |
+
super().__init__()
|
| 1474 |
+
|
| 1475 |
+
self.blocks = nn.ModuleList([
|
| 1476 |
+
- Conv1dBlock(inp_channels, out_channels, kernel_size),
|
| 1477 |
+
- Conv1dBlock(out_channels, out_channels, kernel_size),
|
| 1478 |
+
+ Conv1dBlock(inp_channels, out_channels, kernel_size).to(dtype=torch.float64),
|
| 1479 |
+
+ Conv1dBlock(out_channels, out_channels, kernel_size).to(dtype=torch.float64),
|
| 1480 |
+
])
|
| 1481 |
+
|
| 1482 |
+
self.time_mlp = nn.Sequential(
|
| 1483 |
+
nn.Mish(),
|
| 1484 |
+
- nn.Linear(embed_dim, out_channels),
|
| 1485 |
+
+ nn.Linear(embed_dim, out_channels).to(dtype=torch.float64),
|
| 1486 |
+
Rearrange('batch t -> batch t 1'),
|
| 1487 |
+
- )
|
| 1488 |
+
+ ).to(dtype=torch.float64)
|
| 1489 |
+
|
| 1490 |
+
- self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
|
| 1491 |
+
- if inp_channels != out_channels else nn.Identity()
|
| 1492 |
+
+ self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1).to(dtype=torch.float64) \
|
| 1493 |
+
+ if inp_channels != out_channels else nn.Identity().to(dtype=torch.float64)
|
| 1494 |
+
|
| 1495 |
+
def forward(self, x, t):
|
| 1496 |
+
'''
|
| 1497 |
+
@@ -37,7 +37,8 @@ class ResidualTemporalBlock(nn.Module):
|
| 1498 |
+
returns:
|
| 1499 |
+
out : [ batch_size x out_channels x horizon ]
|
| 1500 |
+
'''
|
| 1501 |
+
- out = self.blocks[0](x) + self.time_mlp(t)
|
| 1502 |
+
+
|
| 1503 |
+
+ out = self.blocks[0](x) + self.time_mlp(t.double())
|
| 1504 |
+
out = self.blocks[1](out)
|
| 1505 |
+
return out + self.residual_conv(x)
|
| 1506 |
+
|
| 1507 |
+
@@ -49,11 +50,11 @@ class TemporalUnet(nn.Module):
|
| 1508 |
+
transition_dim,
|
| 1509 |
+
cond_dim,
|
| 1510 |
+
dim=32,
|
| 1511 |
+
- dim_mults=(1, 2, 4, 8),
|
| 1512 |
+
+ dim_mults=(1, 2, 4),
|
| 1513 |
+
):
|
| 1514 |
+
super().__init__()
|
| 1515 |
+
|
| 1516 |
+
- dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
| 1517 |
+
+ dims = [(transition_dim + cond_dim), *map(lambda m: dim * m, dim_mults)]
|
| 1518 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
| 1519 |
+
print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
| 1520 |
+
|
| 1521 |
+
@@ -100,7 +101,7 @@ class TemporalUnet(nn.Module):
|
| 1522 |
+
|
| 1523 |
+
self.final_conv = nn.Sequential(
|
| 1524 |
+
Conv1dBlock(dim, dim, kernel_size=5),
|
| 1525 |
+
- nn.Conv1d(dim, transition_dim, 1),
|
| 1526 |
+
+ nn.Conv1d(dim, transition_dim, 1).to(dtype=torch.float64),
|
| 1527 |
+
)
|
| 1528 |
+
|
| 1529 |
+
def forward(self, x, cond, time):
|
| 1530 |
+
@@ -129,7 +130,6 @@ class TemporalUnet(nn.Module):
|
| 1531 |
+
x = upsample(x)
|
| 1532 |
+
|
| 1533 |
+
x = self.final_conv(x)
|
| 1534 |
+
-
|
| 1535 |
+
x = einops.rearrange(x, 'b t h -> b h t')
|
| 1536 |
+
return x
|
| 1537 |
+
|
| 1538 |
+
diff --git a/diffuser/utils/arrays.py b/diffuser/utils/arrays.py
|
| 1539 |
+
index c3a9d24..96a7093 100644
|
| 1540 |
+
--- a/diffuser/utils/arrays.py
|
| 1541 |
+
+++ b/diffuser/utils/arrays.py
|
| 1542 |
+
@@ -54,7 +54,7 @@ def batchify(batch):
|
| 1543 |
+
1) converting np arrays to torch tensors and
|
| 1544 |
+
2) and ensuring that everything has a batch dimension
|
| 1545 |
+
'''
|
| 1546 |
+
- fn = lambda x: to_torch(x[None])
|
| 1547 |
+
+ fn = lambda x: to_torch(x[None], dtype=torch.float64)
|
| 1548 |
+
|
| 1549 |
+
batched_vals = []
|
| 1550 |
+
for field in batch._fields:
|
| 1551 |
+
diff --git a/diffuser/utils/serialization.py b/diffuser/utils/serialization.py
|
| 1552 |
+
index 6cc9db9..039eb64 100644
|
| 1553 |
+
--- a/diffuser/utils/serialization.py
|
| 1554 |
+
+++ b/diffuser/utils/serialization.py
|
| 1555 |
+
@@ -19,7 +19,7 @@ def mkdir(savepath):
|
| 1556 |
+
return False
|
| 1557 |
+
|
| 1558 |
+
def get_latest_epoch(loadpath):
|
| 1559 |
+
- states = glob.glob1(os.path.join(*loadpath), 'state_*')
|
| 1560 |
+
+ states = glob.glob1(os.path.join(loadpath), 'state_*')
|
| 1561 |
+
latest_epoch = -1
|
| 1562 |
+
for state in states:
|
| 1563 |
+
epoch = int(state.replace('state_', '').replace('.pt', ''))
|
| 1564 |
+
diff --git a/diffuser/utils/training.py b/diffuser/utils/training.py
|
| 1565 |
+
index be3556e..c21e0f0 100644
|
| 1566 |
+
--- a/diffuser/utils/training.py
|
| 1567 |
+
+++ b/diffuser/utils/training.py
|
| 1568 |
+
@@ -4,16 +4,24 @@ import numpy as np
|
| 1569 |
+
import torch
|
| 1570 |
+
import einops
|
| 1571 |
+
import pdb
|
| 1572 |
+
+from tqdm import tqdm
|
| 1573 |
+
+import wandb
|
| 1574 |
+
+from pytorch3d.transforms import axis_angle_to_quaternion
|
| 1575 |
+
|
| 1576 |
+
from .arrays import batch_to_device, to_np, to_device, apply_dict
|
| 1577 |
+
from .timer import Timer
|
| 1578 |
+
from .cloud import sync_logs
|
| 1579 |
+
+from ..models.helpers import catmull_rom_spline_with_rotation
|
| 1580 |
+
|
| 1581 |
+
def cycle(dl):
|
| 1582 |
+
while True:
|
| 1583 |
+
for data in dl:
|
| 1584 |
+
yield data
|
| 1585 |
+
|
| 1586 |
+
+def assert_no_nan_weights(model):
|
| 1587 |
+
+ for name, param in model.named_parameters():
|
| 1588 |
+
+ assert not torch.isnan(param).any(), f"NaN detected in parameter: {name}"
|
| 1589 |
+
+
|
| 1590 |
+
class EMA():
|
| 1591 |
+
'''
|
| 1592 |
+
empirical moving average
|
| 1593 |
+
@@ -71,13 +79,35 @@ class Trainer(object):
|
| 1594 |
+
self.gradient_accumulate_every = gradient_accumulate_every
|
| 1595 |
+
|
| 1596 |
+
self.dataset = dataset
|
| 1597 |
+
- self.dataloader = cycle(torch.utils.data.DataLoader(
|
| 1598 |
+
- self.dataset, batch_size=train_batch_size, num_workers=1, shuffle=True, pin_memory=True
|
| 1599 |
+
+ dataset_size = len(self.dataset)
|
| 1600 |
+
+
|
| 1601 |
+
+ # Read the indices from the .txt file
|
| 1602 |
+
+ with open(os.path.join(results_folder, 'train_indices.txt'), 'r') as f:
|
| 1603 |
+
+ self.train_indices = f.read()
|
| 1604 |
+
+ self.train_indices = [int(i) for i in self.train_indices.split('\n') if i]
|
| 1605 |
+
+
|
| 1606 |
+
+ with open(os.path.join(results_folder, 'val_indices.txt'), 'r') as f:
|
| 1607 |
+
+ self.val_indices = f.read()
|
| 1608 |
+
+ self.val_indices = [int(i) for i in self.val_indices.split('\n') if i]
|
| 1609 |
+
+
|
| 1610 |
+
+
|
| 1611 |
+
+ self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)
|
| 1612 |
+
+ self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)
|
| 1613 |
+
+ self.train_dataloader = cycle(torch.utils.data.DataLoader(
|
| 1614 |
+
+ self.train_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False
|
| 1615 |
+
+ ))
|
| 1616 |
+
+
|
| 1617 |
+
+ self.val_dataloader = cycle(torch.utils.data.DataLoader(
|
| 1618 |
+
+ self.val_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False
|
| 1619 |
+
))
|
| 1620 |
+
+
|
| 1621 |
+
self.dataloader_vis = cycle(torch.utils.data.DataLoader(
|
| 1622 |
+
self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
|
| 1623 |
+
))
|
| 1624 |
+
self.renderer = renderer
|
| 1625 |
+
+
|
| 1626 |
+
+
|
| 1627 |
+
+
|
| 1628 |
+
self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr)
|
| 1629 |
+
|
| 1630 |
+
self.logdir = results_folder
|
| 1631 |
+
@@ -88,6 +118,8 @@ class Trainer(object):
|
| 1632 |
+
self.reset_parameters()
|
| 1633 |
+
self.step = 0
|
| 1634 |
+
|
| 1635 |
+
+ self.log_to_wandb = False
|
| 1636 |
+
+
|
| 1637 |
+
def reset_parameters(self):
|
| 1638 |
+
self.ema_model.load_state_dict(self.model.state_dict())
|
| 1639 |
+
|
| 1640 |
+
@@ -102,36 +134,129 @@ class Trainer(object):
|
| 1641 |
+
#-----------------------------------------------------------------------------#
|
| 1642 |
+
|
| 1643 |
+
def train(self, n_train_steps):
|
| 1644 |
+
-
|
| 1645 |
+
+ # Save the indices as .txt files
|
| 1646 |
+
+ with open(os.path.join(self.logdir, 'train_indices.txt'), 'w') as f:
|
| 1647 |
+
+ for idx in self.train_indices:
|
| 1648 |
+
+ f.write(f"{idx}\n")
|
| 1649 |
+
+ with open(os.path.join(self.logdir, 'val_indices.txt'), 'w') as f:
|
| 1650 |
+
+ for idx in self.val_indices:
|
| 1651 |
+
+ f.write(f"{idx}\n")
|
| 1652 |
+
+
|
| 1653 |
+
timer = Timer()
|
| 1654 |
+
- for step in range(n_train_steps):
|
| 1655 |
+
+ torch.autograd.set_detect_anomaly(True)
|
| 1656 |
+
+
|
| 1657 |
+
+ # Setup wandb
|
| 1658 |
+
+ if self.log_to_wandb:
|
| 1659 |
+
+ wandb.init(
|
| 1660 |
+
+ project='trajectory-generation',
|
| 1661 |
+
+ config={'lr': self.optimizer.param_groups[0]['lr'], 'batch_size': self.batch_size, 'gradient_accumulate_every': self.gradient_accumulate_every},
|
| 1662 |
+
+ )
|
| 1663 |
+
+
|
| 1664 |
+
+ for step in tqdm(range(n_train_steps)):
|
| 1665 |
+
+
|
| 1666 |
+
+ mean_train_loss = 0.0
|
| 1667 |
+
for i in range(self.gradient_accumulate_every):
|
| 1668 |
+
- batch = next(self.dataloader)
|
| 1669 |
+
+ batch = next(self.train_dataloader)
|
| 1670 |
+
batch = batch_to_device(batch)
|
| 1671 |
+
-
|
| 1672 |
+
- loss, infos = self.model.loss(*batch)
|
| 1673 |
+
+
|
| 1674 |
+
+ loss, infos = self.model.loss(x=batch.trajectories, cond=batch.conditions)
|
| 1675 |
+
loss = loss / self.gradient_accumulate_every
|
| 1676 |
+
+ mean_train_loss += loss.item()
|
| 1677 |
+
loss.backward()
|
| 1678 |
+
|
| 1679 |
+
+ if self.log_to_wandb:
|
| 1680 |
+
+ wandb.log({
|
| 1681 |
+
+ 'step': self.step,
|
| 1682 |
+
+ 'train/loss': mean_train_loss
|
| 1683 |
+
+ })
|
| 1684 |
+
+
|
| 1685 |
+
+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 1686 |
+
+
|
| 1687 |
+
self.optimizer.step()
|
| 1688 |
+
self.optimizer.zero_grad()
|
| 1689 |
+
|
| 1690 |
+
+ assert_no_nan_weights(self.model)
|
| 1691 |
+
+
|
| 1692 |
+
if self.step % self.update_ema_every == 0:
|
| 1693 |
+
self.step_ema()
|
| 1694 |
+
|
| 1695 |
+
if self.step % self.save_freq == 0:
|
| 1696 |
+
- label = self.step // self.label_freq * self.label_freq
|
| 1697 |
+
+ label = self.step
|
| 1698 |
+
+ print(f'Saving model at step {self.step}...')
|
| 1699 |
+
self.save(label)
|
| 1700 |
+
|
| 1701 |
+
if self.step % self.log_freq == 0:
|
| 1702 |
+
- infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
|
| 1703 |
+
- print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}')
|
| 1704 |
+
+ val_losses = []
|
| 1705 |
+
+ lin_int_losses = []
|
| 1706 |
+
+
|
| 1707 |
+
+ val_infos_list = []
|
| 1708 |
+
+ lin_int_infos_list = []
|
| 1709 |
+
+
|
| 1710 |
+
+ catmull_losses = []
|
| 1711 |
+
+ catmull_infos_list = []
|
| 1712 |
+
+
|
| 1713 |
+
+ for _ in range(len(self.val_indices)):
|
| 1714 |
+
+ val_batch = next(self.val_dataloader)
|
| 1715 |
+
+ val_batch = batch_to_device(val_batch)
|
| 1716 |
+
+
|
| 1717 |
+
+ traj = self.model.forward(val_batch.conditions, horizon=val_batch.trajectories.shape[1])
|
| 1718 |
+
+ val_loss, val_infos = self.model.loss_fn(traj, val_batch.trajectories, cond=val_batch.conditions)
|
| 1719 |
+
+
|
| 1720 |
+
+ val_losses.append(val_loss.item())
|
| 1721 |
+
+ val_infos_list.append({key: val for key, val in val_infos.items()})
|
| 1722 |
+
+
|
| 1723 |
+
+
|
| 1724 |
+
+ (lin_int_loss, lin_int_infos), lin_int_traj = self.linear_interpolation_loss(
|
| 1725 |
+
+ val_batch.trajectories, val_batch.conditions, self.model.loss_fn
|
| 1726 |
+
+ )
|
| 1727 |
+
+ lin_int_losses.append(lin_int_loss.item())
|
| 1728 |
+
+ lin_int_infos_list.append({key: val for key, val in lin_int_infos.items()})
|
| 1729 |
+
+
|
| 1730 |
+
+ (catmull_loss, catmull_infos), catmull_traj = self.catmull_rom_loss(
|
| 1731 |
+
+ val_batch.trajectories, val_batch.conditions, self.model.loss_fn
|
| 1732 |
+
+ )
|
| 1733 |
+
+
|
| 1734 |
+
+ catmull_losses.append(catmull_loss.item())
|
| 1735 |
+
+ catmull_infos_list.append(catmull_infos)
|
| 1736 |
+
+
|
| 1737 |
+
+ avg_val_loss = np.mean(val_losses)
|
| 1738 |
+
+ avg_lin_int_loss = np.mean(lin_int_losses)
|
| 1739 |
+
+
|
| 1740 |
+
+ val_infos = {key: np.mean([info[key] for info in val_infos_list]) for key in val_infos_list[0].keys()}
|
| 1741 |
+
+ lin_int_infos = {key: np.mean([info[key] for info in lin_int_infos_list]) for key in lin_int_infos_list[0].keys()}
|
| 1742 |
+
|
| 1743 |
+
- if self.step == 0 and self.sample_freq:
|
| 1744 |
+
- self.render_reference(self.n_reference)
|
| 1745 |
+
+ avg_catmull_loss = np.mean(catmull_losses)
|
| 1746 |
+
+ catmull_infos = {key: np.mean([info[key] for info in catmull_infos_list]) for key in catmull_infos_list[0].keys()}
|
| 1747 |
+
|
| 1748 |
+
- if self.sample_freq and self.step % self.sample_freq == 0:
|
| 1749 |
+
- self.render_samples(n_samples=self.n_samples)
|
| 1750 |
+
+ val_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in val_infos.items()])
|
| 1751 |
+
+ lin_int_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in lin_int_infos.items()])
|
| 1752 |
+
+ catmull_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in catmull_infos.items()])
|
| 1753 |
+
+
|
| 1754 |
+
+
|
| 1755 |
+
+ infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
|
| 1756 |
+
+ print("Learning Rate: ", self.optimizer.param_groups[0]['lr'])
|
| 1757 |
+
+ print(f'Step {self.step}: {loss * self.gradient_accumulate_every:8.4f} | {infos_str} | t: {timer():8.4f}')
|
| 1758 |
+
+ print(f'Validation - {self.step}: {avg_val_loss:8.4f} | {val_infos_str} | t: {timer():8.4f}')
|
| 1759 |
+
+ print(f'Linear Interpolation Loss - {self.step}: {avg_lin_int_loss:8.4f} | {lin_int_infos_str} | t: {timer():8.4f}')
|
| 1760 |
+
+ print(f'Catmull Rom Loss - {self.step}: {avg_catmull_loss:8.4f} | {catmull_infos_str} | t: {timer():8.4f}')
|
| 1761 |
+
+ print()
|
| 1762 |
+
+
|
| 1763 |
+
+ if self.log_to_wandb:
|
| 1764 |
+
+ wandb.log({
|
| 1765 |
+
+ 'step': self.step,
|
| 1766 |
+
+ 'val/loss': avg_val_loss,
|
| 1767 |
+
+ 'val/linear_interp/loss': avg_lin_int_loss,
|
| 1768 |
+
+ 'val/linear_interp/quaternion dist.': lin_int_infos['quat. dist.'],
|
| 1769 |
+
+ 'val/linear_interp/euclidean dist.': lin_int_infos['trans. error'],
|
| 1770 |
+
+ 'val/linear_interp/geodesic loss': lin_int_infos['geodesic dist.'],
|
| 1771 |
+
+ 'val/catmull_rom/loss': avg_catmull_loss,
|
| 1772 |
+
+ 'val/catmull_rom/quaternion dist.': catmull_infos['quat. dist.'],
|
| 1773 |
+
+ 'val/catmull_rom/euclidean dist.': catmull_infos['trans. error'],
|
| 1774 |
+
+ 'val/catmull_rom/geodesic loss': catmull_infos['geodesic dist.'],
|
| 1775 |
+
+ 'val/quaternion dist.': val_infos['quat. dist.'],
|
| 1776 |
+
+ 'val/euclidean dist.': val_infos['trans. error'],
|
| 1777 |
+
+ 'val/geodesic loss': val_infos['geodesic dist.'],
|
| 1778 |
+
+ })
|
| 1779 |
+
|
| 1780 |
+
self.step += 1
|
| 1781 |
+
|
| 1782 |
+
@@ -186,15 +311,6 @@ class Trainer(object):
|
| 1783 |
+
normed_observations = trajectories[:, :, self.dataset.action_dim:]
|
| 1784 |
+
observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
|
| 1785 |
+
|
| 1786 |
+
- # from diffusion.datasets.preprocessing import blocks_cumsum_quat
|
| 1787 |
+
- # # observations = conditions + blocks_cumsum_quat(deltas)
|
| 1788 |
+
- # observations = conditions + deltas.cumsum(axis=1)
|
| 1789 |
+
-
|
| 1790 |
+
- #### @TODO: remove block-stacking specific stuff
|
| 1791 |
+
- # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
|
| 1792 |
+
- # observations = blocks_add_kuka(observations)
|
| 1793 |
+
- ####
|
| 1794 |
+
-
|
| 1795 |
+
savepath = os.path.join(self.logdir, f'_sample-reference.png')
|
| 1796 |
+
self.renderer.composite(savepath, observations)
|
| 1797 |
+
|
| 1798 |
+
@@ -225,9 +341,6 @@ class Trainer(object):
|
| 1799 |
+
# [ 1 x 1 x observation_dim ]
|
| 1800 |
+
normed_conditions = to_np(batch.conditions[0])[:,None]
|
| 1801 |
+
|
| 1802 |
+
- # from diffusion.datasets.preprocessing import blocks_cumsum_quat
|
| 1803 |
+
- # observations = conditions + blocks_cumsum_quat(deltas)
|
| 1804 |
+
- # observations = conditions + deltas.cumsum(axis=1)
|
| 1805 |
+
|
| 1806 |
+
## [ n_samples x (horizon + 1) x observation_dim ]
|
| 1807 |
+
normed_observations = np.concatenate([
|
| 1808 |
+
@@ -238,10 +351,70 @@ class Trainer(object):
|
| 1809 |
+
## [ n_samples x (horizon + 1) x observation_dim ]
|
| 1810 |
+
observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
|
| 1811 |
+
|
| 1812 |
+
- #### @TODO: remove block-stacking specific stuff
|
| 1813 |
+
- # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
|
| 1814 |
+
- # observations = blocks_add_kuka(observations)
|
| 1815 |
+
- ####
|
| 1816 |
+
-
|
| 1817 |
+
savepath = os.path.join(self.logdir, f'sample-{self.step}-{i}.png')
|
| 1818 |
+
self.renderer.composite(savepath, observations)
|
| 1819 |
+
+
|
| 1820 |
+
+ def linear_interpolation_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None):
|
| 1821 |
+
+ batch_size, horizon, transition = trajectories.shape
|
| 1822 |
+
+
|
| 1823 |
+
+ # Extract known indices and values
|
| 1824 |
+
+ known_indices = np.array(list(conditions.keys()), dtype=int)
|
| 1825 |
+
+ # candidate_no x batch_size x dim
|
| 1826 |
+
+ known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
|
| 1827 |
+
+ known_values = np.moveaxis(known_values, 0, 1)
|
| 1828 |
+
+
|
| 1829 |
+
+ # Create time steps for interpolation
|
| 1830 |
+
+ time_steps = np.linspace(0, horizon, num=horizon)
|
| 1831 |
+
+
|
| 1832 |
+
+ # Perform interpolation across all dimensions at once
|
| 1833 |
+
+ linear_int_arr = np.array([[
|
| 1834 |
+
+ np.interp(time_steps, known_indices, known_values[b, :, dim])
|
| 1835 |
+
+ for dim in range(transition)]
|
| 1836 |
+
+ for b in range(batch_size)]
|
| 1837 |
+
+ ).T # Transpose to match shape (horizon, transition)
|
| 1838 |
+
+
|
| 1839 |
+
+ # Convert to tensor and move to the same device as trajectories
|
| 1840 |
+
+ linear_int_arr = np.transpose(linear_int_arr, axes=[2, 0, 1])
|
| 1841 |
+
+ linear_int_tensor = torch.tensor(linear_int_arr, dtype=torch.float64, device=trajectories.device)
|
| 1842 |
+
+
|
| 1843 |
+
+ return loss_fc(linear_int_tensor, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), linear_int_tensor
|
| 1844 |
+
+
|
| 1845 |
+
+
|
| 1846 |
+
+ def catmull_rom_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None):
|
| 1847 |
+
+ '''
|
| 1848 |
+
+ loss for catmull-rom interpolation
|
| 1849 |
+
+ '''
|
| 1850 |
+
+
|
| 1851 |
+
+ batch_size, horizon, transition = trajectories.shape
|
| 1852 |
+
+
|
| 1853 |
+
+ # Extract known indices and values
|
| 1854 |
+
+ known_indices = np.array(list(conditions.keys()), dtype=int)
|
| 1855 |
+
+ # candidate_no x batch_size x dim
|
| 1856 |
+
+ known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
|
| 1857 |
+
+ known_values = np.moveaxis(known_values, 0, 1)
|
| 1858 |
+
+
|
| 1859 |
+
+ # Sort the timepoints
|
| 1860 |
+
+ sorted_indices = np.argsort(known_indices)
|
| 1861 |
+
+ known_indices = known_indices[sorted_indices]
|
| 1862 |
+
+ known_values = known_values[:, sorted_indices]
|
| 1863 |
+
+
|
| 1864 |
+
+ spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)])
|
| 1865 |
+
+
|
| 1866 |
+
+ # Convert to tensor and move to the same device as trajectories
|
| 1867 |
+
+ spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device)
|
| 1868 |
+
+
|
| 1869 |
+
+ assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}"
|
| 1870 |
+
+
|
| 1871 |
+
+ return loss_fc(spline_points, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), spline_points
|
| 1872 |
+
+
|
| 1873 |
+
+
|
| 1874 |
+
+
|
| 1875 |
+
+
|
| 1876 |
+
+
|
| 1877 |
+
+
|
| 1878 |
+
+
|
| 1879 |
+
+
|
| 1880 |
+
+
|
| 1881 |
+
+
|
| 1882 |
+
+
|
| 1883 |
+
+
|
| 1884 |
+
diff --git a/scripts/train.py b/scripts/train.py
|
| 1885 |
+
index 2c5f299..6728d6f 100644
|
| 1886 |
+
--- a/scripts/train.py
|
| 1887 |
+
+++ b/scripts/train.py
|
| 1888 |
+
@@ -108,6 +108,7 @@ utils.report_parameters(model)
|
| 1889 |
+
|
| 1890 |
+
print('Testing forward...', end=' ', flush=True)
|
| 1891 |
+
batch = utils.batchify(dataset[0])
|
| 1892 |
+
+
|
| 1893 |
+
loss, _ = diffusion.loss(*batch)
|
| 1894 |
+
loss.backward()
|
| 1895 |
+
print('✓')
|
residual-diffuser/diffusion_config.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02206556f60d5d7911ade8ae3a68cc6c59c8ce65aa6f16481125122ea83a827b
|
| 3 |
+
size 316
|
residual-diffuser/model_config.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c7d03e458df1b0f5eec375eaedbc0daaab6a95a996e158fbbfcf4128a25fc1e
|
| 3 |
+
size 202
|
residual-diffuser/render_config.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ddf10f1f4223e218c38e6190698e58a382fa33344e2d587bcde53f4700444683
|
| 3 |
+
size 156
|
residual-diffuser/state_58000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf0e30c9d06d00f9e8651f621640ab063f766fef6f85d0dd4912fe150ec6f083
|
| 3 |
+
size 59009153
|
residual-diffuser/test_indices.txt
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
369
|
| 2 |
+
764
|
| 3 |
+
474
|
| 4 |
+
312
|
| 5 |
+
857
|
| 6 |
+
384
|
| 7 |
+
323
|
| 8 |
+
548
|
| 9 |
+
796
|
| 10 |
+
212
|
| 11 |
+
595
|
| 12 |
+
388
|
| 13 |
+
444
|
| 14 |
+
120
|
| 15 |
+
598
|
| 16 |
+
302
|
| 17 |
+
633
|
| 18 |
+
688
|
| 19 |
+
653
|
| 20 |
+
20
|
| 21 |
+
665
|
| 22 |
+
67
|
| 23 |
+
130
|
| 24 |
+
56
|
| 25 |
+
822
|
| 26 |
+
160
|
| 27 |
+
169
|
| 28 |
+
30
|
| 29 |
+
623
|
| 30 |
+
200
|
| 31 |
+
520
|
| 32 |
+
13
|
| 33 |
+
273
|
| 34 |
+
296
|
| 35 |
+
411
|
| 36 |
+
530
|
| 37 |
+
367
|
| 38 |
+
579
|
| 39 |
+
788
|
| 40 |
+
387
|
| 41 |
+
8
|
| 42 |
+
216
|
| 43 |
+
738
|
| 44 |
+
527
|
| 45 |
+
35
|
| 46 |
+
713
|
| 47 |
+
416
|
| 48 |
+
422
|
| 49 |
+
492
|
| 50 |
+
680
|
| 51 |
+
757
|
| 52 |
+
435
|
| 53 |
+
218
|
| 54 |
+
643
|
| 55 |
+
489
|
| 56 |
+
481
|
| 57 |
+
54
|
| 58 |
+
760
|
| 59 |
+
558
|
| 60 |
+
485
|
| 61 |
+
666
|
| 62 |
+
619
|
| 63 |
+
806
|
| 64 |
+
724
|
| 65 |
+
742
|
| 66 |
+
452
|
| 67 |
+
445
|
| 68 |
+
137
|
| 69 |
+
165
|
| 70 |
+
260
|
| 71 |
+
855
|
| 72 |
+
95
|
| 73 |
+
191
|
| 74 |
+
736
|
| 75 |
+
71
|
| 76 |
+
860
|
| 77 |
+
210
|
| 78 |
+
176
|
| 79 |
+
662
|
| 80 |
+
480
|
| 81 |
+
583
|
| 82 |
+
34
|
| 83 |
+
471
|
| 84 |
+
772
|
| 85 |
+
393
|
| 86 |
+
466
|
| 87 |
+
469
|
| 88 |
+
111
|
| 89 |
+
687
|
| 90 |
+
125
|
| 91 |
+
231
|
| 92 |
+
123
|
| 93 |
+
366
|
| 94 |
+
304
|
| 95 |
+
262
|
| 96 |
+
97
|
| 97 |
+
597
|
| 98 |
+
177
|
| 99 |
+
636
|
| 100 |
+
350
|
residual-diffuser/train_indices.txt
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
749
|
| 2 |
+
792
|
| 3 |
+
22
|
| 4 |
+
3
|
| 5 |
+
775
|
| 6 |
+
420
|
| 7 |
+
83
|
| 8 |
+
284
|
| 9 |
+
635
|
| 10 |
+
376
|
| 11 |
+
700
|
| 12 |
+
754
|
| 13 |
+
575
|
| 14 |
+
115
|
| 15 |
+
122
|
| 16 |
+
751
|
| 17 |
+
826
|
| 18 |
+
695
|
| 19 |
+
263
|
| 20 |
+
577
|
| 21 |
+
856
|
| 22 |
+
336
|
| 23 |
+
249
|
| 24 |
+
150
|
| 25 |
+
226
|
| 26 |
+
248
|
| 27 |
+
478
|
| 28 |
+
617
|
| 29 |
+
535
|
| 30 |
+
10
|
| 31 |
+
329
|
| 32 |
+
46
|
| 33 |
+
821
|
| 34 |
+
206
|
| 35 |
+
807
|
| 36 |
+
147
|
| 37 |
+
345
|
| 38 |
+
766
|
| 39 |
+
768
|
| 40 |
+
254
|
| 41 |
+
164
|
| 42 |
+
188
|
| 43 |
+
133
|
| 44 |
+
437
|
| 45 |
+
716
|
| 46 |
+
532
|
| 47 |
+
391
|
| 48 |
+
426
|
| 49 |
+
105
|
| 50 |
+
728
|
| 51 |
+
463
|
| 52 |
+
864
|
| 53 |
+
5
|
| 54 |
+
178
|
| 55 |
+
640
|
| 56 |
+
774
|
| 57 |
+
837
|
| 58 |
+
309
|
| 59 |
+
348
|
| 60 |
+
850
|
| 61 |
+
205
|
| 62 |
+
314
|
| 63 |
+
346
|
| 64 |
+
385
|
| 65 |
+
423
|
| 66 |
+
425
|
| 67 |
+
707
|
| 68 |
+
163
|
| 69 |
+
415
|
| 70 |
+
412
|
| 71 |
+
599
|
| 72 |
+
503
|
| 73 |
+
490
|
| 74 |
+
319
|
| 75 |
+
693
|
| 76 |
+
274
|
| 77 |
+
156
|
| 78 |
+
316
|
| 79 |
+
135
|
| 80 |
+
721
|
| 81 |
+
153
|
| 82 |
+
72
|
| 83 |
+
162
|
| 84 |
+
765
|
| 85 |
+
684
|
| 86 |
+
138
|
| 87 |
+
90
|
| 88 |
+
834
|
| 89 |
+
229
|
| 90 |
+
673
|
| 91 |
+
195
|
| 92 |
+
94
|
| 93 |
+
569
|
| 94 |
+
270
|
| 95 |
+
786
|
| 96 |
+
342
|
| 97 |
+
745
|
| 98 |
+
356
|
| 99 |
+
80
|
| 100 |
+
685
|
| 101 |
+
677
|
| 102 |
+
544
|
| 103 |
+
220
|
| 104 |
+
672
|
| 105 |
+
668
|
| 106 |
+
361
|
| 107 |
+
7
|
| 108 |
+
31
|
| 109 |
+
203
|
| 110 |
+
142
|
| 111 |
+
539
|
| 112 |
+
421
|
| 113 |
+
378
|
| 114 |
+
16
|
| 115 |
+
816
|
| 116 |
+
443
|
| 117 |
+
696
|
| 118 |
+
276
|
| 119 |
+
151
|
| 120 |
+
472
|
| 121 |
+
616
|
| 122 |
+
305
|
| 123 |
+
681
|
| 124 |
+
93
|
| 125 |
+
491
|
| 126 |
+
849
|
| 127 |
+
183
|
| 128 |
+
300
|
| 129 |
+
675
|
| 130 |
+
753
|
| 131 |
+
294
|
| 132 |
+
649
|
| 133 |
+
691
|
| 134 |
+
175
|
| 135 |
+
467
|
| 136 |
+
144
|
| 137 |
+
501
|
| 138 |
+
858
|
| 139 |
+
779
|
| 140 |
+
869
|
| 141 |
+
246
|
| 142 |
+
867
|
| 143 |
+
333
|
| 144 |
+
258
|
| 145 |
+
414
|
| 146 |
+
84
|
| 147 |
+
486
|
| 148 |
+
632
|
| 149 |
+
0
|
| 150 |
+
519
|
| 151 |
+
830
|
| 152 |
+
600
|
| 153 |
+
541
|
| 154 |
+
52
|
| 155 |
+
631
|
| 156 |
+
198
|
| 157 |
+
626
|
| 158 |
+
278
|
| 159 |
+
552
|
| 160 |
+
547
|
| 161 |
+
235
|
| 162 |
+
559
|
| 163 |
+
528
|
| 164 |
+
353
|
| 165 |
+
86
|
| 166 |
+
88
|
| 167 |
+
718
|
| 168 |
+
234
|
| 169 |
+
828
|
| 170 |
+
295
|
| 171 |
+
829
|
| 172 |
+
646
|
| 173 |
+
874
|
| 174 |
+
564
|
| 175 |
+
525
|
| 176 |
+
810
|
| 177 |
+
682
|
| 178 |
+
250
|
| 179 |
+
861
|
| 180 |
+
217
|
| 181 |
+
748
|
| 182 |
+
113
|
| 183 |
+
740
|
| 184 |
+
505
|
| 185 |
+
770
|
| 186 |
+
787
|
| 187 |
+
611
|
| 188 |
+
642
|
| 189 |
+
550
|
| 190 |
+
63
|
| 191 |
+
567
|
| 192 |
+
549
|
| 193 |
+
21
|
| 194 |
+
588
|
| 195 |
+
524
|
| 196 |
+
752
|
| 197 |
+
747
|
| 198 |
+
48
|
| 199 |
+
53
|
| 200 |
+
285
|
| 201 |
+
335
|
| 202 |
+
66
|
| 203 |
+
453
|
| 204 |
+
124
|
| 205 |
+
360
|
| 206 |
+
815
|
| 207 |
+
390
|
| 208 |
+
781
|
| 209 |
+
543
|
| 210 |
+
201
|
| 211 |
+
283
|
| 212 |
+
141
|
| 213 |
+
434
|
| 214 |
+
230
|
| 215 |
+
613
|
| 216 |
+
193
|
| 217 |
+
608
|
| 218 |
+
508
|
| 219 |
+
199
|
| 220 |
+
732
|
| 221 |
+
741
|
| 222 |
+
222
|
| 223 |
+
76
|
| 224 |
+
555
|
| 225 |
+
261
|
| 226 |
+
96
|
| 227 |
+
436
|
| 228 |
+
282
|
| 229 |
+
45
|
| 230 |
+
589
|
| 231 |
+
11
|
| 232 |
+
459
|
| 233 |
+
382
|
| 234 |
+
57
|
| 235 |
+
877
|
| 236 |
+
70
|
| 237 |
+
537
|
| 238 |
+
801
|
| 239 |
+
129
|
| 240 |
+
722
|
| 241 |
+
494
|
| 242 |
+
823
|
| 243 |
+
26
|
| 244 |
+
377
|
| 245 |
+
326
|
| 246 |
+
820
|
| 247 |
+
310
|
| 248 |
+
77
|
| 249 |
+
876
|
| 250 |
+
565
|
| 251 |
+
504
|
| 252 |
+
406
|
| 253 |
+
686
|
| 254 |
+
811
|
| 255 |
+
482
|
| 256 |
+
499
|
| 257 |
+
507
|
| 258 |
+
458
|
| 259 |
+
386
|
| 260 |
+
847
|
| 261 |
+
658
|
| 262 |
+
708
|
| 263 |
+
100
|
| 264 |
+
60
|
| 265 |
+
607
|
| 266 |
+
817
|
| 267 |
+
663
|
| 268 |
+
428
|
| 269 |
+
859
|
| 270 |
+
313
|
| 271 |
+
68
|
| 272 |
+
32
|
| 273 |
+
267
|
| 274 |
+
701
|
| 275 |
+
139
|
| 276 |
+
349
|
| 277 |
+
487
|
| 278 |
+
289
|
| 279 |
+
225
|
| 280 |
+
840
|
| 281 |
+
375
|
| 282 |
+
186
|
| 283 |
+
875
|
| 284 |
+
832
|
| 285 |
+
381
|
| 286 |
+
667
|
| 287 |
+
777
|
| 288 |
+
515
|
| 289 |
+
298
|
| 290 |
+
862
|
| 291 |
+
773
|
| 292 |
+
509
|
| 293 |
+
715
|
| 294 |
+
449
|
| 295 |
+
664
|
| 296 |
+
358
|
| 297 |
+
121
|
| 298 |
+
172
|
| 299 |
+
594
|
| 300 |
+
39
|
| 301 |
+
553
|
| 302 |
+
468
|
| 303 |
+
370
|
| 304 |
+
424
|
| 305 |
+
570
|
| 306 |
+
614
|
| 307 |
+
477
|
| 308 |
+
645
|
| 309 |
+
400
|
| 310 |
+
179
|
| 311 |
+
441
|
| 312 |
+
767
|
| 313 |
+
865
|
| 314 |
+
55
|
| 315 |
+
497
|
| 316 |
+
288
|
| 317 |
+
704
|
| 318 |
+
551
|
| 319 |
+
809
|
| 320 |
+
498
|
| 321 |
+
354
|
| 322 |
+
82
|
| 323 |
+
408
|
| 324 |
+
157
|
| 325 |
+
460
|
| 326 |
+
98
|
| 327 |
+
145
|
| 328 |
+
439
|
| 329 |
+
591
|
| 330 |
+
556
|
| 331 |
+
211
|
| 332 |
+
606
|
| 333 |
+
9
|
| 334 |
+
538
|
| 335 |
+
719
|
| 336 |
+
720
|
| 337 |
+
641
|
| 338 |
+
240
|
| 339 |
+
841
|
| 340 |
+
17
|
| 341 |
+
112
|
| 342 |
+
465
|
| 343 |
+
733
|
| 344 |
+
372
|
| 345 |
+
338
|
| 346 |
+
268
|
| 347 |
+
219
|
| 348 |
+
365
|
| 349 |
+
624
|
| 350 |
+
114
|
| 351 |
+
87
|
| 352 |
+
585
|
| 353 |
+
128
|
| 354 |
+
14
|
| 355 |
+
612
|
| 356 |
+
803
|
| 357 |
+
215
|
| 358 |
+
836
|
| 359 |
+
297
|
| 360 |
+
255
|
| 361 |
+
795
|
| 362 |
+
495
|
| 363 |
+
730
|
| 364 |
+
29
|
| 365 |
+
838
|
| 366 |
+
99
|
| 367 |
+
161
|
| 368 |
+
814
|
| 369 |
+
51
|
| 370 |
+
794
|
| 371 |
+
36
|
| 372 |
+
170
|
| 373 |
+
171
|
| 374 |
+
644
|
| 375 |
+
271
|
| 376 |
+
281
|
| 377 |
+
131
|
| 378 |
+
603
|
| 379 |
+
514
|
| 380 |
+
562
|
| 381 |
+
389
|
| 382 |
+
180
|
| 383 |
+
239
|
| 384 |
+
118
|
| 385 |
+
593
|
| 386 |
+
407
|
| 387 |
+
574
|
| 388 |
+
79
|
| 389 |
+
714
|
| 390 |
+
655
|
| 391 |
+
394
|
| 392 |
+
622
|
| 393 |
+
166
|
| 394 |
+
802
|
| 395 |
+
780
|
| 396 |
+
639
|
| 397 |
+
101
|
| 398 |
+
19
|
| 399 |
+
253
|
| 400 |
+
674
|
| 401 |
+
763
|
| 402 |
+
269
|
| 403 |
+
427
|
| 404 |
+
364
|
| 405 |
+
689
|
| 406 |
+
526
|
| 407 |
+
800
|
| 408 |
+
303
|
| 409 |
+
207
|
| 410 |
+
630
|
| 411 |
+
184
|
| 412 |
+
168
|
| 413 |
+
290
|
| 414 |
+
392
|
| 415 |
+
251
|
| 416 |
+
561
|
| 417 |
+
506
|
| 418 |
+
65
|
| 419 |
+
450
|
| 420 |
+
651
|
| 421 |
+
399
|
| 422 |
+
484
|
| 423 |
+
190
|
| 424 |
+
851
|
| 425 |
+
28
|
| 426 |
+
637
|
| 427 |
+
202
|
| 428 |
+
657
|
| 429 |
+
656
|
| 430 |
+
868
|
| 431 |
+
808
|
| 432 |
+
568
|
| 433 |
+
223
|
| 434 |
+
236
|
| 435 |
+
108
|
| 436 |
+
580
|
| 437 |
+
853
|
| 438 |
+
75
|
| 439 |
+
58
|
| 440 |
+
327
|
| 441 |
+
671
|
| 442 |
+
510
|
| 443 |
+
578
|
| 444 |
+
797
|
| 445 |
+
536
|
| 446 |
+
110
|
| 447 |
+
602
|
| 448 |
+
196
|
| 449 |
+
798
|
| 450 |
+
790
|
| 451 |
+
185
|
| 452 |
+
670
|
| 453 |
+
448
|
| 454 |
+
523
|
| 455 |
+
692
|
| 456 |
+
430
|
| 457 |
+
140
|
| 458 |
+
383
|
| 459 |
+
252
|
| 460 |
+
531
|
| 461 |
+
746
|
| 462 |
+
456
|
| 463 |
+
557
|
| 464 |
+
522
|
| 465 |
+
690
|
| 466 |
+
782
|
| 467 |
+
286
|
| 468 |
+
92
|
| 469 |
+
618
|
| 470 |
+
723
|
| 471 |
+
605
|
| 472 |
+
49
|
| 473 |
+
563
|
| 474 |
+
243
|
| 475 |
+
64
|
| 476 |
+
455
|
| 477 |
+
804
|
| 478 |
+
299
|
| 479 |
+
292
|
| 480 |
+
529
|
| 481 |
+
835
|
| 482 |
+
2
|
| 483 |
+
831
|
| 484 |
+
197
|
| 485 |
+
479
|
| 486 |
+
102
|
| 487 |
+
470
|
| 488 |
+
47
|
| 489 |
+
762
|
| 490 |
+
518
|
| 491 |
+
703
|
| 492 |
+
842
|
| 493 |
+
337
|
| 494 |
+
678
|
| 495 |
+
698
|
| 496 |
+
189
|
| 497 |
+
247
|
| 498 |
+
410
|
| 499 |
+
213
|
| 500 |
+
401
|
| 501 |
+
277
|
| 502 |
+
280
|
| 503 |
+
173
|
| 504 |
+
328
|
| 505 |
+
818
|
| 506 |
+
744
|
| 507 |
+
238
|
| 508 |
+
315
|
| 509 |
+
676
|
| 510 |
+
872
|
| 511 |
+
756
|
| 512 |
+
244
|
| 513 |
+
291
|
| 514 |
+
727
|
| 515 |
+
155
|
| 516 |
+
208
|
| 517 |
+
25
|
| 518 |
+
339
|
| 519 |
+
755
|
| 520 |
+
844
|
| 521 |
+
592
|
| 522 |
+
321
|
| 523 |
+
521
|
| 524 |
+
546
|
| 525 |
+
44
|
| 526 |
+
242
|
| 527 |
+
759
|
| 528 |
+
769
|
| 529 |
+
652
|
| 530 |
+
181
|
| 531 |
+
275
|
| 532 |
+
50
|
| 533 |
+
833
|
| 534 |
+
371
|
| 535 |
+
584
|
| 536 |
+
758
|
| 537 |
+
433
|
| 538 |
+
279
|
| 539 |
+
627
|
| 540 |
+
107
|
| 541 |
+
15
|
| 542 |
+
109
|
| 543 |
+
854
|
| 544 |
+
227
|
| 545 |
+
596
|
| 546 |
+
395
|
| 547 |
+
182
|
| 548 |
+
778
|
| 549 |
+
648
|
| 550 |
+
825
|
| 551 |
+
628
|
| 552 |
+
18
|
| 553 |
+
717
|
| 554 |
+
398
|
| 555 |
+
601
|
| 556 |
+
566
|
| 557 |
+
625
|
| 558 |
+
447
|
| 559 |
+
660
|
| 560 |
+
647
|
| 561 |
+
866
|
| 562 |
+
735
|
| 563 |
+
462
|
| 564 |
+
590
|
| 565 |
+
351
|
| 566 |
+
659
|
| 567 |
+
330
|
| 568 |
+
634
|
| 569 |
+
126
|
| 570 |
+
334
|
| 571 |
+
324
|
| 572 |
+
783
|
| 573 |
+
516
|
| 574 |
+
500
|
| 575 |
+
743
|
| 576 |
+
739
|
| 577 |
+
784
|
| 578 |
+
457
|
| 579 |
+
709
|
| 580 |
+
318
|
| 581 |
+
726
|
| 582 |
+
192
|
| 583 |
+
697
|
| 584 |
+
69
|
| 585 |
+
204
|
| 586 |
+
669
|
| 587 |
+
461
|
| 588 |
+
413
|
| 589 |
+
650
|
| 590 |
+
362
|
| 591 |
+
824
|
| 592 |
+
127
|
| 593 |
+
871
|
| 594 |
+
805
|
| 595 |
+
355
|
| 596 |
+
442
|
| 597 |
+
347
|
| 598 |
+
209
|
| 599 |
+
117
|
| 600 |
+
306
|
| 601 |
+
332
|
| 602 |
+
379
|
| 603 |
+
42
|
| 604 |
+
152
|
| 605 |
+
512
|
| 606 |
+
638
|
| 607 |
+
106
|
| 608 |
+
187
|
| 609 |
+
194
|
| 610 |
+
159
|
| 611 |
+
89
|
| 612 |
+
712
|
| 613 |
+
119
|
| 614 |
+
307
|
| 615 |
+
214
|
| 616 |
+
403
|
| 617 |
+
705
|
| 618 |
+
582
|
| 619 |
+
586
|
| 620 |
+
264
|
| 621 |
+
502
|
| 622 |
+
488
|
| 623 |
+
409
|
| 624 |
+
621
|
| 625 |
+
233
|
| 626 |
+
340
|
| 627 |
+
863
|
| 628 |
+
737
|
| 629 |
+
432
|
| 630 |
+
24
|
| 631 |
+
576
|
| 632 |
+
454
|
| 633 |
+
464
|
| 634 |
+
85
|
| 635 |
+
404
|
| 636 |
+
517
|
| 637 |
+
451
|
| 638 |
+
513
|
| 639 |
+
483
|
| 640 |
+
363
|
| 641 |
+
317
|
| 642 |
+
573
|
| 643 |
+
620
|
| 644 |
+
74
|
| 645 |
+
609
|
| 646 |
+
706
|
| 647 |
+
301
|
| 648 |
+
259
|
| 649 |
+
812
|
| 650 |
+
679
|
| 651 |
+
396
|
| 652 |
+
610
|
| 653 |
+
476
|
| 654 |
+
417
|
| 655 |
+
827
|
| 656 |
+
405
|
| 657 |
+
325
|
| 658 |
+
581
|
| 659 |
+
6
|
| 660 |
+
846
|
| 661 |
+
418
|
| 662 |
+
132
|
| 663 |
+
710
|
| 664 |
+
272
|
| 665 |
+
571
|
| 666 |
+
368
|
| 667 |
+
533
|
| 668 |
+
839
|
| 669 |
+
402
|
| 670 |
+
380
|
| 671 |
+
245
|
| 672 |
+
174
|
| 673 |
+
228
|
| 674 |
+
266
|
| 675 |
+
143
|
| 676 |
+
785
|
| 677 |
+
587
|
| 678 |
+
661
|
| 679 |
+
344
|
| 680 |
+
38
|
| 681 |
+
848
|
| 682 |
+
154
|
| 683 |
+
265
|
| 684 |
+
771
|
| 685 |
+
791
|
| 686 |
+
761
|
| 687 |
+
493
|
| 688 |
+
542
|
| 689 |
+
359
|
| 690 |
+
62
|
| 691 |
+
149
|
residual-diffuser/trainer_config.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b582355f277afd95fac231f750f74b3e2e3e301b0420e40098d54537790de3ce
|
| 3 |
+
size 381
|
residual-diffuser/val_indices.txt
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
136
|
| 2 |
+
604
|
| 3 |
+
683
|
| 4 |
+
146
|
| 5 |
+
654
|
| 6 |
+
373
|
| 7 |
+
511
|
| 8 |
+
73
|
| 9 |
+
343
|
| 10 |
+
311
|
| 11 |
+
819
|
| 12 |
+
331
|
| 13 |
+
750
|
| 14 |
+
33
|
| 15 |
+
341
|
| 16 |
+
694
|
| 17 |
+
237
|
| 18 |
+
572
|
| 19 |
+
91
|
| 20 |
+
134
|
| 21 |
+
357
|
| 22 |
+
540
|
| 23 |
+
729
|
| 24 |
+
776
|
| 25 |
+
789
|
| 26 |
+
61
|
| 27 |
+
852
|
| 28 |
+
699
|
| 29 |
+
12
|
| 30 |
+
232
|
| 31 |
+
734
|
| 32 |
+
43
|
| 33 |
+
27
|
| 34 |
+
545
|
| 35 |
+
158
|
| 36 |
+
224
|
| 37 |
+
438
|
| 38 |
+
629
|
| 39 |
+
104
|
| 40 |
+
554
|
| 41 |
+
429
|
| 42 |
+
534
|
| 43 |
+
873
|
| 44 |
+
322
|
| 45 |
+
496
|
| 46 |
+
446
|
| 47 |
+
287
|
| 48 |
+
397
|
| 49 |
+
1
|
| 50 |
+
293
|
| 51 |
+
725
|
| 52 |
+
81
|
| 53 |
+
440
|
| 54 |
+
419
|
| 55 |
+
702
|
| 56 |
+
870
|
| 57 |
+
308
|
| 58 |
+
793
|
| 59 |
+
103
|
| 60 |
+
843
|
| 61 |
+
78
|
| 62 |
+
256
|
| 63 |
+
475
|
| 64 |
+
560
|
| 65 |
+
711
|
| 66 |
+
813
|
| 67 |
+
431
|
| 68 |
+
374
|
| 69 |
+
731
|
| 70 |
+
23
|
| 71 |
+
167
|
| 72 |
+
37
|
| 73 |
+
4
|
| 74 |
+
352
|
| 75 |
+
116
|
| 76 |
+
148
|
| 77 |
+
59
|
| 78 |
+
845
|
| 79 |
+
221
|
| 80 |
+
257
|
| 81 |
+
40
|
| 82 |
+
615
|
| 83 |
+
473
|
| 84 |
+
320
|
| 85 |
+
241
|
| 86 |
+
799
|
| 87 |
+
41
|