Ata Celen commited on
Commit
dcaa3ad
·
1 Parent(s): bb12ece

Model Weights added

Browse files
qwen2-vl-3d/config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Qwen/Qwen2-VL-7B-Instruct",
3
+ "architectures": [
4
+ "Qwen2VLSpatialForConditionalGeneration"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 3584,
11
+ "image_token_id": 151655,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 18944,
14
+ "max_position_embeddings": 32768,
15
+ "max_window_layers": 28,
16
+ "model_type": "qwen2_vl",
17
+ "num_attention_heads": 28,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 4,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": {
22
+ "mrope_section": [
23
+ 16,
24
+ 24,
25
+ 24
26
+ ],
27
+ "rope_type": "default",
28
+ "type": "default"
29
+ },
30
+ "rope_theta": 1000000.0,
31
+ "sliding_window": 32768,
32
+ "tie_word_embeddings": false,
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.49.0.dev0",
35
+ "use_cache": true,
36
+ "use_sliding_window": false,
37
+ "video_token_id": 151656,
38
+ "vision_config": {
39
+ "in_chans": 3,
40
+ "model_type": "qwen2_vl",
41
+ "spatial_patch_size": 14,
42
+ "torch_dtype": "bfloat16"
43
+ },
44
+ "vision_end_token_id": 151653,
45
+ "vision_start_token_id": 151652,
46
+ "vision_token_id": 151654,
47
+ "vocab_size": 151660
48
+ }
qwen2-vl-3d/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_implementation": "flash_attention_2",
3
+ "bos_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "pad_token_id": 151643,
10
+ "temperature": 0.01,
11
+ "top_k": 1,
12
+ "top_p": 0.001,
13
+ "transformers_version": "4.49.0.dev0"
14
+ }
qwen2-vl-3d/model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ebcaeaea7f52c327c50f91b36d0f3d63cd96b445572c129fdb2760b001bd323
3
+ size 4963764072
qwen2-vl-3d/model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:638919d8742180acc5d1ba5016cb77d9c1d51eb0aaf4966410e82d2dcac48580
3
+ size 4991495816
qwen2-vl-3d/model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff7bf17a78fd87ad5adf702a460330b1ae89ae2f121175eb6045f3a6a3c905fc
3
+ size 4932751040
qwen2-vl-3d/model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:254fc35d50a309eefb1ef995ae67767137ad0a01b573b6ad30cf66cb6896e6a1
3
+ size 1720377840
qwen2-vl-3d/model.safetensors.index.json ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16608291000
4
+ },
5
+ "weight_map": {
6
+ "adapter.bias": "model-00004-of-00004.safetensors",
7
+ "adapter.weight": "model-00004-of-00004.safetensors",
8
+ "diffuser.alphas_cumprod": "model-00004-of-00004.safetensors",
9
+ "diffuser.alphas_cumprod_prev": "model-00004-of-00004.safetensors",
10
+ "diffuser.betas": "model-00004-of-00004.safetensors",
11
+ "diffuser.log_one_minus_alphas_cumprod": "model-00004-of-00004.safetensors",
12
+ "diffuser.model.downs.0.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
13
+ "diffuser.model.downs.0.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
14
+ "diffuser.model.downs.0.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
15
+ "diffuser.model.downs.0.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
16
+ "diffuser.model.downs.0.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
17
+ "diffuser.model.downs.0.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
18
+ "diffuser.model.downs.0.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
19
+ "diffuser.model.downs.0.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
20
+ "diffuser.model.downs.0.0.residual_conv.bias": "model-00004-of-00004.safetensors",
21
+ "diffuser.model.downs.0.0.residual_conv.weight": "model-00004-of-00004.safetensors",
22
+ "diffuser.model.downs.0.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
23
+ "diffuser.model.downs.0.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
24
+ "diffuser.model.downs.0.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
25
+ "diffuser.model.downs.0.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
26
+ "diffuser.model.downs.0.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
27
+ "diffuser.model.downs.0.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
28
+ "diffuser.model.downs.0.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
29
+ "diffuser.model.downs.0.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
30
+ "diffuser.model.downs.0.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
31
+ "diffuser.model.downs.0.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
32
+ "diffuser.model.downs.0.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
33
+ "diffuser.model.downs.0.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
34
+ "diffuser.model.downs.0.2.conv.bias": "model-00004-of-00004.safetensors",
35
+ "diffuser.model.downs.0.2.conv.weight": "model-00004-of-00004.safetensors",
36
+ "diffuser.model.downs.1.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
37
+ "diffuser.model.downs.1.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
38
+ "diffuser.model.downs.1.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
39
+ "diffuser.model.downs.1.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
40
+ "diffuser.model.downs.1.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
41
+ "diffuser.model.downs.1.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
42
+ "diffuser.model.downs.1.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
43
+ "diffuser.model.downs.1.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
44
+ "diffuser.model.downs.1.0.residual_conv.bias": "model-00004-of-00004.safetensors",
45
+ "diffuser.model.downs.1.0.residual_conv.weight": "model-00004-of-00004.safetensors",
46
+ "diffuser.model.downs.1.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
47
+ "diffuser.model.downs.1.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
48
+ "diffuser.model.downs.1.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
49
+ "diffuser.model.downs.1.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
50
+ "diffuser.model.downs.1.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
51
+ "diffuser.model.downs.1.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
52
+ "diffuser.model.downs.1.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
53
+ "diffuser.model.downs.1.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
54
+ "diffuser.model.downs.1.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
55
+ "diffuser.model.downs.1.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
56
+ "diffuser.model.downs.1.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
57
+ "diffuser.model.downs.1.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
58
+ "diffuser.model.downs.1.2.conv.bias": "model-00004-of-00004.safetensors",
59
+ "diffuser.model.downs.1.2.conv.weight": "model-00004-of-00004.safetensors",
60
+ "diffuser.model.downs.2.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
61
+ "diffuser.model.downs.2.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
62
+ "diffuser.model.downs.2.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
63
+ "diffuser.model.downs.2.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
64
+ "diffuser.model.downs.2.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
65
+ "diffuser.model.downs.2.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
66
+ "diffuser.model.downs.2.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
67
+ "diffuser.model.downs.2.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
68
+ "diffuser.model.downs.2.0.residual_conv.bias": "model-00004-of-00004.safetensors",
69
+ "diffuser.model.downs.2.0.residual_conv.weight": "model-00004-of-00004.safetensors",
70
+ "diffuser.model.downs.2.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
71
+ "diffuser.model.downs.2.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
72
+ "diffuser.model.downs.2.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
73
+ "diffuser.model.downs.2.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
74
+ "diffuser.model.downs.2.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
75
+ "diffuser.model.downs.2.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
76
+ "diffuser.model.downs.2.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
77
+ "diffuser.model.downs.2.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
78
+ "diffuser.model.downs.2.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
79
+ "diffuser.model.downs.2.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
80
+ "diffuser.model.downs.2.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
81
+ "diffuser.model.downs.2.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
82
+ "diffuser.model.final_conv.0.block.0.bias": "model-00004-of-00004.safetensors",
83
+ "diffuser.model.final_conv.0.block.0.weight": "model-00004-of-00004.safetensors",
84
+ "diffuser.model.final_conv.0.block.2.bias": "model-00004-of-00004.safetensors",
85
+ "diffuser.model.final_conv.0.block.2.weight": "model-00004-of-00004.safetensors",
86
+ "diffuser.model.final_conv.1.bias": "model-00004-of-00004.safetensors",
87
+ "diffuser.model.final_conv.1.weight": "model-00004-of-00004.safetensors",
88
+ "diffuser.model.mid_block1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
89
+ "diffuser.model.mid_block1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
90
+ "diffuser.model.mid_block1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
91
+ "diffuser.model.mid_block1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
92
+ "diffuser.model.mid_block1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
93
+ "diffuser.model.mid_block1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
94
+ "diffuser.model.mid_block1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
95
+ "diffuser.model.mid_block1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
96
+ "diffuser.model.mid_block1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
97
+ "diffuser.model.mid_block1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
98
+ "diffuser.model.mid_block2.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
99
+ "diffuser.model.mid_block2.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
100
+ "diffuser.model.mid_block2.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
101
+ "diffuser.model.mid_block2.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
102
+ "diffuser.model.mid_block2.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
103
+ "diffuser.model.mid_block2.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
104
+ "diffuser.model.mid_block2.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
105
+ "diffuser.model.mid_block2.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
106
+ "diffuser.model.mid_block2.time_mlp.1.bias": "model-00004-of-00004.safetensors",
107
+ "diffuser.model.mid_block2.time_mlp.1.weight": "model-00004-of-00004.safetensors",
108
+ "diffuser.model.time_mlp.1.bias": "model-00004-of-00004.safetensors",
109
+ "diffuser.model.time_mlp.1.weight": "model-00004-of-00004.safetensors",
110
+ "diffuser.model.time_mlp.3.bias": "model-00004-of-00004.safetensors",
111
+ "diffuser.model.time_mlp.3.weight": "model-00004-of-00004.safetensors",
112
+ "diffuser.model.ups.0.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
113
+ "diffuser.model.ups.0.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
114
+ "diffuser.model.ups.0.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
115
+ "diffuser.model.ups.0.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
116
+ "diffuser.model.ups.0.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
117
+ "diffuser.model.ups.0.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
118
+ "diffuser.model.ups.0.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
119
+ "diffuser.model.ups.0.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
120
+ "diffuser.model.ups.0.0.residual_conv.bias": "model-00004-of-00004.safetensors",
121
+ "diffuser.model.ups.0.0.residual_conv.weight": "model-00004-of-00004.safetensors",
122
+ "diffuser.model.ups.0.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
123
+ "diffuser.model.ups.0.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
124
+ "diffuser.model.ups.0.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
125
+ "diffuser.model.ups.0.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
126
+ "diffuser.model.ups.0.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
127
+ "diffuser.model.ups.0.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
128
+ "diffuser.model.ups.0.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
129
+ "diffuser.model.ups.0.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
130
+ "diffuser.model.ups.0.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
131
+ "diffuser.model.ups.0.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
132
+ "diffuser.model.ups.0.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
133
+ "diffuser.model.ups.0.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
134
+ "diffuser.model.ups.0.2.conv.bias": "model-00004-of-00004.safetensors",
135
+ "diffuser.model.ups.0.2.conv.weight": "model-00004-of-00004.safetensors",
136
+ "diffuser.model.ups.1.0.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
137
+ "diffuser.model.ups.1.0.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
138
+ "diffuser.model.ups.1.0.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
139
+ "diffuser.model.ups.1.0.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
140
+ "diffuser.model.ups.1.0.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
141
+ "diffuser.model.ups.1.0.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
142
+ "diffuser.model.ups.1.0.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
143
+ "diffuser.model.ups.1.0.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
144
+ "diffuser.model.ups.1.0.residual_conv.bias": "model-00004-of-00004.safetensors",
145
+ "diffuser.model.ups.1.0.residual_conv.weight": "model-00004-of-00004.safetensors",
146
+ "diffuser.model.ups.1.0.time_mlp.1.bias": "model-00004-of-00004.safetensors",
147
+ "diffuser.model.ups.1.0.time_mlp.1.weight": "model-00004-of-00004.safetensors",
148
+ "diffuser.model.ups.1.1.blocks.0.block.0.bias": "model-00004-of-00004.safetensors",
149
+ "diffuser.model.ups.1.1.blocks.0.block.0.weight": "model-00004-of-00004.safetensors",
150
+ "diffuser.model.ups.1.1.blocks.0.block.2.bias": "model-00004-of-00004.safetensors",
151
+ "diffuser.model.ups.1.1.blocks.0.block.2.weight": "model-00004-of-00004.safetensors",
152
+ "diffuser.model.ups.1.1.blocks.1.block.0.bias": "model-00004-of-00004.safetensors",
153
+ "diffuser.model.ups.1.1.blocks.1.block.0.weight": "model-00004-of-00004.safetensors",
154
+ "diffuser.model.ups.1.1.blocks.1.block.2.bias": "model-00004-of-00004.safetensors",
155
+ "diffuser.model.ups.1.1.blocks.1.block.2.weight": "model-00004-of-00004.safetensors",
156
+ "diffuser.model.ups.1.1.time_mlp.1.bias": "model-00004-of-00004.safetensors",
157
+ "diffuser.model.ups.1.1.time_mlp.1.weight": "model-00004-of-00004.safetensors",
158
+ "diffuser.model.ups.1.2.conv.bias": "model-00004-of-00004.safetensors",
159
+ "diffuser.model.ups.1.2.conv.weight": "model-00004-of-00004.safetensors",
160
+ "diffuser.posterior_log_variance_clipped": "model-00004-of-00004.safetensors",
161
+ "diffuser.posterior_mean_coef1": "model-00004-of-00004.safetensors",
162
+ "diffuser.posterior_mean_coef2": "model-00004-of-00004.safetensors",
163
+ "diffuser.posterior_variance": "model-00004-of-00004.safetensors",
164
+ "diffuser.sqrt_alphas_cumprod": "model-00004-of-00004.safetensors",
165
+ "diffuser.sqrt_one_minus_alphas_cumprod": "model-00004-of-00004.safetensors",
166
+ "diffuser.sqrt_recip_alphas_cumprod": "model-00004-of-00004.safetensors",
167
+ "diffuser.sqrt_recipm1_alphas_cumprod": "model-00004-of-00004.safetensors",
168
+ "lm_head.weight": "model-00004-of-00004.safetensors",
169
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
170
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
171
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
172
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
173
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
174
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
175
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
176
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
177
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
178
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
179
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
180
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
181
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
182
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
183
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
184
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
185
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
186
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
187
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
188
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
189
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
190
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
191
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
192
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
193
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
194
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
195
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
196
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
197
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
198
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
199
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
200
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
201
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
202
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
203
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
204
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
205
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
206
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
207
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
208
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
209
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
210
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
211
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
212
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
213
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
214
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
215
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
216
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
217
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
218
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
219
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
220
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
221
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
222
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
223
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
224
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
225
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
226
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
227
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
228
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
229
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
230
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
231
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
232
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
233
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
234
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
235
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
236
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
237
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
238
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
239
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
240
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
241
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
242
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
243
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
244
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
245
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
246
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
247
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
248
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
249
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
250
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
251
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
252
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
253
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
254
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
255
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
256
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
257
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
258
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
259
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
260
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
261
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
262
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
263
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
264
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
265
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
266
+ "model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
267
+ "model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
268
+ "model.layers.16.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
269
+ "model.layers.16.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
270
+ "model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
271
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
272
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
273
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
274
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
275
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
276
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
277
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
278
+ "model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
280
+ "model.layers.17.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
281
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
282
+ "model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
283
+ "model.layers.17.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
284
+ "model.layers.17.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
285
+ "model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
286
+ "model.layers.17.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
287
+ "model.layers.17.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
288
+ "model.layers.17.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
289
+ "model.layers.17.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
290
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
291
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
292
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
293
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
294
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
295
+ "model.layers.18.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
296
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
297
+ "model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
298
+ "model.layers.18.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
299
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
300
+ "model.layers.18.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
301
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
302
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
303
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
304
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
305
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
306
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
307
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
308
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
309
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
310
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
311
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
312
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
313
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
314
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
315
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
316
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
317
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
318
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
319
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
320
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
321
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
322
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
323
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
324
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
325
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
326
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
327
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
328
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
329
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
330
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
331
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
332
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
333
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
334
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
335
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
336
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
337
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
338
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
339
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
340
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
341
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
342
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
343
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
344
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
345
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
346
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
347
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
348
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
349
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
350
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
351
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
352
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
353
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
354
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
355
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
356
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
357
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
358
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
359
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
360
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
361
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
362
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
363
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
364
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
365
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
366
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
367
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
368
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
369
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
370
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
371
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
372
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
373
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
374
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
375
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
376
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
377
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
378
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
379
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
380
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
381
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
382
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
383
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
384
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
385
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
386
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
387
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
388
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
389
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
390
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
391
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
392
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
393
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
394
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
395
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
396
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
397
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
398
+ "model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
399
+ "model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
400
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
401
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
402
+ "model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
403
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
404
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
405
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
406
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
407
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
408
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
409
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
410
+ "model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
411
+ "model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
412
+ "model.layers.27.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
413
+ "model.layers.27.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
414
+ "model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
415
+ "model.layers.27.self_attn.k_proj.bias": "model-00004-of-00004.safetensors",
416
+ "model.layers.27.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
417
+ "model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
418
+ "model.layers.27.self_attn.q_proj.bias": "model-00004-of-00004.safetensors",
419
+ "model.layers.27.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
420
+ "model.layers.27.self_attn.v_proj.bias": "model-00004-of-00004.safetensors",
421
+ "model.layers.27.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
422
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
423
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
424
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
425
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
426
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
427
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
428
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
429
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
430
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
431
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
432
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
433
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
434
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
435
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
436
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
437
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
438
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
439
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
440
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
441
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
442
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
443
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
444
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
445
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
446
+ "model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
447
+ "model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
448
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
449
+ "model.layers.5.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
450
+ "model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
451
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
452
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
453
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
454
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
455
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
456
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
457
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
458
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
459
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
460
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
461
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
462
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
463
+ "model.layers.6.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
464
+ "model.layers.6.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
465
+ "model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
466
+ "model.layers.6.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
467
+ "model.layers.6.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
468
+ "model.layers.6.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
469
+ "model.layers.6.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
470
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
471
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
472
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
473
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
474
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
475
+ "model.layers.7.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
476
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
477
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
478
+ "model.layers.7.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
479
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
480
+ "model.layers.7.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
481
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
482
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
483
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
484
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
485
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
486
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
487
+ "model.layers.8.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
488
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
489
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
490
+ "model.layers.8.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
491
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
492
+ "model.layers.8.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
493
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
494
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
495
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
496
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
497
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
498
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
499
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
500
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
501
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
502
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
503
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
504
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
505
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
506
+ "model.norm.weight": "model-00004-of-00004.safetensors",
507
+ "visual.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
508
+ "visual.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
509
+ "visual.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
510
+ "visual.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
511
+ "visual.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
512
+ "visual.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
513
+ "visual.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
514
+ "visual.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
515
+ "visual.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
516
+ "visual.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
517
+ "visual.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
518
+ "visual.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
519
+ "visual.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
520
+ "visual.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
521
+ "visual.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
522
+ "visual.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
523
+ "visual.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
524
+ "visual.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
525
+ "visual.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
526
+ "visual.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
527
+ "visual.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
528
+ "visual.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
529
+ "visual.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
530
+ "visual.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
531
+ "visual.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
532
+ "visual.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
533
+ "visual.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
534
+ "visual.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
535
+ "visual.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
536
+ "visual.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
537
+ "visual.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
538
+ "visual.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
539
+ "visual.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
540
+ "visual.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
541
+ "visual.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
542
+ "visual.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
543
+ "visual.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
544
+ "visual.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
545
+ "visual.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
546
+ "visual.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
547
+ "visual.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
548
+ "visual.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
549
+ "visual.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
550
+ "visual.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
551
+ "visual.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
552
+ "visual.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
553
+ "visual.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
554
+ "visual.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
555
+ "visual.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
556
+ "visual.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
557
+ "visual.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
558
+ "visual.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
559
+ "visual.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
560
+ "visual.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
561
+ "visual.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
562
+ "visual.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
563
+ "visual.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
564
+ "visual.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
565
+ "visual.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
566
+ "visual.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
567
+ "visual.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
568
+ "visual.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
569
+ "visual.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
570
+ "visual.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
571
+ "visual.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
572
+ "visual.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
573
+ "visual.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
574
+ "visual.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
575
+ "visual.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
576
+ "visual.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
577
+ "visual.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
578
+ "visual.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
579
+ "visual.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
580
+ "visual.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
581
+ "visual.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
582
+ "visual.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
583
+ "visual.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
584
+ "visual.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
585
+ "visual.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
586
+ "visual.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
587
+ "visual.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
588
+ "visual.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
589
+ "visual.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
590
+ "visual.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
591
+ "visual.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
592
+ "visual.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
593
+ "visual.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
594
+ "visual.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
595
+ "visual.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
596
+ "visual.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
597
+ "visual.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
598
+ "visual.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
599
+ "visual.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
600
+ "visual.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
601
+ "visual.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
602
+ "visual.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
603
+ "visual.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
604
+ "visual.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
605
+ "visual.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
606
+ "visual.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
607
+ "visual.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
608
+ "visual.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
609
+ "visual.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
610
+ "visual.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
611
+ "visual.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
612
+ "visual.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
613
+ "visual.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
614
+ "visual.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
615
+ "visual.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
616
+ "visual.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
617
+ "visual.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
618
+ "visual.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
619
+ "visual.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
620
+ "visual.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
621
+ "visual.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
622
+ "visual.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
623
+ "visual.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
624
+ "visual.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
625
+ "visual.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
626
+ "visual.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
627
+ "visual.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
628
+ "visual.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
629
+ "visual.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
630
+ "visual.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
631
+ "visual.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
632
+ "visual.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
633
+ "visual.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
634
+ "visual.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
635
+ "visual.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
636
+ "visual.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
637
+ "visual.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
638
+ "visual.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
639
+ "visual.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
640
+ "visual.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
641
+ "visual.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
642
+ "visual.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
643
+ "visual.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
644
+ "visual.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
645
+ "visual.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
646
+ "visual.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
647
+ "visual.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
648
+ "visual.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
649
+ "visual.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
650
+ "visual.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
651
+ "visual.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
652
+ "visual.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
653
+ "visual.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
654
+ "visual.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
655
+ "visual.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
656
+ "visual.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
657
+ "visual.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
658
+ "visual.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
659
+ "visual.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
660
+ "visual.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
661
+ "visual.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
662
+ "visual.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
663
+ "visual.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
664
+ "visual.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
665
+ "visual.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
666
+ "visual.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
667
+ "visual.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
668
+ "visual.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
669
+ "visual.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
670
+ "visual.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
671
+ "visual.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
672
+ "visual.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
673
+ "visual.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
674
+ "visual.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
675
+ "visual.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
676
+ "visual.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
677
+ "visual.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
678
+ "visual.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
679
+ "visual.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
680
+ "visual.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
681
+ "visual.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
682
+ "visual.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
683
+ "visual.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
684
+ "visual.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
685
+ "visual.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
686
+ "visual.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
687
+ "visual.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
688
+ "visual.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
689
+ "visual.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
690
+ "visual.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
691
+ "visual.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
692
+ "visual.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
693
+ "visual.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
694
+ "visual.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
695
+ "visual.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
696
+ "visual.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
697
+ "visual.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
698
+ "visual.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
699
+ "visual.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
700
+ "visual.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
701
+ "visual.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
702
+ "visual.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
703
+ "visual.blocks.23.mlp.fc1.bias": "model-00001-of-00004.safetensors",
704
+ "visual.blocks.23.mlp.fc1.weight": "model-00001-of-00004.safetensors",
705
+ "visual.blocks.23.mlp.fc2.bias": "model-00001-of-00004.safetensors",
706
+ "visual.blocks.23.mlp.fc2.weight": "model-00001-of-00004.safetensors",
707
+ "visual.blocks.23.norm1.bias": "model-00001-of-00004.safetensors",
708
+ "visual.blocks.23.norm1.weight": "model-00001-of-00004.safetensors",
709
+ "visual.blocks.23.norm2.bias": "model-00001-of-00004.safetensors",
710
+ "visual.blocks.23.norm2.weight": "model-00001-of-00004.safetensors",
711
+ "visual.blocks.24.attn.proj.bias": "model-00001-of-00004.safetensors",
712
+ "visual.blocks.24.attn.proj.weight": "model-00001-of-00004.safetensors",
713
+ "visual.blocks.24.attn.qkv.bias": "model-00001-of-00004.safetensors",
714
+ "visual.blocks.24.attn.qkv.weight": "model-00001-of-00004.safetensors",
715
+ "visual.blocks.24.mlp.fc1.bias": "model-00001-of-00004.safetensors",
716
+ "visual.blocks.24.mlp.fc1.weight": "model-00001-of-00004.safetensors",
717
+ "visual.blocks.24.mlp.fc2.bias": "model-00001-of-00004.safetensors",
718
+ "visual.blocks.24.mlp.fc2.weight": "model-00001-of-00004.safetensors",
719
+ "visual.blocks.24.norm1.bias": "model-00001-of-00004.safetensors",
720
+ "visual.blocks.24.norm1.weight": "model-00001-of-00004.safetensors",
721
+ "visual.blocks.24.norm2.bias": "model-00001-of-00004.safetensors",
722
+ "visual.blocks.24.norm2.weight": "model-00001-of-00004.safetensors",
723
+ "visual.blocks.25.attn.proj.bias": "model-00001-of-00004.safetensors",
724
+ "visual.blocks.25.attn.proj.weight": "model-00001-of-00004.safetensors",
725
+ "visual.blocks.25.attn.qkv.bias": "model-00001-of-00004.safetensors",
726
+ "visual.blocks.25.attn.qkv.weight": "model-00001-of-00004.safetensors",
727
+ "visual.blocks.25.mlp.fc1.bias": "model-00001-of-00004.safetensors",
728
+ "visual.blocks.25.mlp.fc1.weight": "model-00001-of-00004.safetensors",
729
+ "visual.blocks.25.mlp.fc2.bias": "model-00001-of-00004.safetensors",
730
+ "visual.blocks.25.mlp.fc2.weight": "model-00001-of-00004.safetensors",
731
+ "visual.blocks.25.norm1.bias": "model-00001-of-00004.safetensors",
732
+ "visual.blocks.25.norm1.weight": "model-00001-of-00004.safetensors",
733
+ "visual.blocks.25.norm2.bias": "model-00001-of-00004.safetensors",
734
+ "visual.blocks.25.norm2.weight": "model-00001-of-00004.safetensors",
735
+ "visual.blocks.26.attn.proj.bias": "model-00001-of-00004.safetensors",
736
+ "visual.blocks.26.attn.proj.weight": "model-00001-of-00004.safetensors",
737
+ "visual.blocks.26.attn.qkv.bias": "model-00001-of-00004.safetensors",
738
+ "visual.blocks.26.attn.qkv.weight": "model-00001-of-00004.safetensors",
739
+ "visual.blocks.26.mlp.fc1.bias": "model-00001-of-00004.safetensors",
740
+ "visual.blocks.26.mlp.fc1.weight": "model-00001-of-00004.safetensors",
741
+ "visual.blocks.26.mlp.fc2.bias": "model-00001-of-00004.safetensors",
742
+ "visual.blocks.26.mlp.fc2.weight": "model-00001-of-00004.safetensors",
743
+ "visual.blocks.26.norm1.bias": "model-00001-of-00004.safetensors",
744
+ "visual.blocks.26.norm1.weight": "model-00001-of-00004.safetensors",
745
+ "visual.blocks.26.norm2.bias": "model-00001-of-00004.safetensors",
746
+ "visual.blocks.26.norm2.weight": "model-00001-of-00004.safetensors",
747
+ "visual.blocks.27.attn.proj.bias": "model-00001-of-00004.safetensors",
748
+ "visual.blocks.27.attn.proj.weight": "model-00001-of-00004.safetensors",
749
+ "visual.blocks.27.attn.qkv.bias": "model-00001-of-00004.safetensors",
750
+ "visual.blocks.27.attn.qkv.weight": "model-00001-of-00004.safetensors",
751
+ "visual.blocks.27.mlp.fc1.bias": "model-00001-of-00004.safetensors",
752
+ "visual.blocks.27.mlp.fc1.weight": "model-00001-of-00004.safetensors",
753
+ "visual.blocks.27.mlp.fc2.bias": "model-00001-of-00004.safetensors",
754
+ "visual.blocks.27.mlp.fc2.weight": "model-00001-of-00004.safetensors",
755
+ "visual.blocks.27.norm1.bias": "model-00001-of-00004.safetensors",
756
+ "visual.blocks.27.norm1.weight": "model-00001-of-00004.safetensors",
757
+ "visual.blocks.27.norm2.bias": "model-00001-of-00004.safetensors",
758
+ "visual.blocks.27.norm2.weight": "model-00001-of-00004.safetensors",
759
+ "visual.blocks.28.attn.proj.bias": "model-00001-of-00004.safetensors",
760
+ "visual.blocks.28.attn.proj.weight": "model-00001-of-00004.safetensors",
761
+ "visual.blocks.28.attn.qkv.bias": "model-00001-of-00004.safetensors",
762
+ "visual.blocks.28.attn.qkv.weight": "model-00001-of-00004.safetensors",
763
+ "visual.blocks.28.mlp.fc1.bias": "model-00001-of-00004.safetensors",
764
+ "visual.blocks.28.mlp.fc1.weight": "model-00001-of-00004.safetensors",
765
+ "visual.blocks.28.mlp.fc2.bias": "model-00001-of-00004.safetensors",
766
+ "visual.blocks.28.mlp.fc2.weight": "model-00001-of-00004.safetensors",
767
+ "visual.blocks.28.norm1.bias": "model-00001-of-00004.safetensors",
768
+ "visual.blocks.28.norm1.weight": "model-00001-of-00004.safetensors",
769
+ "visual.blocks.28.norm2.bias": "model-00001-of-00004.safetensors",
770
+ "visual.blocks.28.norm2.weight": "model-00001-of-00004.safetensors",
771
+ "visual.blocks.29.attn.proj.bias": "model-00001-of-00004.safetensors",
772
+ "visual.blocks.29.attn.proj.weight": "model-00001-of-00004.safetensors",
773
+ "visual.blocks.29.attn.qkv.bias": "model-00001-of-00004.safetensors",
774
+ "visual.blocks.29.attn.qkv.weight": "model-00001-of-00004.safetensors",
775
+ "visual.blocks.29.mlp.fc1.bias": "model-00001-of-00004.safetensors",
776
+ "visual.blocks.29.mlp.fc1.weight": "model-00001-of-00004.safetensors",
777
+ "visual.blocks.29.mlp.fc2.bias": "model-00001-of-00004.safetensors",
778
+ "visual.blocks.29.mlp.fc2.weight": "model-00001-of-00004.safetensors",
779
+ "visual.blocks.29.norm1.bias": "model-00001-of-00004.safetensors",
780
+ "visual.blocks.29.norm1.weight": "model-00001-of-00004.safetensors",
781
+ "visual.blocks.29.norm2.bias": "model-00001-of-00004.safetensors",
782
+ "visual.blocks.29.norm2.weight": "model-00001-of-00004.safetensors",
783
+ "visual.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
784
+ "visual.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
785
+ "visual.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
786
+ "visual.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
787
+ "visual.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
788
+ "visual.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
789
+ "visual.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
790
+ "visual.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
791
+ "visual.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
792
+ "visual.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
793
+ "visual.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
794
+ "visual.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
795
+ "visual.blocks.30.attn.proj.bias": "model-00001-of-00004.safetensors",
796
+ "visual.blocks.30.attn.proj.weight": "model-00001-of-00004.safetensors",
797
+ "visual.blocks.30.attn.qkv.bias": "model-00001-of-00004.safetensors",
798
+ "visual.blocks.30.attn.qkv.weight": "model-00001-of-00004.safetensors",
799
+ "visual.blocks.30.mlp.fc1.bias": "model-00001-of-00004.safetensors",
800
+ "visual.blocks.30.mlp.fc1.weight": "model-00001-of-00004.safetensors",
801
+ "visual.blocks.30.mlp.fc2.bias": "model-00001-of-00004.safetensors",
802
+ "visual.blocks.30.mlp.fc2.weight": "model-00001-of-00004.safetensors",
803
+ "visual.blocks.30.norm1.bias": "model-00001-of-00004.safetensors",
804
+ "visual.blocks.30.norm1.weight": "model-00001-of-00004.safetensors",
805
+ "visual.blocks.30.norm2.bias": "model-00001-of-00004.safetensors",
806
+ "visual.blocks.30.norm2.weight": "model-00001-of-00004.safetensors",
807
+ "visual.blocks.31.attn.proj.bias": "model-00001-of-00004.safetensors",
808
+ "visual.blocks.31.attn.proj.weight": "model-00001-of-00004.safetensors",
809
+ "visual.blocks.31.attn.qkv.bias": "model-00001-of-00004.safetensors",
810
+ "visual.blocks.31.attn.qkv.weight": "model-00001-of-00004.safetensors",
811
+ "visual.blocks.31.mlp.fc1.bias": "model-00001-of-00004.safetensors",
812
+ "visual.blocks.31.mlp.fc1.weight": "model-00001-of-00004.safetensors",
813
+ "visual.blocks.31.mlp.fc2.bias": "model-00001-of-00004.safetensors",
814
+ "visual.blocks.31.mlp.fc2.weight": "model-00001-of-00004.safetensors",
815
+ "visual.blocks.31.norm1.bias": "model-00001-of-00004.safetensors",
816
+ "visual.blocks.31.norm1.weight": "model-00001-of-00004.safetensors",
817
+ "visual.blocks.31.norm2.bias": "model-00001-of-00004.safetensors",
818
+ "visual.blocks.31.norm2.weight": "model-00001-of-00004.safetensors",
819
+ "visual.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
820
+ "visual.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
821
+ "visual.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
822
+ "visual.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
823
+ "visual.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
824
+ "visual.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
825
+ "visual.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
826
+ "visual.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
827
+ "visual.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
828
+ "visual.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
829
+ "visual.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
830
+ "visual.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
831
+ "visual.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
832
+ "visual.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
833
+ "visual.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
834
+ "visual.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
835
+ "visual.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
836
+ "visual.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
837
+ "visual.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
838
+ "visual.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
839
+ "visual.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
840
+ "visual.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
841
+ "visual.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
842
+ "visual.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
843
+ "visual.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
844
+ "visual.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
845
+ "visual.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
846
+ "visual.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
847
+ "visual.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
848
+ "visual.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
849
+ "visual.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
850
+ "visual.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
851
+ "visual.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
852
+ "visual.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
853
+ "visual.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
854
+ "visual.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
855
+ "visual.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
856
+ "visual.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
857
+ "visual.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
858
+ "visual.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
859
+ "visual.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
860
+ "visual.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
861
+ "visual.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
862
+ "visual.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
863
+ "visual.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
864
+ "visual.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
865
+ "visual.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
866
+ "visual.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
867
+ "visual.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
868
+ "visual.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
869
+ "visual.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
870
+ "visual.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
871
+ "visual.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
872
+ "visual.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
873
+ "visual.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
874
+ "visual.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
875
+ "visual.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
876
+ "visual.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
877
+ "visual.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
878
+ "visual.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
879
+ "visual.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
880
+ "visual.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
881
+ "visual.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
882
+ "visual.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
883
+ "visual.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
884
+ "visual.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
885
+ "visual.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
886
+ "visual.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
887
+ "visual.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
888
+ "visual.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
889
+ "visual.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
890
+ "visual.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
891
+ "visual.merger.ln_q.bias": "model-00001-of-00004.safetensors",
892
+ "visual.merger.ln_q.weight": "model-00001-of-00004.safetensors",
893
+ "visual.merger.mlp.0.bias": "model-00001-of-00004.safetensors",
894
+ "visual.merger.mlp.0.weight": "model-00001-of-00004.safetensors",
895
+ "visual.merger.mlp.2.bias": "model-00001-of-00004.safetensors",
896
+ "visual.merger.mlp.2.weight": "model-00001-of-00004.safetensors",
897
+ "visual.patch_embed.proj.weight": "model-00001-of-00004.safetensors"
898
+ }
899
+ }
qwen2-vl-3d/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc71ef5f1c3be4d29c372baeea36b661da41256d91477608542b21e214e357dc
3
+ size 165361722
qwen2-vl-3d/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a50d9ce198e2733053d8d4a9703a034370d3ce84ca4a64a639d3035f21bb559
3
+ size 14244
qwen2-vl-3d/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b46bd00f15da7bf12c17740f5c0095e4932c178c4baf4ffcf2aaa705fd44155
3
+ size 1064
qwen2-vl-3d/trainer_state.json ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 1.1060227155685425,
3
+ "best_model_checkpoint": "qwen2-7b-instruct-trl-sft-housetour-l64-adapter/checkpoint-500",
4
+ "epoch": 3.4246575342465753,
5
+ "eval_steps": 500,
6
+ "global_step": 500,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.00684931506849315,
13
+ "grad_norm": 18.83251953125,
14
+ "learning_rate": 1.1363636363636363e-08,
15
+ "loss": 2.0308,
16
+ "step": 1
17
+ },
18
+ {
19
+ "epoch": 0.0684931506849315,
20
+ "grad_norm": 16.91636848449707,
21
+ "learning_rate": 1.1363636363636363e-07,
22
+ "loss": 2.2756,
23
+ "step": 10
24
+ },
25
+ {
26
+ "epoch": 0.136986301369863,
27
+ "grad_norm": 16.575881958007812,
28
+ "learning_rate": 2.2727272727272726e-07,
29
+ "loss": 2.2648,
30
+ "step": 20
31
+ },
32
+ {
33
+ "epoch": 0.2054794520547945,
34
+ "grad_norm": 10.142265319824219,
35
+ "learning_rate": 3.4090909090909085e-07,
36
+ "loss": 2.103,
37
+ "step": 30
38
+ },
39
+ {
40
+ "epoch": 0.273972602739726,
41
+ "grad_norm": 10.252799034118652,
42
+ "learning_rate": 4.545454545454545e-07,
43
+ "loss": 1.9579,
44
+ "step": 40
45
+ },
46
+ {
47
+ "epoch": 0.3424657534246575,
48
+ "grad_norm": 10.1475830078125,
49
+ "learning_rate": 5.681818181818182e-07,
50
+ "loss": 1.7427,
51
+ "step": 50
52
+ },
53
+ {
54
+ "epoch": 0.410958904109589,
55
+ "grad_norm": 6.072550296783447,
56
+ "learning_rate": 6.818181818181817e-07,
57
+ "loss": 1.6655,
58
+ "step": 60
59
+ },
60
+ {
61
+ "epoch": 0.4794520547945205,
62
+ "grad_norm": 4.325073719024658,
63
+ "learning_rate": 7.954545454545454e-07,
64
+ "loss": 1.5283,
65
+ "step": 70
66
+ },
67
+ {
68
+ "epoch": 0.547945205479452,
69
+ "grad_norm": 1.543124794960022,
70
+ "learning_rate": 9.09090909090909e-07,
71
+ "loss": 1.3892,
72
+ "step": 80
73
+ },
74
+ {
75
+ "epoch": 0.6164383561643836,
76
+ "grad_norm": 1.4148296117782593,
77
+ "learning_rate": 9.999987694108851e-07,
78
+ "loss": 1.2604,
79
+ "step": 90
80
+ },
81
+ {
82
+ "epoch": 0.684931506849315,
83
+ "grad_norm": 1.1791110038757324,
84
+ "learning_rate": 9.999556994278908e-07,
85
+ "loss": 1.1968,
86
+ "step": 100
87
+ },
88
+ {
89
+ "epoch": 0.7534246575342466,
90
+ "grad_norm": 0.6725325584411621,
91
+ "learning_rate": 9.99851106046421e-07,
92
+ "loss": 1.2527,
93
+ "step": 110
94
+ },
95
+ {
96
+ "epoch": 0.821917808219178,
97
+ "grad_norm": 0.6848694682121277,
98
+ "learning_rate": 9.996850021374967e-07,
99
+ "loss": 1.1948,
100
+ "step": 120
101
+ },
102
+ {
103
+ "epoch": 0.8904109589041096,
104
+ "grad_norm": 0.7549965977668762,
105
+ "learning_rate": 9.994574081414829e-07,
106
+ "loss": 1.1981,
107
+ "step": 130
108
+ },
109
+ {
110
+ "epoch": 0.958904109589041,
111
+ "grad_norm": 0.7096678614616394,
112
+ "learning_rate": 9.991683520655733e-07,
113
+ "loss": 1.1431,
114
+ "step": 140
115
+ },
116
+ {
117
+ "epoch": 1.0273972602739727,
118
+ "grad_norm": 0.5900934934616089,
119
+ "learning_rate": 9.988178694803437e-07,
120
+ "loss": 1.1278,
121
+ "step": 150
122
+ },
123
+ {
124
+ "epoch": 1.095890410958904,
125
+ "grad_norm": 0.5475574135780334,
126
+ "learning_rate": 9.98406003515375e-07,
127
+ "loss": 1.0992,
128
+ "step": 160
129
+ },
130
+ {
131
+ "epoch": 1.1643835616438356,
132
+ "grad_norm": 0.7482232451438904,
133
+ "learning_rate": 9.979328048539456e-07,
134
+ "loss": 1.1284,
135
+ "step": 170
136
+ },
137
+ {
138
+ "epoch": 1.2328767123287672,
139
+ "grad_norm": 0.5973280072212219,
140
+ "learning_rate": 9.973983317267942e-07,
141
+ "loss": 1.0852,
142
+ "step": 180
143
+ },
144
+ {
145
+ "epoch": 1.3013698630136985,
146
+ "grad_norm": 0.6261045932769775,
147
+ "learning_rate": 9.968026499049549e-07,
148
+ "loss": 1.1337,
149
+ "step": 190
150
+ },
151
+ {
152
+ "epoch": 1.36986301369863,
153
+ "grad_norm": 0.5769844055175781,
154
+ "learning_rate": 9.961458326916622e-07,
155
+ "loss": 1.0631,
156
+ "step": 200
157
+ },
158
+ {
159
+ "epoch": 1.4383561643835616,
160
+ "grad_norm": 0.5357353687286377,
161
+ "learning_rate": 9.95427960913332e-07,
162
+ "loss": 1.0776,
163
+ "step": 210
164
+ },
165
+ {
166
+ "epoch": 1.5068493150684932,
167
+ "grad_norm": 0.4853924810886383,
168
+ "learning_rate": 9.946491229096141e-07,
169
+ "loss": 1.1915,
170
+ "step": 220
171
+ },
172
+ {
173
+ "epoch": 1.5753424657534247,
174
+ "grad_norm": 0.46966347098350525,
175
+ "learning_rate": 9.93809414522522e-07,
176
+ "loss": 1.1294,
177
+ "step": 230
178
+ },
179
+ {
180
+ "epoch": 1.643835616438356,
181
+ "grad_norm": 0.4522402286529541,
182
+ "learning_rate": 9.929089390846387e-07,
183
+ "loss": 1.1039,
184
+ "step": 240
185
+ },
186
+ {
187
+ "epoch": 1.7123287671232876,
188
+ "grad_norm": 0.5600599646568298,
189
+ "learning_rate": 9.919478074064001e-07,
190
+ "loss": 1.1376,
191
+ "step": 250
192
+ },
193
+ {
194
+ "epoch": 1.7808219178082192,
195
+ "grad_norm": 0.6039386987686157,
196
+ "learning_rate": 9.9092613776246e-07,
197
+ "loss": 0.9726,
198
+ "step": 260
199
+ },
200
+ {
201
+ "epoch": 1.8493150684931505,
202
+ "grad_norm": 0.49978339672088623,
203
+ "learning_rate": 9.89844055877135e-07,
204
+ "loss": 1.0931,
205
+ "step": 270
206
+ },
207
+ {
208
+ "epoch": 1.9178082191780823,
209
+ "grad_norm": 0.5846579074859619,
210
+ "learning_rate": 9.887016949089332e-07,
211
+ "loss": 0.9936,
212
+ "step": 280
213
+ },
214
+ {
215
+ "epoch": 1.9863013698630136,
216
+ "grad_norm": 0.7952700257301331,
217
+ "learning_rate": 9.874991954341681e-07,
218
+ "loss": 0.9754,
219
+ "step": 290
220
+ },
221
+ {
222
+ "epoch": 2.0547945205479454,
223
+ "grad_norm": 0.4901112914085388,
224
+ "learning_rate": 9.862367054296588e-07,
225
+ "loss": 1.0696,
226
+ "step": 300
227
+ },
228
+ {
229
+ "epoch": 2.1232876712328768,
230
+ "grad_norm": 0.5387308597564697,
231
+ "learning_rate": 9.84914380254522e-07,
232
+ "loss": 1.0019,
233
+ "step": 310
234
+ },
235
+ {
236
+ "epoch": 2.191780821917808,
237
+ "grad_norm": 0.5762762427330017,
238
+ "learning_rate": 9.83532382631052e-07,
239
+ "loss": 1.033,
240
+ "step": 320
241
+ },
242
+ {
243
+ "epoch": 2.26027397260274,
244
+ "grad_norm": 0.4994489848613739,
245
+ "learning_rate": 9.82090882624698e-07,
246
+ "loss": 0.9562,
247
+ "step": 330
248
+ },
249
+ {
250
+ "epoch": 2.328767123287671,
251
+ "grad_norm": 0.5436204671859741,
252
+ "learning_rate": 9.805900576231357e-07,
253
+ "loss": 0.9257,
254
+ "step": 340
255
+ },
256
+ {
257
+ "epoch": 2.3972602739726026,
258
+ "grad_norm": 0.6050807237625122,
259
+ "learning_rate": 9.790300923144372e-07,
260
+ "loss": 1.0112,
261
+ "step": 350
262
+ },
263
+ {
264
+ "epoch": 2.4657534246575343,
265
+ "grad_norm": 0.41216209530830383,
266
+ "learning_rate": 9.77411178664346e-07,
267
+ "loss": 0.9853,
268
+ "step": 360
269
+ },
270
+ {
271
+ "epoch": 2.5342465753424657,
272
+ "grad_norm": 0.4330998957157135,
273
+ "learning_rate": 9.75733515892652e-07,
274
+ "loss": 1.0301,
275
+ "step": 370
276
+ },
277
+ {
278
+ "epoch": 2.602739726027397,
279
+ "grad_norm": 0.4711519479751587,
280
+ "learning_rate": 9.739973104486777e-07,
281
+ "loss": 0.9389,
282
+ "step": 380
283
+ },
284
+ {
285
+ "epoch": 2.671232876712329,
286
+ "grad_norm": 0.5700051784515381,
287
+ "learning_rate": 9.722027759858714e-07,
288
+ "loss": 0.9781,
289
+ "step": 390
290
+ },
291
+ {
292
+ "epoch": 2.73972602739726,
293
+ "grad_norm": 0.36050090193748474,
294
+ "learning_rate": 9.703501333355166e-07,
295
+ "loss": 1.0553,
296
+ "step": 400
297
+ },
298
+ {
299
+ "epoch": 2.808219178082192,
300
+ "grad_norm": 0.41387563943862915,
301
+ "learning_rate": 9.68439610479557e-07,
302
+ "loss": 1.0056,
303
+ "step": 410
304
+ },
305
+ {
306
+ "epoch": 2.8767123287671232,
307
+ "grad_norm": 0.4400821030139923,
308
+ "learning_rate": 9.664714425225413e-07,
309
+ "loss": 0.9877,
310
+ "step": 420
311
+ },
312
+ {
313
+ "epoch": 2.9452054794520546,
314
+ "grad_norm": 0.4176802635192871,
315
+ "learning_rate": 9.644458716626911e-07,
316
+ "loss": 0.9547,
317
+ "step": 430
318
+ },
319
+ {
320
+ "epoch": 3.0136986301369864,
321
+ "grad_norm": 0.3389221727848053,
322
+ "learning_rate": 9.623631471620979e-07,
323
+ "loss": 1.026,
324
+ "step": 440
325
+ },
326
+ {
327
+ "epoch": 3.0821917808219177,
328
+ "grad_norm": 0.36593514680862427,
329
+ "learning_rate": 9.602235253160481e-07,
330
+ "loss": 1.0277,
331
+ "step": 450
332
+ },
333
+ {
334
+ "epoch": 3.1506849315068495,
335
+ "grad_norm": 0.3649665117263794,
336
+ "learning_rate": 9.580272694214854e-07,
337
+ "loss": 0.9359,
338
+ "step": 460
339
+ },
340
+ {
341
+ "epoch": 3.219178082191781,
342
+ "grad_norm": 0.34383371472358704,
343
+ "learning_rate": 9.557746497446085e-07,
344
+ "loss": 0.9152,
345
+ "step": 470
346
+ },
347
+ {
348
+ "epoch": 3.287671232876712,
349
+ "grad_norm": 0.34904035925865173,
350
+ "learning_rate": 9.53465943487614e-07,
351
+ "loss": 0.9315,
352
+ "step": 480
353
+ },
354
+ {
355
+ "epoch": 3.356164383561644,
356
+ "grad_norm": 0.34317487478256226,
357
+ "learning_rate": 9.511014347545837e-07,
358
+ "loss": 0.9726,
359
+ "step": 490
360
+ },
361
+ {
362
+ "epoch": 3.4246575342465753,
363
+ "grad_norm": 0.30464446544647217,
364
+ "learning_rate": 9.48681414516524e-07,
365
+ "loss": 0.9659,
366
+ "step": 500
367
+ },
368
+ {
369
+ "epoch": 3.4246575342465753,
370
+ "eval_loss": 1.1060227155685425,
371
+ "eval_runtime": 641.2039,
372
+ "eval_samples_per_second": 0.203,
373
+ "eval_steps_per_second": 0.203,
374
+ "step": 500
375
+ }
376
+ ],
377
+ "logging_steps": 10,
378
+ "max_steps": 2920,
379
+ "num_input_tokens_seen": 0,
380
+ "num_train_epochs": 20,
381
+ "save_steps": 500,
382
+ "stateful_callbacks": {
383
+ "TrainerControl": {
384
+ "args": {
385
+ "should_epoch_stop": false,
386
+ "should_evaluate": false,
387
+ "should_log": false,
388
+ "should_save": true,
389
+ "should_training_stop": false
390
+ },
391
+ "attributes": {}
392
+ }
393
+ },
394
+ "total_flos": 3.5300059477526943e+18,
395
+ "train_batch_size": 1,
396
+ "trial_name": null,
397
+ "trial_params": null
398
+ }
residual-diffuser/args.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_weight": 1,
3
+ "add_extras": {
4
+ "_string": "<bound method Parser.add_extras of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
5
+ "_type": "python_object (type = method)",
6
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaD6GlFKULg=="
7
+ },
8
+ "batch_size": 1,
9
+ "bucket": null,
10
+ "clip_denoised": true,
11
+ "commit": "13b4d404bdbd9d0d74c08d5a51ee39f84e765830 maze2d",
12
+ "config": "config.maze2d",
13
+ "dataset": "maze2d-large-v1",
14
+ "device": "cuda",
15
+ "diffusion": "models.GaussianDiffusion",
16
+ "dim_mults": {
17
+ "_type": "tuple",
18
+ "_value": [
19
+ 1,
20
+ 4,
21
+ 8
22
+ ]
23
+ },
24
+ "ema_decay": 0.995,
25
+ "eval_fstrings": {
26
+ "_string": "<bound method Parser.eval_fstrings of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
27
+ "_type": "python_object (type = method)",
28
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaBuGlFKULg=="
29
+ },
30
+ "exp_name": "diffusion/H384_T16",
31
+ "generate_exp_name": {
32
+ "_string": "<bound method Parser.generate_exp_name of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
33
+ "_type": "python_object (type = method)",
34
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaEmGlFKULg=="
35
+ },
36
+ "get_commit": {
37
+ "_string": "<bound method Parser.get_commit of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
38
+ "_type": "python_object (type = method)",
39
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaBKGlFKULg=="
40
+ },
41
+ "gradient_accumulate_every": 8,
42
+ "horizon": 384,
43
+ "learning_rate": 5e-06,
44
+ "loader": "datasets.GoalDataset",
45
+ "logbase": "logs",
46
+ "loss_discount": 1,
47
+ "loss_type": "spline",
48
+ "loss_weights": null,
49
+ "max_path_length": 40000,
50
+ "mkdir": {
51
+ "_string": "<bound method Parser.mkdir of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
52
+ "_type": "python_object (type = method)",
53
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaA+GlFKULg=="
54
+ },
55
+ "model": "models.TemporalUnet",
56
+ "n_diffusion_steps": 16,
57
+ "n_reference": 50,
58
+ "n_samples": 10,
59
+ "n_saves": 50,
60
+ "n_steps_per_epoch": 60000,
61
+ "n_train_steps": 60000.0,
62
+ "normalizer": "LimitsNormalizer",
63
+ "predict_epsilon": true,
64
+ "prefix": "diffusion/",
65
+ "preprocess_fns": [
66
+ "maze2d_set_terminals"
67
+ ],
68
+ "read_config": {
69
+ "_string": "<bound method Parser.read_config of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
70
+ "_type": "python_object (type = method)",
71
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaCuGlFKULg=="
72
+ },
73
+ "renderer": "utils.Maze2dRenderer",
74
+ "reproducibility": {
75
+ "command_line": "python scripts/train_tour.py",
76
+ "git_has_uncommitted_changes": true,
77
+ "git_root": "/local/home/atcelen/work/diffuser",
78
+ "git_url": "https://github.com/jannerm/diffuser/tree/13b4d404bdbd9d0d74c08d5a51ee39f84e765830",
79
+ "time": "Thu Jun 26 17:05:14 2025"
80
+ },
81
+ "sample_freq": 1000,
82
+ "save_diff": {
83
+ "_string": "<bound method Parser.save_diff of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
84
+ "_type": "python_object (type = method)",
85
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaFWGlFKULg=="
86
+ },
87
+ "save_freq": 2000,
88
+ "save_parallel": false,
89
+ "savepath": "logs/maze2d-large-v1/diffusion/H384_T16",
90
+ "set_seed": {
91
+ "_string": "<bound method Parser.set_seed of Parser(prog='train_tour.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)>",
92
+ "_type": "python_object (type = method)",
93
+ "_value": "gASVlAQAAAAAAACMCGJ1aWx0aW5zlIwHZ2V0YXR0cpSTlIwIX19tYWluX1+UjAZQYXJzZXKUk5QpgZR9lCiMCHNldF9zZWVklGgCaAZoCIaUUpSMCHNhdmVwYXRolIwnbG9ncy9tYXplMmQtbGFyZ2UtdjEvZGlmZnVzaW9uL0gzODRfVDE2lIwKbm9ybWFsaXplcpSMEExpbWl0c05vcm1hbGl6ZXKUjAVta2RpcpRoAmgGaA+GlFKUjApnZXRfY29tbWl0lGgCaAZoEoaUUpSMCGV4cF9uYW1llIwSZGlmZnVzaW9uL0gzODRfVDE2lIwLc2FtcGxlX2ZyZXGUTegDjAdob3Jpem9ulE2AAYwGY29uZmlnlIwNY29uZmlnLm1hemUyZJSMDWV2YWxfZnN0cmluZ3OUaAJoBmgbhpRSlIwKYmF0Y2hfc2l6ZZRLAYwGZGV2aWNllIwEY3VkYZSMC25fcmVmZXJlbmNllEsyjA1zYXZlX3BhcmFsbGVslImMGWdyYWRpZW50X2FjY3VtdWxhdGVfZXZlcnmUSwiMDWFjdGlvbl93ZWlnaHSUSwGMCWRpZmZ1c2lvbpSMGG1vZGVscy5HYXVzc2lhbkRpZmZ1c2lvbpSMBmJ1Y2tldJROjAZwcmVmaXiUjApkaWZmdXNpb24vlIwNbl90cmFpbl9zdGVwc5RHQO1MAAAAAACMC3JlYWRfY29uZmlnlGgCaAZoK4aUUpSMDWxvc3NfZGlzY291bnSUSwGMEW5fc3RlcHNfcGVyX2Vwb2NolE1g6owMbG9zc193ZWlnaHRzlE6MD21heF9wYXRoX2xlbmd0aJRNQJyMDWNsaXBfZGVub2lzZWSUiIwFbW9kZWyUjBNtb2RlbHMuVGVtcG9yYWxVbmV0lIwJZW1hX2RlY2F5lEc/79cKPXCj14wPcHJlZGljdF9lcHNpbG9ulIiMB2xvZ2Jhc2WUjARsb2dzlIwRbl9kaWZmdXNpb25fc3RlcHOUSxCMB2RhdGFzZXSUjA9tYXplMmQtbGFyZ2UtdjGUjA1sZWFybmluZ19yYXRllEc+1Pi1iONo8YwLdXNlX3BhZGRpbmeUiYwKYWRkX2V4dHJhc5RoAmgGaD6GlFKUjAlkaW1fbXVsdHOUSwFLBEsIh5SMCWxvc3NfdHlwZZSMBnNwbGluZZSMCW5fc2FtcGxlc5RLCowIcmVuZGVyZXKUjBR1dGlscy5NYXplMmRSZW5kZXJlcpSMCXNhdmVfZnJlcZRN0AeMEWdlbmVyYXRlX2V4cF9uYW1llGgCaAZoSYaUUpSMBmxvYWRlcpSMFGRhdGFzZXRzLkdvYWxEYXRhc2V0lIwGY29tbWl0lIwvMTNiNGQ0MDRiZGJkOWQwZDc0YzA4ZDVhNTFlZTM5Zjg0ZTc2NTgzMCBtYXplMmSUjAduX3NhdmVzlEsyjBN0ZXJtaW5hdGlvbl9wZW5hbHR5lE6MDnByZXByb2Nlc3NfZm5zlF2UjBRtYXplMmRfc2V0X3Rlcm1pbmFsc5RhjAlzYXZlX2RpZmaUaAJoBmhVhpRSlHViaAiGlFKULg=="
94
+ },
95
+ "termination_penalty": null,
96
+ "use_padding": false
97
+ }
residual-diffuser/dataset_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c441212970d6518583ad1a3d0dcb8a68690102f1dbca6bc3fbdbe3ec5b8430f8
3
+ size 280
residual-diffuser/diff.txt ADDED
@@ -0,0 +1,1895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/config/locomotion.py b/config/locomotion.py
2
+ deleted file mode 100644
3
+ index 4410bb1..0000000
4
+ --- a/config/locomotion.py
5
+ +++ /dev/null
6
+ @@ -1,70 +0,0 @@
7
+ -import socket
8
+ -
9
+ -from diffuser.utils import watch
10
+ -
11
+ -#------------------------ base ------------------------#
12
+ -
13
+ -## automatically make experiment names for planning
14
+ -## by labelling folders with these args
15
+ -
16
+ -diffusion_args_to_watch = [
17
+ - ('prefix', ''),
18
+ - ('horizon', 'H'),
19
+ - ('n_diffusion_steps', 'T'),
20
+ -]
21
+ -
22
+ -base = {
23
+ - 'diffusion': {
24
+ - ## model
25
+ - 'model': 'models.TemporalUnet',
26
+ - 'diffusion': 'models.GaussianDiffusion',
27
+ - 'horizon': 32,
28
+ - 'n_diffusion_steps': 100,
29
+ - 'action_weight': 10,
30
+ - 'loss_weights': None,
31
+ - 'loss_discount': 1,
32
+ - 'predict_epsilon': False,
33
+ - 'dim_mults': (1, 4, 8),
34
+ - 'renderer': 'utils.MuJoCoRenderer',
35
+ -
36
+ - ## dataset
37
+ - 'loader': 'datasets.SequenceDataset',
38
+ - 'normalizer': 'LimitsNormalizer',
39
+ - 'preprocess_fns': [],
40
+ - 'clip_denoised': True,
41
+ - 'use_padding': True,
42
+ - 'max_path_length': 1000,
43
+ -
44
+ - ## serialization
45
+ - 'logbase': 'logs',
46
+ - 'prefix': 'diffusion/',
47
+ - 'exp_name': watch(diffusion_args_to_watch),
48
+ -
49
+ - ## training
50
+ - 'n_steps_per_epoch': 10000,
51
+ - 'loss_type': 'l2',
52
+ - 'n_train_steps': 1e6,
53
+ - 'batch_size': 32,
54
+ - 'learning_rate': 2e-4,
55
+ - 'gradient_accumulate_every': 2,
56
+ - 'ema_decay': 0.995,
57
+ - 'save_freq': 1000,
58
+ - 'sample_freq': 1000,
59
+ - 'n_saves': 5,
60
+ - 'save_parallel': False,
61
+ - 'n_reference': 8,
62
+ - 'n_samples': 2,
63
+ - 'bucket': None,
64
+ - 'device': 'cuda',
65
+ - },
66
+ -}
67
+ -
68
+ -#------------------------ overrides ------------------------#
69
+ -
70
+ -## put environment-specific overrides here
71
+ -
72
+ -halfcheetah_medium_expert_v2 = {
73
+ - 'diffusion': {
74
+ - 'horizon': 16,
75
+ - },
76
+ -}
77
+ diff --git a/config/maze2d.py b/config/maze2d.py
78
+ index a06ac7f..0a8d22a 100644
79
+ --- a/config/maze2d.py
80
+ +++ b/config/maze2d.py
81
+ @@ -34,11 +34,11 @@ base = {
82
+ 'model': 'models.TemporalUnet',
83
+ 'diffusion': 'models.GaussianDiffusion',
84
+ 'horizon': 256,
85
+ - 'n_diffusion_steps': 256,
86
+ + 'n_diffusion_steps': 512,
87
+ 'action_weight': 1,
88
+ 'loss_weights': None,
89
+ 'loss_discount': 1,
90
+ - 'predict_epsilon': False,
91
+ + 'predict_epsilon': True,
92
+ 'dim_mults': (1, 4, 8),
93
+ 'renderer': 'utils.Maze2dRenderer',
94
+
95
+ @@ -57,14 +57,14 @@ base = {
96
+ 'exp_name': watch(diffusion_args_to_watch),
97
+
98
+ ## training
99
+ - 'n_steps_per_epoch': 10000,
100
+ - 'loss_type': 'l2',
101
+ - 'n_train_steps': 2e6,
102
+ - 'batch_size': 32,
103
+ - 'learning_rate': 2e-4,
104
+ - 'gradient_accumulate_every': 2,
105
+ + 'n_steps_per_epoch': 60000,
106
+ + 'loss_type': 'spline',
107
+ + 'n_train_steps': 6e4,
108
+ + 'batch_size': 1,
109
+ + 'learning_rate': 5e-6,
110
+ + 'gradient_accumulate_every': 8,
111
+ 'ema_decay': 0.995,
112
+ - 'save_freq': 1000,
113
+ + 'save_freq': 2000,
114
+ 'sample_freq': 1000,
115
+ 'n_saves': 50,
116
+ 'save_parallel': False,
117
+ @@ -89,7 +89,6 @@ base = {
118
+ 'prefix': 'plans/release',
119
+ 'exp_name': watch(plan_args_to_watch),
120
+ 'suffix': '0',
121
+ -
122
+ 'conditional': False,
123
+
124
+ ## loading
125
+ @@ -122,10 +121,10 @@ maze2d_umaze_v1 = {
126
+ maze2d_large_v1 = {
127
+ 'diffusion': {
128
+ 'horizon': 384,
129
+ - 'n_diffusion_steps': 256,
130
+ + 'n_diffusion_steps': 16,
131
+ },
132
+ 'plan': {
133
+ 'horizon': 384,
134
+ - 'n_diffusion_steps': 256,
135
+ + 'n_diffusion_steps': 16,
136
+ },
137
+ }
138
+ diff --git a/diffuser/datasets/buffer.py b/diffuser/datasets/buffer.py
139
+ index 1ad2106..5991f01 100644
140
+ --- a/diffuser/datasets/buffer.py
141
+ +++ b/diffuser/datasets/buffer.py
142
+ @@ -9,7 +9,7 @@ class ReplayBuffer:
143
+
144
+ def __init__(self, max_n_episodes, max_path_length, termination_penalty):
145
+ self._dict = {
146
+ - 'path_lengths': np.zeros(max_n_episodes, dtype=np.int),
147
+ + 'path_lengths': np.zeros(max_n_episodes, dtype=np.int_),
148
+ }
149
+ self._count = 0
150
+ self.max_n_episodes = max_n_episodes
151
+ diff --git a/diffuser/datasets/sequence.py b/diffuser/datasets/sequence.py
152
+ index 356c540..73c1b04 100644
153
+ --- a/diffuser/datasets/sequence.py
154
+ +++ b/diffuser/datasets/sequence.py
155
+ @@ -83,6 +83,7 @@ class SequenceDataset(torch.utils.data.Dataset):
156
+ actions = self.fields.normed_actions[path_ind, start:end]
157
+
158
+ conditions = self.get_conditions(observations)
159
+ +
160
+ trajectories = np.concatenate([actions, observations], axis=-1)
161
+ batch = Batch(trajectories, conditions)
162
+ return batch
163
+ diff --git a/diffuser/models/diffusion.py b/diffuser/models/diffusion.py
164
+ index fae4cfd..461680a 100644
165
+ --- a/diffuser/models/diffusion.py
166
+ +++ b/diffuser/models/diffusion.py
167
+ @@ -2,6 +2,7 @@ import numpy as np
168
+ import torch
169
+ from torch import nn
170
+ import pdb
171
+ +import matplotlib.pyplot as plt
172
+
173
+ import diffuser.utils as utils
174
+ from .helpers import (
175
+ @@ -9,6 +10,7 @@ from .helpers import (
176
+ extract,
177
+ apply_conditioning,
178
+ Losses,
179
+ + catmull_rom_spline_with_rotation,
180
+ )
181
+
182
+ class GaussianDiffusion(nn.Module):
183
+ @@ -26,6 +28,7 @@ class GaussianDiffusion(nn.Module):
184
+ betas = cosine_beta_schedule(n_timesteps)
185
+ alphas = 1. - betas
186
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
187
+ + print(f"Alphas Cumprod: {alphas_cumprod}")
188
+ alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
189
+
190
+ self.n_timesteps = int(n_timesteps)
191
+ @@ -73,7 +76,7 @@ class GaussianDiffusion(nn.Module):
192
+ '''
193
+ self.action_weight = action_weight
194
+
195
+ - dim_weights = torch.ones(self.transition_dim, dtype=torch.float32)
196
+ + dim_weights = torch.ones(self.transition_dim, dtype=torch.float64)
197
+
198
+ ## set loss coefficients for dimensions of observation
199
+ if weights_dict is None: weights_dict = {}
200
+ @@ -97,18 +100,16 @@ class GaussianDiffusion(nn.Module):
201
+ otherwise, model predicts x0 directly
202
+ '''
203
+ if self.predict_epsilon:
204
+ - return (
205
+ - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
206
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
207
+ - )
208
+ + return noise
209
+ else:
210
+ return noise
211
+
212
+ def q_posterior(self, x_start, x_t, t):
213
+ posterior_mean = (
214
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
215
+ - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
216
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t[:, :, self.action_dim:]
217
+ )
218
+ +
219
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
220
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
221
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
222
+ @@ -129,7 +130,7 @@ class GaussianDiffusion(nn.Module):
223
+ def p_sample(self, x, cond, t):
224
+ b, *_, device = *x.shape, x.device
225
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
226
+ - noise = torch.randn_like(x)
227
+ + noise = torch.randn_like(x[:, :, self.action_dim:])
228
+ # no noise when t == 0
229
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
230
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
231
+ @@ -139,22 +140,59 @@ class GaussianDiffusion(nn.Module):
232
+ device = self.betas.device
233
+
234
+ batch_size = shape[0]
235
+ - x = torch.randn(shape, device=device)
236
+ - x = apply_conditioning(x, cond, self.action_dim)
237
+ + # x = torch.randn(shape, device=device, dtype=torch.float64)
238
+ + # Extract known indices and values
239
+ + known_indices = np.array(list(cond.keys()), dtype=int)
240
+ +
241
+ + # candidate_no x batch_size x dim
242
+ + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0)
243
+ + known_values = np.moveaxis(known_values, 0, 1)
244
+ +
245
+ + # Sort the timepoints
246
+ + sorted_indices = np.argsort(known_indices)
247
+ + known_indices = known_indices[sorted_indices]
248
+ + known_values = known_values[:, sorted_indices]
249
+ +
250
+ + # Build the structured spline guess
251
+ + catmull_spline_trajectory = np.array([
252
+ + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, shape[1])
253
+ + for b in range(batch_size)
254
+ + ])
255
+ + catmull_spline_trajectory = torch.tensor(
256
+ + catmull_spline_trajectory,
257
+ + dtype=torch.float64,
258
+ + device=device
259
+ + )
260
+ +
261
+ +
262
+ + if self.predict_epsilon:
263
+ + x = torch.randn((shape[0], shape[1], self.observation_dim), device=device, dtype=torch.float64)
264
+ + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()}
265
+ + is_cond = torch.zeros((shape[0], shape[1], 1), device=device, dtype=torch.float64)
266
+ + is_cond[:, known_indices, :] = 1.0
267
+
268
+ if return_diffusion: diffusion = [x]
269
+
270
+ - progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
271
+ + # progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
272
+ for i in reversed(range(0, self.n_timesteps)):
273
+ + if self.predict_epsilon:
274
+ + x = torch.cat([catmull_spline_trajectory, is_cond, x], dim=-1)
275
+ +
276
+ timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
277
+ - x = self.p_sample(x, cond, timesteps)
278
+ - x = apply_conditioning(x, cond, self.action_dim)
279
+ + x = self.p_sample(x, cond_residual, timesteps)
280
+ +
281
+ + x = apply_conditioning(x, cond_residual, 0)
282
+
283
+ - progress.update({'t': i})
284
+ + if return_diffusion: diffusion.append(x)
285
+
286
+ - if return_diffusion: diffusion.append(x)
287
+ + x = catmull_spline_trajectory + x
288
+
289
+ - progress.close()
290
+ +
291
+ +
292
+ + # Normalize the quaternions
293
+ + # x[:, :, 3:7] = x[:, :, 3:7] / torch.norm(x[:, :, 3:7], dim=-1, keepdim=True)
294
+ +
295
+ + # progress.close()
296
+
297
+ if return_diffusion:
298
+ return x, torch.stack(diffusion, dim=1)
299
+ @@ -167,7 +205,7 @@ class GaussianDiffusion(nn.Module):
300
+ conditions : [ (time, state), ... ]
301
+ '''
302
+ device = self.betas.device
303
+ - batch_size = len(cond[0])
304
+ + batch_size = len(next(iter(cond.values())))
305
+ horizon = horizon or self.horizon
306
+ shape = (batch_size, horizon, self.transition_dim)
307
+
308
+ @@ -175,38 +213,106 @@ class GaussianDiffusion(nn.Module):
309
+
310
+ #------------------------------------------ training ------------------------------------------#
311
+
312
+ - def q_sample(self, x_start, t, noise=None):
313
+ + def q_sample(self, x_start, t, spline=None, noise=None):
314
+ + x_start_noise = x_start[:, : , :-1]
315
+ + x_start_is_cond = x_start[:, :, [-1]]
316
+ +
317
+ + if spline is None:
318
+ + spline = torch.randn_like(x_start_noise)
319
+ if noise is None:
320
+ - noise = torch.randn_like(x_start)
321
+ + noise = torch.randn_like(x_start_noise)
322
+
323
+ - sample = (
324
+ - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
325
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
326
+ - )
327
+ + alpha = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
328
+ + oneminusalpha = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
329
+ +
330
+ + # Weighted combination of x_0 and the spline
331
+ + out = alpha * x_start_noise + oneminusalpha * noise
332
+ +
333
+ + # Concatenate the binary feature and the spline as the conditioning
334
+ + out = torch.cat([spline, x_start_is_cond, out], dim=-1)
335
+
336
+ - return sample
337
+ + return out
338
+
339
+ def p_losses(self, x_start, cond, t):
340
+ - noise = torch.randn_like(x_start)
341
+ + batch_size, horizon, _ = x_start.shape
342
+ + # Extract known indices and values
343
+ + known_indices = np.array(list(cond.keys()), dtype=int)
344
+ +
345
+ + # candidate_no x batch_size x dim
346
+ + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0)
347
+ + known_values = np.moveaxis(known_values, 0, 1)
348
+ +
349
+ + # Sort the timepoints
350
+ + sorted_indices = np.argsort(known_indices)
351
+ + known_indices = known_indices[sorted_indices]
352
+ + known_values = known_values[:, sorted_indices]
353
+ +
354
+ + # Build your structured guess
355
+ + catmull_spline_trajectory = np.array([
356
+ + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, horizon)
357
+ + for b in range(batch_size)
358
+ + ])
359
+ + catmull_spline_trajectory = torch.tensor(
360
+ + catmull_spline_trajectory,
361
+ + dtype=torch.float64,
362
+ + device=x_start.device
363
+ + )
364
+
365
+ - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
366
+ - x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
367
+ + # Plot the quaternions
368
+ + # plt.plot(x_start[0, :, 3].cpu().numpy())
369
+ + # plt.plot(catmull_spline_trajectory[0, :, 3].cpu().numpy())
370
+ + # plt.legend(["x_start", "catmull_spline"])
371
+ + # plt.show()
372
+ + # raise Exception
373
+
374
+ - x_recon = self.model(x_noisy, cond, t)
375
+ - x_recon = apply_conditioning(x_recon, cond, self.action_dim)
376
+
377
+ - assert noise.shape == x_recon.shape
378
+ + if not self.predict_epsilon:
379
+ + # Forward diffuse with the structured trajectory
380
+ + x_noisy = self.q_sample(
381
+ + x_start,
382
+ + t,
383
+ + spline=catmull_spline_trajectory,
384
+ + )
385
+ + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
386
+
387
+ - if self.predict_epsilon:
388
+ - loss, info = self.loss_fn(x_recon, noise)
389
+ + # Reverse pass guess
390
+ + x_recon = self.model(x_noisy, cond, t)
391
+ + x_recon = apply_conditioning(x_recon, cond, self.action_dim)
392
+ +
393
+ + # Then x_recon is the predicted x_0, compare to the true x_0
394
+ + loss, info = self.loss_fn(x_recon, x_start, cond)
395
+ else:
396
+ - loss, info = self.loss_fn(x_recon, x_start)
397
+ + residual = x_start.clone()
398
+ +
399
+ + residual[:, :, :-1] -= catmull_spline_trajectory
400
+ +
401
+ +
402
+ + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()}
403
+ +
404
+ + x_noisy = self.q_sample(
405
+ + residual,
406
+ + t,
407
+ + spline=catmull_spline_trajectory,
408
+ + )
409
+ + x_noisy = apply_conditioning(x_noisy, cond_residual, self.action_dim)
410
+ +
411
+ + # Reverse pass guess
412
+ + x_recon = self.model(x_noisy, cond, t)
413
+ + x_recon = apply_conditioning(x_recon, cond_residual, 0)
414
+ +
415
+ + x_recon = x_recon + catmull_spline_trajectory
416
+ +
417
+ + loss, info = self.loss_fn(x_recon, x_start[:, :, :-1], cond)
418
+
419
+ return loss, info
420
+
421
+ def loss(self, x, cond):
422
+ batch_size = len(x)
423
+ t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
424
+ + # t = torch.randint(1, 2, (batch_size,), device=x.device).long()
425
+ + # x = x.double()
426
+ + # cond = {k: v.double() for k, v in cond.items()}
427
+ + # print(f"Time: {t.item()}")
428
+ return self.p_losses(x, cond, t)
429
+
430
+ def forward(self, cond, *args, **kwargs):
431
+ diff --git a/diffuser/models/helpers.py b/diffuser/models/helpers.py
432
+ index d39f35d..9f43ef8 100644
433
+ --- a/diffuser/models/helpers.py
434
+ +++ b/diffuser/models/helpers.py
435
+ @@ -1,11 +1,11 @@
436
+ import math
437
+ +import json
438
+ import numpy as np
439
+ import torch
440
+ import torch.nn as nn
441
+ import torch.nn.functional as F
442
+ -import einops
443
+ from einops.layers.torch import Rearrange
444
+ -import pdb
445
+ +from pytorch3d.transforms import quaternion_to_matrix, quaternion_to_axis_angle
446
+
447
+ import diffuser.utils as utils
448
+
449
+ @@ -30,7 +30,7 @@ class SinusoidalPosEmb(nn.Module):
450
+ class Downsample1d(nn.Module):
451
+ def __init__(self, dim):
452
+ super().__init__()
453
+ - self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
454
+ + self.conv = nn.Conv1d(dim, dim, 3, 2, 1).to(torch.float64)
455
+
456
+ def forward(self, x):
457
+ return self.conv(x)
458
+ @@ -38,7 +38,7 @@ class Downsample1d(nn.Module):
459
+ class Upsample1d(nn.Module):
460
+ def __init__(self, dim):
461
+ super().__init__()
462
+ - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
463
+ + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1).to(torch.float64)
464
+
465
+ def forward(self, x):
466
+ return self.conv(x)
467
+ @@ -52,9 +52,9 @@ class Conv1dBlock(nn.Module):
468
+ super().__init__()
469
+
470
+ self.block = nn.Sequential(
471
+ - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
472
+ + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2).to(torch.float64),
473
+ Rearrange('batch channels horizon -> batch channels 1 horizon'),
474
+ - nn.GroupNorm(n_groups, out_channels),
475
+ + nn.GroupNorm(n_groups, out_channels).to(torch.float64),
476
+ Rearrange('batch channels 1 horizon -> batch channels horizon'),
477
+ nn.Mish(),
478
+ )
479
+ @@ -72,7 +72,7 @@ def extract(a, t, x_shape):
480
+ out = a.gather(-1, t)
481
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
482
+
483
+ -def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
484
+ +def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float64):
485
+ """
486
+ cosine schedule
487
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
488
+ @@ -157,9 +157,979 @@ class ValueL2(ValueLoss):
489
+ def _loss(self, pred, targ):
490
+ return F.mse_loss(pred, targ, reduction='none')
491
+
492
+ +class GeodesicL2Loss(nn.Module):
493
+ + def __init__(self, *args):
494
+ + super().__init__()
495
+ + pass
496
+ +
497
+ + def _loss(self, pred, targ):
498
+ + # Compute L2 loss for the first three dimensions
499
+ + l2_loss = F.mse_loss(pred[..., :3], targ[..., :3], reduction='mean')
500
+ +
501
+ + # Normalize to unit quaternions for the last four dimensions
502
+ + pred_quat = pred[..., 3:] / pred[..., 3:].norm(dim=-1, keepdim=True)
503
+ + targ_quat = targ[..., 3:] / targ[..., 3:].norm(dim=-1, keepdim=True)
504
+ +
505
+ + assert not torch.isnan(pred_quat).any(), "Pred Quat has NaNs"
506
+ + assert not torch.isnan(targ_quat).any(), "Targ Quat has NaNs"
507
+ +
508
+ + # Compute dot product for the quaternions
509
+ + dot_product = torch.sum(pred_quat * targ_quat, dim=-1)
510
+ + dot_product = torch.clamp(torch.abs(dot_product), -1.0, 1.0)
511
+ +
512
+ + # Compute geodesic loss for the quaternions
513
+ + geodesic_loss = 2 * torch.acos(dot_product).mean()
514
+ +
515
+ + assert not torch.isnan(geodesic_loss).any(), "Geodesic Loss has NaNs"
516
+ + assert not torch.isnan(l2_loss).any(), "L2 Loss has NaNs"
517
+ +
518
+ + return l2_loss + geodesic_loss, l2_loss, geodesic_loss
519
+ +
520
+ + def forward(self, pred, targ):
521
+ + loss, l2, geodesic = self._loss(pred, targ)
522
+ +
523
+ + info = {
524
+ + 'l2': l2.item(),
525
+ + 'geodesic': geodesic.item(),
526
+ + }
527
+ +
528
+ + return loss, info
529
+ +
530
+ +class RotationTranslationLoss(nn.Module):
531
+ + def __init__(self, *args):
532
+ + super().__init__()
533
+ + pass
534
+ +
535
+ + def _loss(self, pred, targ, cond=None):
536
+ +
537
+ + # Make sure the dtype is float64
538
+ + pred = pred.to(torch.float64)
539
+ + targ = targ.to(torch.float64)
540
+ +
541
+ + eps = 1e-8
542
+ +
543
+ + pred_trans = pred[..., :3]
544
+ + pred_quat = pred[..., 3:7]
545
+ + targ_trans = targ[..., :3]
546
+ + targ_quat = targ[..., 3:7]
547
+ +
548
+ + l2_loss = F.mse_loss(pred_trans, targ_trans, reduction='mean')
549
+ +
550
+ + # Calculate the geodesic loss
551
+ + pred_n = pred_quat.norm(dim=-1, keepdim=True).clamp(min=eps)
552
+ + targ_n = targ_quat.norm(dim=-1, keepdim=True).clamp(min=eps)
553
+ +
554
+ + pred_quat_norm = pred_quat / pred_n
555
+ + targ_quat_norm = targ_quat / targ_n
556
+ +
557
+ +
558
+ + dot_product = torch.sum(pred_quat_norm * targ_quat_norm, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
559
+ + quaternion_dist = 1 - (dot_product ** 2).mean()
560
+ +
561
+ + # Calculate the rotation error
562
+ + pred_rot = quaternion_to_matrix(pred_quat_norm).reshape(-1, 3, 3)
563
+ + targ_rot = quaternion_to_matrix(targ_quat_norm).reshape(-1, 3, 3)
564
+ +
565
+ + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1)
566
+ + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1)
567
+ + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps)
568
+ + geodesic_loss = torch.acos(trace).mean()
569
+ +
570
+ + # Add a smoothness and acceleration term to the positions and quaternions
571
+ + alpha = 1.0
572
+ + smoothness_loss = F.mse_loss(pred[:, 1:, :7].reshape(-1, 7), pred[:, :-1, :7].reshape(-1, 7), reduction='mean')
573
+ + acceleration_loss = F.mse_loss(pred[:, 2:, :7].reshape(-1, 7), 2 * pred[:, 1:-1, :7].reshape(-1, 7) - pred[:, :-2, :7].reshape(-1, 7), reduction='mean')
574
+ +
575
+ + l2_multiplier = 10.0
576
+ +
577
+ + loss = l2_multiplier * l2_loss + quaternion_dist + geodesic_loss + alpha * (smoothness_loss + acceleration_loss)
578
+ +
579
+ + dtw = DynamicTimeWarpingLoss()
580
+ + dtw_loss, _ = dtw.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
581
+ +
582
+ + hausdorff = HausdorffDistanceLoss()
583
+ + hausdorff_loss, _ = hausdorff.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
584
+ +
585
+ + frec = FrechetDistanceLoss()
586
+ + frechet_loss, _ = frec.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
587
+ +
588
+ + chamfer = ChamferDistanceLoss()
589
+ + chamfer_loss, _ = chamfer.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
590
+ +
591
+ + return loss, l2_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss
592
+ +
593
+ +
594
+ + def forward(self, pred, targ, cond=None):
595
+ + loss, err_t, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond)
596
+ +
597
+ + info = {
598
+ + 'rot. error': err_r.item(),
599
+ + 'geodesic error': err_geo.item(),
600
+ + 'trans. error': err_t.item(),
601
+ + 'dtw': err_dtw.item(),
602
+ + 'hausdorff': err_hausdorff.item(),
603
+ + 'frechet': err_frechet.item(),
604
+ + 'chamfer': err_chamfer.item(),
605
+ + }
606
+ +
607
+ + return loss, info
608
+ +
609
+ +class SplineLoss(nn.Module):
610
+ + def __init__(self, *args):
611
+ + super().__init__()
612
+ + self.scales = json.load(open('scene_scale.json'))
613
+ +
614
+ + def compute_spline_coeffs(self, trans):
615
+ + p0 = trans[:, :-3, :]
616
+ + p1 = trans[:, 1:-2, :]
617
+ + p2 = trans[:, 2:-1, :]
618
+ + p3 = trans[:, 3:, :]
619
+ +
620
+ + # Tangent approximations
621
+ + m1 = 0.5 * (-p0 + p2)
622
+ + m2 = 0.5 * (-p1 + p3)
623
+ +
624
+ + # Cubic spline coefficients for each dimension
625
+ + a = (2 * p1 - 2 * p2 + m1 + m2)
626
+ + b = (-3 * p1 + 3 * p2 - 2 * m1 - m2)
627
+ + c = (m1)
628
+ + d = (p1)
629
+ +
630
+ + return torch.stack([a, b, c, d], dim=-1)
631
+ +
632
+ + def q_normalize(self, q):
633
+ + return q / q.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12)
634
+ +
635
+ + def q_conjugate(self, q):
636
+ + w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
637
+ + return torch.stack([w, -x, -y, -z], dim=-1)
638
+ +
639
+ + def q_multiply(self, q1, q2):
640
+ + """
641
+ + q1*q2.
642
+ + """
643
+ + w1, x1, y1, z1 = q1.unbind(-1)
644
+ + w2, x2, y2, z2 = q2.unbind(-1)
645
+ + w = w1*w2 - x1*x2 - y1*y2 - z1*z2
646
+ + x = w1*x2 + x1*w2 + y1*z2 - z1*y2
647
+ + y = w1*y2 - x1*z2 + y1*w2 + z1*x2
648
+ + z = w1*z2 + x1*y2 - y1*x2 + z1*w2
649
+ + return torch.stack([w, x, y, z], dim=-1)
650
+ +
651
+ + def q_inverse(self, q):
652
+ + return self.q_conjugate(self.q_normalize(q))
653
+ +
654
+ + def q_log(self, q):
655
+ + """
656
+ + Quaternion logarithm for a unit quaternion
657
+ + Only returns the imaginary part
658
+ + """
659
+ + q = self.q_normalize(q)
660
+ + w = q[..., 0]
661
+ + xyz = q[..., 1:] # shape [..., 3]
662
+ + mag_v = xyz.norm(p=2, dim=-1)
663
+ + eps = 1e-12
664
+ + angle = torch.acos(w.clamp(-1.0 + eps, 1.0 - eps))
665
+ +
666
+ + # We do a safe-guard against zero for sin(angle)
667
+ + small_mask = (mag_v < 1e-12) | (angle < 1e-12)
668
+ + # Where small_mask is True => near identity => log(q) ~ 0
669
+ + log_val = torch.zeros_like(xyz)
670
+ +
671
+ + # Normal case
672
+ + scale = angle / mag_v.clamp(min=1e-12)
673
+ + normal_case = scale.unsqueeze(-1) * xyz
674
+ +
675
+ + log_val = torch.where(
676
+ + small_mask.unsqueeze(-1),
677
+ + torch.zeros_like(xyz),
678
+ + normal_case
679
+ + )
680
+ + return log_val
681
+ +
682
+ + def q_exp(self, v):
683
+ + """
684
+ + Quaternion exponential
685
+ + """
686
+ + norm_v = v.norm(p=2, dim=-1)
687
+ + small_mask = norm_v < 1e-12
688
+ +
689
+ + w = torch.cos(norm_v)
690
+ + sin_v = torch.sin(norm_v)
691
+ + scale = torch.where(
692
+ + small_mask,
693
+ + torch.zeros_like(norm_v), # if zero, sin(0)/0 => 0
694
+ + sin_v / norm_v.clamp(min=1e-12)
695
+ + )
696
+ + xyz = scale.unsqueeze(-1) * v
697
+ +
698
+ + # For small angles, we approximate cos(norm_v) ~ 1, sin(norm_v)/norm_v ~ 1
699
+ + w = torch.where(
700
+ + small_mask,
701
+ + torch.ones_like(w),
702
+ + w
703
+ + )
704
+ + return torch.cat([w.unsqueeze(-1), xyz], dim=-1)
705
+ +
706
+ + def q_slerp(self, q1, q2, t):
707
+ + """
708
+ + Spherical linear interpolation from q1 to q2 at t in [0,1].
709
+ + Both q1, q2 assumed normalized.
710
+ + q1, q2, t can be 1D or broadcastable shapes, but typically 1D.
711
+ + """
712
+ + q1 = self.q_normalize(q1)
713
+ + q2 = self.q_normalize(q2)
714
+ + dot = (q1 * q2).sum(dim=-1, keepdim=True) # the dot product
715
+ +
716
+ + eps = 1e-12
717
+ + dot = dot.clamp(-1.0 + eps, 1.0 - eps)
718
+ +
719
+ + flip_mask = dot < 0.0
720
+ + if flip_mask.any():
721
+ + q2 = torch.where(flip_mask, -q2, q2)
722
+ + dot = torch.where(flip_mask, -dot, dot)
723
+ +
724
+ + # If they're very close, do a simple linear interpolation
725
+ + close_mask = dot.squeeze(-1) > 0.9995
726
+ + # Using an epsilon to avoid potential issues close to 1.0
727
+ +
728
+ + # Branch 1: Very close
729
+ + # linear LERP
730
+ + lerp_val = (1.0 - t) * q1 + t * q2
731
+ + lerp_val = self.q_normalize(lerp_val)
732
+ +
733
+ + # Branch 2: Standard SLERP
734
+ + theta_0 = torch.acos(dot)
735
+ + sin_theta_0 = torch.sin(theta_0)
736
+ + theta = theta_0 * t
737
+ + s1 = torch.sin(theta_0 - theta) / sin_theta_0.clamp(min=1e-12)
738
+ + s2 = torch.sin(theta) / sin_theta_0.clamp(min=1e-12)
739
+ + slerp_val = s1 * q1 + s2 * q2
740
+ + slerp_val = self.q_normalize(slerp_val)
741
+ +
742
+ + # Combine
743
+ + return torch.where(
744
+ + close_mask.unsqueeze(-1),
745
+ + lerp_val,
746
+ + slerp_val
747
+ + )
748
+ +
749
+ + def compute_uniform_tangent(self, q_im1, q_i, q_ip1):
750
+ + """
751
+ + Computes a 'Catmull–Rom-like' tangent T_i for quaternion q_i,
752
+ + given neighbors q_im1, q_i, q_ip1.
753
+ +
754
+ + T_i = q_i * exp( -0.25 * [ log(q_i^-1 q_ip1) + log(q_i^-1 q_im1) ] )
755
+ + """
756
+ + q_im1 = self.q_normalize(q_im1)
757
+ + q_i = self.q_normalize(q_i)
758
+ + q_ip1 = self.q_normalize(q_ip1)
759
+ +
760
+ + inv_qi = self.q_inverse(q_i)
761
+ + r1 = self.q_multiply(inv_qi, q_ip1)
762
+ + r2 = self.q_multiply(inv_qi, q_im1)
763
+ +
764
+ + lr1 = self.q_log(r1)
765
+ + lr2 = self.q_log(r2)
766
+ +
767
+ + m = -0.25 * (lr1 + lr2)
768
+ + exp_m = self.q_exp(m)
769
+ + return self.q_multiply(q_i, exp_m)
770
+ +
771
+ + def compute_all_uniform_tangents(self, quats):
772
+ + """
773
+ + Vectorized version that computes tangents T_i for all keyframe quaternions at once.
774
+ + quats shape: [N,4], N >= 2
775
+ + Returns shape [N,4].
776
+ + """
777
+ + q_im1 = torch.cat([quats[[0]], quats[:-1]], dim=0) # q_im1[0] = q0
778
+ + q_ip1 = torch.cat([quats[1:], quats[[-1]]], dim=0) # q_ip1[N-1]= q_{N-1}
779
+ +
780
+ + return self.compute_uniform_tangent(q_im1, quats, q_ip1)
781
+ +
782
+ + def squad(self, q0, a, b, q1, t):
783
+ + """
784
+ + Shoemake's "squad" interpolation for quaternion splines:
785
+ + squad(q0, a, b, q1; t) = slerp( slerp(q0, q1; t),
786
+ + slerp(a, b; t),
787
+ + 2t(1-t) )
788
+ + where a, b are tangential control quaternions for q0, q1.
789
+ + """
790
+ + s1 = self.q_slerp(q0, q1, t)
791
+ + s2 = self.q_slerp(a, b, t)
792
+ + alpha = 2.0*t*(1.0 - t)
793
+ + return self.q_slerp(s1, s2, alpha)
794
+ +
795
+ + def uniform_cr_spline(self, quats, num_samples_per_segment=10):
796
+ + """
797
+ + Given a list of keyframe quaternions quats (each a torch 1D tensor [4]),
798
+ + compute a "Uniform Catmull–Rom–like" quaternion spline through them.
799
+ +
800
+ + Returns:
801
+ + A list (Python list) of interpolated quaternions (torch tensors),
802
+ + including all segment endpoints.
803
+ +
804
+ + Each interior qi gets a tangent T_i using neighbors q_{i-1}, q_i, q_{i+1}.
805
+ + For boundary tangents, we replicate the end quaternions.
806
+ + """
807
+ + n = quats.shape[0]
808
+ + if n < 2:
809
+ + return quats.unsqueeze(0) # not enough quats to interpolate
810
+ +
811
+ + # Precompute tangents
812
+ + tangents = self.compute_all_uniform_tangents(quats)
813
+ +
814
+ + # Interpolate each segment [qi, q_{i+1}]
815
+ + q0 = quats[:-1].unsqueeze(1)
816
+ + q1 = quats[1:].unsqueeze(1)
817
+ + a = tangents[:-1].unsqueeze(1)
818
+ + b = tangents[1:].unsqueeze(1)
819
+ +
820
+ + t_vals = torch.linspace(0.0, 1.0, num_samples_per_segment, device=quats.device, dtype=quats.dtype)
821
+ + t_vals = t_vals.view(1, -1, 1)
822
+ +
823
+ + out = self.squad(q0, a, b, q1, t_vals)
824
+ + return out
825
+ +
826
+ +
827
+ + def forward(self, pred, targ, cond=None, scene_id=None, norm_params=None):
828
+ + loss, err_t, err_smooth, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond, scene_id, norm_params)
829
+ +
830
+ + info = {
831
+ + 'trans. error': err_t.item(),
832
+ + 'smoothness error': err_smooth.item(),
833
+ + # 'dtw': err_dtw.item(),
834
+ + # 'hausdorff': err_hausdorff.item(),
835
+ + # 'frechet': err_frechet.item(),
836
+ + # 'chamfer': err_chamfer.item(),
837
+ + 'quat. dist.': err_r.item(),
838
+ + 'geodesic dist.': err_geo.item(),
839
+ + }
840
+ +
841
+ + return loss, info
842
+ +
843
+ + def _loss(self, pred, targ, cond=None, scene_id=None, norm_params=None):
844
+ + def poly_eval(coeffs, x):
845
+ + """
846
+ + Evaluates a polynomial (with highest-degree term first) at points x.
847
+ + coeffs: 2D tensor of shape [num_polynomials, degree + 1], highest-degree term first.
848
+ + x: 1D tensor of points at which to evaluate the polynomial.
849
+ + Returns:
850
+ + 2D tensor of shape [num_polynomials, len(x)], containing p(x).
851
+ + """
852
+ + x_powers = torch.stack([x**i for i in range(coeffs.shape[-1] - 1, -1, -1)], dim=-1)
853
+ + x_powers = x_powers.to(torch.float64).to(coeffs.device)
854
+ + y = torch.matmul(coeffs, x_powers.T)
855
+ + return y
856
+ +
857
+ + # Make sure the dtype is float64
858
+ + pred = pred.to(torch.float64)
859
+ + targ = targ.to(torch.float64)
860
+ +
861
+ + # Rescale the translations
862
+ + if scene_id is not None and norm_params is not None:
863
+ + scene_id = scene_id.item()
864
+ + scene_scale = self.scales[str(scene_id)]
865
+ + scene_scale = norm_params['scale'][0] * scene_scale
866
+ + pred[..., :3] = pred[..., :3] * scene_scale
867
+ + targ[..., :3] = targ[..., :3] * scene_scale
868
+ + # print(pred[..., :3].max(), targ[..., :3].max())
869
+ +
870
+ + # We only consider interpolated points for loss calculation
871
+ + candidate_idxs = sorted(cond.keys())
872
+ + pred = pred[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :]
873
+ + targ = targ[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :]
874
+ +
875
+ + pred_trans = pred[..., :3]
876
+ + pred_quat = pred[..., 3:7]
877
+ + targ_trans = targ[..., :3]
878
+ + targ_quat = targ[..., 3:7]
879
+ +
880
+ + pred_coeffs = self.compute_spline_coeffs(pred_trans)
881
+ + targ_coeffs = self.compute_spline_coeffs(targ_trans)
882
+ +
883
+ + n_points = 2000
884
+ +
885
+ + # Distribute sample points among intervals
886
+ + dists = torch.norm(targ_trans[:, 1:, :] - targ_trans[:, :-1, :], dim=-1).reshape(-1)
887
+ + dists_c = torch.zeros(len(candidate_idxs) - 1, device=pred.device)
888
+ + for i in range(len(candidate_idxs) - 1):
889
+ + dists_c[i] = dists[candidate_idxs[i]:candidate_idxs[i+1]].sum()
890
+ +
891
+ + weights_c = dists_c / dists_c.sum()
892
+ + scaled_c = weights_c * n_points
893
+ + points_c = torch.floor(scaled_c).int()
894
+ +
895
+ + while points_c.sum() < n_points:
896
+ + idx = torch.argmax(scaled_c - points_c)
897
+ + points_c[idx] += 1
898
+ +
899
+ + # Calculate the spline loss
900
+ + sample_points = 50
901
+ + x = torch.linspace(0, 1, sample_points, device=pred.device)
902
+ + pred_spline = poly_eval(pred_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3)
903
+ + targ_spline = poly_eval(targ_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3)
904
+ +
905
+ + indexes = []
906
+ + start_idx = candidate_idxs[0]
907
+ + for c, (idx_i0, idx_i1) in enumerate(zip(candidate_idxs[:-1], candidate_idxs[1:])):
908
+ + p = points_c[c]
909
+ + total_dist = dists_c[c]
910
+ + dist_arr = dists[idx_i0 - start_idx : idx_i1 - start_idx]
911
+ +
912
+ + step_distances = (dist_arr / sample_points).repeat_interleave(sample_points)
913
+ + cumul_distances = step_distances.cumsum(dim=0)
914
+ +
915
+ + dist_per_pick = total_dist / p
916
+ + pick_targets = torch.arange(1, p + 1, device=dists.device) * dist_per_pick
917
+ +
918
+ + pick_idxs = torch.searchsorted(cumul_distances, pick_targets, right=True)
919
+ + pick_idxs = torch.clamp(pick_idxs, max=len(cumul_distances) - 1)
920
+ +
921
+ +
922
+ + indexes_1d = torch.zeros_like(step_distances)
923
+ + indexes_1d[pick_idxs] = 1
924
+ +
925
+ + indexes_2d = indexes_1d.view(len(dist_arr), sample_points)
926
+ +
927
+ + indexes.append(indexes_2d)
928
+ +
929
+ + indexes = torch.cat(indexes)[1: -1] # The first and last candidates don't have spline representations
930
+ +
931
+ + indexes_trans = torch.stack([indexes for _ in range(3)], dim=-1)
932
+ + indexes_quat = torch.stack([indexes for _ in range(4)], dim=-1)
933
+ +
934
+ + indexes_trans = indexes_trans.to(torch.bool)
935
+ + indexes_quat = indexes_quat.to(torch.bool)
936
+ +
937
+ + pred_trans_selected_values = pred_spline[indexes_trans]
938
+ + targ_trans_selected_values = targ_spline[indexes_trans]
939
+ +
940
+ + pred_trans_selected_values = pred_trans_selected_values.reshape(-1, 3)
941
+ + targ_trans_selected_values = targ_trans_selected_values.reshape(-1, 3)
942
+ +
943
+ + # Calculate the loss for quaternions
944
+ + pred_quat = pred_quat / pred_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8)
945
+ + targ_quat = targ_quat / targ_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8)
946
+ +
947
+ + targ_quat_spline = self.uniform_cr_spline(targ_quat.reshape(-1, 4), num_samples_per_segment=sample_points)
948
+ + pred_quat_spline = self.uniform_cr_spline(pred_quat.reshape(-1, 4), num_samples_per_segment=sample_points)
949
+ +
950
+ +
951
+ + targ_quat_spline = targ_quat_spline[1:-1]
952
+ + pred_quat_spline = pred_quat_spline[1:-1]
953
+ +
954
+ +
955
+ + pred_quat_selected_values = pred_quat_spline[indexes_quat]
956
+ + targ_quat_selected_values = targ_quat_spline[indexes_quat]
957
+ +
958
+ + pred_quat_selected_values = pred_quat_selected_values.reshape(-1, 4)
959
+ + targ_quat_selected_values = targ_quat_selected_values.reshape(-1, 4)
960
+ +
961
+ + # Calculate the geodesic loss
962
+ + pred_rot = quaternion_to_matrix(pred_quat_selected_values).reshape(-1, 3, 3)
963
+ + targ_rot = quaternion_to_matrix(targ_quat_selected_values).reshape(-1, 3, 3)
964
+ +
965
+ + eps = 1e-12
966
+ + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1)
967
+ + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1)
968
+ + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps)
969
+ + geodesic_loss = torch.acos(trace).mean()
970
+ +
971
+ + # Calculate the rotation error
972
+ + dot_product = torch.sum(pred_quat_selected_values * targ_quat_selected_values, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
973
+ + quaternion_dist = 1 - (dot_product ** 2).mean()
974
+ +
975
+ + # Calculate the L2 loss
976
+ + l2_loss = F.mse_loss(pred_trans_selected_values, targ_trans_selected_values, reduction='mean')
977
+ +
978
+ + # Calculate the smoothness loss for translation and quaternion
979
+ + smoothness_multiplier = 10 ** 2 # Empirically determined multiplier for smoothness loss
980
+ + weight_acceleration = 0.1
981
+ + weight_jerk = 0.05
982
+ +
983
+ + pos_acc = pred_trans_selected_values[2:, :] - 2 * pred_trans_selected_values[1:-1, :] + pred_trans_selected_values[:-2, :]
984
+ + pos_jerk = pred_trans_selected_values[3:, :] - 3 * pred_trans_selected_values[2:-1, :] + 3 * pred_trans_selected_values[1:-2, :] - pred_trans_selected_values[:-3, :]
985
+ +
986
+ + pos_acceleration_loss = torch.mean(pos_acc ** 2)
987
+ + pos_jerk_loss = torch.mean(pos_jerk ** 2)
988
+ +
989
+ + q0 = pred_quat_selected_values[:-1, :]
990
+ + q1 = pred_quat_selected_values[1:, :]
991
+ + sign = torch.where((q0 * q1).sum(dim=-1) < 0, -1.0, 1.0)
992
+ + q1 = sign.unsqueeze(-1) * q1
993
+ +
994
+ + dq = self.q_multiply(q1, self.q_inverse(q0))
995
+ + theta = 2 * torch.acos(torch.clamp(dq[..., 0], -1.0 + 1e-8, 1.0 - 1e-8))
996
+ +
997
+ + rot_acc = theta[2:] - 2*theta[1:-1] + theta[:-2]
998
+ + rot_jerk = theta[3:] - 3*theta[2:-1] + 3*theta[1:-2] - theta[:-3]
999
+ +
1000
+ + rot_acceleration_loss = torch.mean(rot_acc ** 2)
1001
+ + rot_jerk_loss = torch.mean(rot_jerk ** 2)
1002
+ +
1003
+ + alpha_rot = 0.1 # <-- tune this (e.g. 0.1 … 10)
1004
+ +
1005
+ +
1006
+ + acceleration_loss = pos_acceleration_loss + alpha_rot * rot_acceleration_loss
1007
+ + jerk_loss = pos_jerk_loss + alpha_rot * rot_jerk_loss
1008
+ +
1009
+ + smoothness_loss = (
1010
+ + weight_acceleration * acceleration_loss
1011
+ + + weight_jerk * jerk_loss
1012
+ + ) * smoothness_multiplier
1013
+ +
1014
+ +
1015
+ + # Calculate the spline loss
1016
+ + l2_multiplier = 10.0
1017
+ + spline_loss = l2_multiplier * (l2_loss + smoothness_loss) + geodesic_loss + quaternion_dist
1018
+ +
1019
+ + dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss = None, None, None, None
1020
+ +
1021
+ + # Uncomment these lines if you want to use the other losses
1022
+ + '''
1023
+ + dtw = DynamicTimeWarpingLoss()
1024
+ + dtw_loss, _ = dtw.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
1025
+ +
1026
+ + hausdorff = HausdorffDistanceLoss()
1027
+ + hausdorff_loss, _ = hausdorff.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
1028
+ +
1029
+ + frec = FrechetDistanceLoss()
1030
+ + frechet_loss, _ = frec.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
1031
+ +
1032
+ + chamfer = ChamferDistanceLoss()
1033
+ + chamfer_loss, _ = chamfer.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
1034
+ + '''
1035
+ +
1036
+ + return spline_loss, l2_multiplier * l2_loss, l2_multiplier * smoothness_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss
1037
+ +
1038
+ +
1039
+ +class DynamicTimeWarpingLoss(nn.Module):
1040
+ + def __init__(self):
1041
+ + super().__init__()
1042
+ +
1043
+ + def _dtw_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor:
1044
+ + """
1045
+ + Computes the DTW distance between two 2D tensors (T x D),
1046
+ + where T is sequence length and D is feature dimension.
1047
+ + """
1048
+ + # seq1, seq2 shapes: (time_steps, feature_dim)
1049
+ + n, m = seq1.size(0), seq2.size(0)
1050
+ +
1051
+ + # Cost matrix (pairwise distances between all elements)
1052
+ + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype)
1053
+ + for i in range(n):
1054
+ + for j in range(m):
1055
+ + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2)
1056
+ +
1057
+ + # Accumulated cost matrix
1058
+ + dist = torch.full((n + 1, m + 1), float('inf'),
1059
+ + device=seq1.device, dtype=seq1.dtype)
1060
+ + dist[0, 0] = 0.0
1061
+ +
1062
+ + # Populate the DP table
1063
+ + for i in range(1, n + 1):
1064
+ + for j in range(1, m + 1):
1065
+ + dist[i, j] = cost[i - 1, j - 1] + torch.min(
1066
+ + torch.min(
1067
+ + dist[i - 1, j], # Insertion
1068
+ + dist[i, j - 1], # Deletion
1069
+ + ),
1070
+ + dist[i - 1, j - 1]# Match
1071
+ + )
1072
+ +
1073
+ + return dist[n, m]
1074
+ +
1075
+ + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
1076
+ + """
1077
+ + Compute the average DTW loss over a batch of sequences.
1078
+ +
1079
+ + pred, targ shapes: (batch_size, T, D)
1080
+ + """
1081
+ + # Ensure shapes match in batch dimension
1082
+ + assert pred.size(0) == targ.size(0), "Batch sizes must match."
1083
+ +
1084
+ + # Compute DTW distance per sample in the batch
1085
+ + distances = []
1086
+ + for b in range(pred.size(0)):
1087
+ + seq1 = pred[b]
1088
+ + seq2 = targ[b]
1089
+ + dtw_val = self._dtw_distance(seq1, seq2)
1090
+ + distances.append(dtw_val)
1091
+ +
1092
+ + # Stack and take mean to get scalar loss
1093
+ + dtw_loss = torch.stack(distances).mean()
1094
+ + return dtw_loss
1095
+ +
1096
+ + def forward(self, pred: torch.Tensor, targ: torch.Tensor):
1097
+ + """
1098
+ + Returns a tuple: (loss, info_dict),
1099
+ + where loss is a scalar tensor and info_dict is a dictionary
1100
+ + of extra information (e.g., loss components).
1101
+ + """
1102
+ + loss = self._loss(pred, targ)
1103
+ +
1104
+ + info = {
1105
+ + 'dtw': loss.item()
1106
+ + }
1107
+ +
1108
+ + return loss, info
1109
+ +
1110
+ +class HausdorffDistanceLoss(nn.Module):
1111
+ + def __init__(self):
1112
+ + super().__init__()
1113
+ +
1114
+ + def _hausdorff_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor:
1115
+ + """
1116
+ + Computes the Hausdorff distance between two 2D tensors (N x D),
1117
+ + where N is the number of points and D is the feature dimension.
1118
+ +
1119
+ + The Hausdorff distance H(A,B) between two sets A and B is defined as:
1120
+ + H(A, B) = max( h(A, B), h(B, A) ),
1121
+ + where
1122
+ + h(A, B) = max_{a in A} min_{b in B} d(a, b).
1123
+ +
1124
+ + Here, d(a, b) is the Euclidean distance between points a and b.
1125
+ + """
1126
+ + # set1, set2 shapes: (num_points, feature_dim)
1127
+ + n, m = set1.size(0), set2.size(0)
1128
+ +
1129
+ + # Compute pairwise distances
1130
+ + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype)
1131
+ + for i in range(n):
1132
+ + for j in range(m):
1133
+ + cost[i, j] = torch.norm(set1[i] - set2[j], p=2)
1134
+ +
1135
+ + # Forward direction: for each point in set1, find distance to closest point in set2
1136
+ + forward_min = cost.min(dim=1)[0] # Shape (n,)
1137
+ + forward_hausdorff = forward_min.max() # max over n
1138
+ +
1139
+ + # Backward direction: for each point in set2, find distance to closest point in set1
1140
+ + backward_min = cost.min(dim=0)[0] # Shape (m,)
1141
+ + backward_hausdorff = backward_min.max() # max over m
1142
+ +
1143
+ + # Hausdorff distance is the max of the two
1144
+ + hausdorff_dist = torch.max(forward_hausdorff, backward_hausdorff)
1145
+ + return hausdorff_dist
1146
+ +
1147
+ + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
1148
+ + """
1149
+ + Compute the average Hausdorff distance over a batch of point sets.
1150
+ +
1151
+ + pred, targ shapes: (batch_size, N, D)
1152
+ + """
1153
+ + # Ensure shapes match in batch dimension
1154
+ + assert pred.size(0) == targ.size(0), "Batch sizes must match."
1155
+ +
1156
+ + distances = []
1157
+ + for b in range(pred.size(0)):
1158
+ + set1 = pred[b]
1159
+ + set2 = targ[b]
1160
+ + h_dist = self._hausdorff_distance(set1, set2)
1161
+ + distances.append(h_dist)
1162
+ +
1163
+ + # Stack and take mean to get scalar loss
1164
+ + hausdorff_loss = torch.stack(distances).mean()
1165
+ + return hausdorff_loss
1166
+ +
1167
+ + def forward(self, pred: torch.Tensor, targ: torch.Tensor):
1168
+ + """
1169
+ + Returns a tuple: (loss, info_dict),
1170
+ + where loss is a scalar tensor and info_dict is a dictionary
1171
+ + of extra information (e.g., distance components).
1172
+ + """
1173
+ + loss = self._loss(pred, targ)
1174
+ +
1175
+ + info = {
1176
+ + 'hausdorff': loss.item()
1177
+ + }
1178
+ +
1179
+ + return loss, info
1180
+ +
1181
+ +class FrechetDistanceLoss(nn.Module):
1182
+ + def __init__(self):
1183
+ + super().__init__()
1184
+ +
1185
+ + def _frechet_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor:
1186
+ + """
1187
+ + Computes the (discrete) Fr��chet distance between two 2D tensors (T x D),
1188
+ + where T is the sequence length and D is the feature dimension.
1189
+ +
1190
+ + The Fréchet distance between two curves in discrete form can be computed
1191
+ + by filling in a DP table “ca” where:
1192
+ +
1193
+ + ca[i, j] = max( d(seq1[i], seq2[j]),
1194
+ + min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]) )
1195
+ +
1196
+ + with boundary conditions handled appropriately.
1197
+ + Here, d(seq1[i], seq2[j]) is the Euclidean distance.
1198
+ + """
1199
+ + n, m = seq1.size(0), seq2.size(0)
1200
+ +
1201
+ + # Cost matrix (pairwise distances between all elements)
1202
+ + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype)
1203
+ + for i in range(n):
1204
+ + for j in range(m):
1205
+ + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2)
1206
+ +
1207
+ + # DP matrix for the Fréchet distance
1208
+ + ca = torch.full((n, m), float('inf'), device=seq1.device, dtype=seq1.dtype)
1209
+ + ca[0, 0] = cost[0, 0]
1210
+ +
1211
+ + # Initialize first row
1212
+ + for i in range(1, n):
1213
+ + ca[i, 0] = torch.max(ca[i - 1, 0], cost[i, 0])
1214
+ +
1215
+ + # Initialize first column
1216
+ + for j in range(1, m):
1217
+ + ca[0, j] = torch.max(ca[0, j - 1], cost[0, j])
1218
+ +
1219
+ + # Populate the DP table
1220
+ + for i in range(1, n):
1221
+ + for j in range(1, m):
1222
+ + ca[i, j] = torch.max(
1223
+ + cost[i, j],
1224
+ + torch.min(
1225
+ + torch.min(
1226
+ + ca[i - 1, j],
1227
+ + ca[i, j - 1],
1228
+ + ),
1229
+ + ca[i - 1, j - 1]
1230
+ + )
1231
+ + )
1232
+ +
1233
+ + return ca[n - 1, m - 1]
1234
+ +
1235
+ + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
1236
+ + """
1237
+ + Compute the average Fréchet distance over a batch of sequences.
1238
+ +
1239
+ + pred, targ shapes: (batch_size, T, D)
1240
+ + """
1241
+ + # Ensure shapes match in batch dimension
1242
+ + assert pred.size(0) == targ.size(0), "Batch sizes must match."
1243
+ +
1244
+ + distances = []
1245
+ + for b in range(pred.size(0)):
1246
+ + seq1 = pred[b]
1247
+ + seq2 = targ[b]
1248
+ + fd_val = self._frechet_distance(seq1, seq2)
1249
+ + distances.append(fd_val)
1250
+ +
1251
+ + # Stack and take mean to get scalar loss
1252
+ + frechet_loss = torch.stack(distances).mean()
1253
+ + return frechet_loss
1254
+ +
1255
+ + def forward(self, pred: torch.Tensor, targ: torch.Tensor):
1256
+ + """
1257
+ + Returns a tuple: (loss, info_dict),
1258
+ + where loss is a scalar tensor and info_dict is a dictionary
1259
+ + of extra information (e.g., distance components).
1260
+ + """
1261
+ + loss = self._loss(pred, targ)
1262
+ + info = {
1263
+ + 'frechet': loss.item()
1264
+ + }
1265
+ + return loss, info
1266
+ +
1267
+ +class ChamferDistanceLoss(nn.Module):
1268
+ + def __init__(self):
1269
+ + super().__init__()
1270
+ +
1271
+ + def _chamfer_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor:
1272
+ + """
1273
+ + Computes the symmetrical Chamfer distance between
1274
+ + two 2D tensors (N x D), where N is the number of points
1275
+ + and D is the feature dimension.
1276
+ +
1277
+ + The Chamfer distance between two point sets A and B is often defined as:
1278
+ +
1279
+ + d_chamfer(A, B) = 1/|A| ∑_{a ∈ A} min_{b ∈ B} ‖a - b‖₂
1280
+ + + 1/|B| ∑_{b ∈ B} min_{a ∈ A} ‖b - a‖₂,
1281
+ +
1282
+ + where ‖·‖₂ is the Euclidean distance.
1283
+ + """
1284
+ + # set1, set2 shapes: (num_points, feature_dim)
1285
+ + n, m = set1.size(0), set2.size(0)
1286
+ +
1287
+ + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype)
1288
+ + for i in range(n):
1289
+ + for j in range(m):
1290
+ + cost[i, j] = torch.norm(set1[i] - set2[j], p=2)
1291
+ +
1292
+ + # For each point in set1, find distance to the closest point in set2
1293
+ + forward_min = cost.min(dim=1)[0] # shape: (n,)
1294
+ + forward_mean = forward_min.mean()
1295
+ +
1296
+ + # For each point in set2, find distance to the closest point in set1
1297
+ + backward_min = cost.min(dim=0)[0] # shape: (m,)
1298
+ + backward_mean = backward_min.mean()
1299
+ +
1300
+ + chamfer_dist = forward_mean + backward_mean
1301
+ + return chamfer_dist
1302
+ +
1303
+ + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
1304
+ + """
1305
+ + Compute the average Chamfer distance over a batch of point sets.
1306
+ +
1307
+ + pred, targ shapes: (batch_size, N, D)
1308
+ + """
1309
+ + # Ensure shapes match in batch dimension
1310
+ + assert pred.size(0) == targ.size(0), "Batch sizes must match."
1311
+ +
1312
+ + distances = []
1313
+ + for b in range(pred.size(0)):
1314
+ + set1 = pred[b]
1315
+ + set2 = targ[b]
1316
+ + distance_val = self._chamfer_distance(set1, set2)
1317
+ + distances.append(distance_val)
1318
+ +
1319
+ + # Combine into a single scalar
1320
+ + chamfer_loss = torch.stack(distances).mean()
1321
+ + return chamfer_loss
1322
+ +
1323
+ + def forward(self, pred: torch.Tensor, targ: torch.Tensor):
1324
+ + """
1325
+ + Returns a tuple: (loss, info_dict),
1326
+ + where 'loss' is a scalar tensor and 'info_dict' is a dictionary
1327
+ + of extra information (e.g., distance components).
1328
+ + """
1329
+ + loss = self._loss(pred, targ)
1330
+ + info = {
1331
+ + 'chamfer': loss.item()
1332
+ + }
1333
+ + return loss, info
1334
+ +
1335
+ +
1336
+ +def slerp(q1, q2, t):
1337
+ + """Spherical linear interpolation between two quaternions."""
1338
+ + q1 = q1 / np.linalg.norm(q1)
1339
+ + q2 = q2 / np.linalg.norm(q2)
1340
+ + dot = np.dot(q1, q2)
1341
+ +
1342
+ + if dot < 0.0:
1343
+ + q2 = -q2
1344
+ + dot = -dot
1345
+ + # If dot is very close to 1, use linear interpolation
1346
+ +
1347
+ + if dot > 0.9995:
1348
+ + result = q1 + t * (q2 - q1)
1349
+ + result = result / np.linalg.norm(result)
1350
+ + return result
1351
+ +
1352
+ + theta_0 = np.arccos(dot)
1353
+ + theta = theta_0 * t
1354
+ +
1355
+ + q3 = q2 - q1 * dot
1356
+ + q3 = q3 / np.linalg.norm(q3)
1357
+ + return q1 * np.cos(theta) + q3 * np.sin(theta)
1358
+ +
1359
+ +def catmull_rom_spline_with_rotation(control_points, timepoints, horizon):
1360
+ + """Compute Catmull-Rom spline for both position and quaternion rotation."""
1361
+ + spline_points = []
1362
+ + # Extrapolate the initial points
1363
+ + if timepoints[0] != 0:
1364
+ + for t in range(timepoints[0]):
1365
+ + x = control_points[0][0]
1366
+ + y = control_points[0][1]
1367
+ + z = control_points[0][2]
1368
+ + q = control_points[0][3:7]
1369
+ + spline_points.append(np.concatenate([np.array([x, y, z]), q]))
1370
+ +
1371
+ + #Linear interpolate between 0th and 1th control points
1372
+ + for t in np.linspace(0, 1, timepoints[1] - timepoints[0] + 1):
1373
+ + x = control_points[0][0] + t * (control_points[1][0] - control_points[0][0])
1374
+ + y = control_points[0][1] + t * (control_points[1][1] - control_points[0][1])
1375
+ + z = control_points[0][2] + t * (control_points[1][2] - control_points[0][2])
1376
+ + q = slerp(control_points[0][3:7], control_points[1][3:7], t)
1377
+ + spline_points.append(np.concatenate([np.array([x, y, z]), q]))
1378
+ +
1379
+ +
1380
+ + # Iterate over the control points
1381
+ + for i in range(1, len(control_points) - 2):
1382
+ + P0 = control_points[i-1][:3]
1383
+ + P1 = control_points[i][:3]
1384
+ + P2 = control_points[i+1][:3]
1385
+ + P3 = control_points[i+2][:3]
1386
+ + Q0 = control_points[i-1][3:7]
1387
+ + Q1 = control_points[i][3:7]
1388
+ + Q2 = control_points[i+1][3:7]
1389
+ + Q3 = control_points[i+2][3:7]
1390
+ +
1391
+ + # Interpolate position (using Catmull-Rom spline)
1392
+ + for idx, t in enumerate(np.linspace(0, 1, timepoints[i+1] - timepoints[i] + 1)):
1393
+ + if idx == 0:
1394
+ + continue
1395
+ +
1396
+ + x = 0.5 * ((2 * P1[0]) + (-P0[0] + P2[0]) * t +
1397
+ + (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t**2 +
1398
+ + (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t**3)
1399
+ + y = 0.5 * ((2 * P1[1]) + (-P0[1] + P2[1]) * t +
1400
+ + (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t**2 +
1401
+ + (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t**3)
1402
+ + z = 0.5 * ((2 * P1[2]) + (-P0[2] + P2[2]) * t +
1403
+ + (2 * P0[2] - 5 * P1[2] + 4 * P2[2] - P3[2]) * t**2 +
1404
+ + (-P0[2] + 3 * P1[2] - 3 * P2[2] + P3[2]) * t**3)
1405
+ + q = slerp(Q1, Q2, t)
1406
+ + spline_points.append(np.concatenate([np.array([x, y, z]), q]))
1407
+ +
1408
+ + #Linear interpolate between 2nd last and last control points
1409
+ + for idx, t in enumerate(np.linspace(0, 1, timepoints[-1] - timepoints[-2] + 1)):
1410
+ + if idx == 0:
1411
+ + continue
1412
+ + x = control_points[-2][0] + t * (control_points[-1][0] - control_points[-2][0])
1413
+ + y = control_points[-2][1] + t * (control_points[-1][1] - control_points[-2][1])
1414
+ + z = control_points[-2][2] + t * (control_points[-1][2] - control_points[-2][2])
1415
+ + q = slerp(control_points[-2][3:7], control_points[-1][3:7], t)
1416
+ + spline_points.append(np.concatenate([np.array([x, y, z]), q]))
1417
+ +
1418
+ + # Extrapolate the rest of the points
1419
+ + if timepoints[-1] != horizon:
1420
+ + for t in range(timepoints[-1] + 1, horizon):
1421
+ + x = control_points[-1][0]
1422
+ + y = control_points[-1][1]
1423
+ + z = control_points[-1][2]
1424
+ + q = control_points[-1][3:7]
1425
+ + spline_points.append(np.concatenate([np.array([x, y, z]), q]))
1426
+ +
1427
+ + stacked_spline_points = np.stack(spline_points, axis=0)
1428
+ +
1429
+ + if control_points.shape[1] != 7:
1430
+ + stacked_spline_points = np.concatenate([stacked_spline_points, np.zeros((stacked_spline_points.shape[0], 1))], axis=1)
1431
+ +
1432
+ +
1433
+ + return stacked_spline_points
1434
+ +
1435
+ +def catmull_rom_loss(trajectories, conditions, loss_fc):
1436
+ + '''
1437
+ + loss for catmull-rom interpolation
1438
+ + '''
1439
+ + batch_size, horizon, transition = trajectories.shape
1440
+ +
1441
+ + # Extract known indices and values
1442
+ + known_indices = np.array(list(conditions.keys()), dtype=int)
1443
+ +
1444
+ + # candidate_no x batch_size x dim
1445
+ + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
1446
+ + known_values = np.moveaxis(known_values, 0, 1)
1447
+ +
1448
+ + # Sort the timepoints
1449
+ + sorted_indices = np.argsort(known_indices)
1450
+ + known_indices = known_indices[sorted_indices]
1451
+ + known_values = known_values[:, sorted_indices]
1452
+ + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)])
1453
+ +
1454
+ + # Convert to tensor and move to the same device as trajectories
1455
+ + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device)
1456
+ + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}"
1457
+ + return loss_fc(spline_points, trajectories)
1458
+ +
1459
+ Losses = {
1460
+ 'l1': WeightedL1,
1461
+ 'l2': WeightedL2,
1462
+ 'value_l1': ValueL1,
1463
+ 'value_l2': ValueL2,
1464
+ + 'geodesic_l2': GeodesicL2Loss,
1465
+ + 'rotation_translation': RotationTranslationLoss,
1466
+ + 'spline': SplineLoss,
1467
+ }
1468
+ diff --git a/diffuser/models/temporal.py b/diffuser/models/temporal.py
1469
+ index e0b9e5c..0f7854a 100644
1470
+ --- a/diffuser/models/temporal.py
1471
+ +++ b/diffuser/models/temporal.py
1472
+ @@ -17,18 +17,18 @@ class ResidualTemporalBlock(nn.Module):
1473
+ super().__init__()
1474
+
1475
+ self.blocks = nn.ModuleList([
1476
+ - Conv1dBlock(inp_channels, out_channels, kernel_size),
1477
+ - Conv1dBlock(out_channels, out_channels, kernel_size),
1478
+ + Conv1dBlock(inp_channels, out_channels, kernel_size).to(dtype=torch.float64),
1479
+ + Conv1dBlock(out_channels, out_channels, kernel_size).to(dtype=torch.float64),
1480
+ ])
1481
+
1482
+ self.time_mlp = nn.Sequential(
1483
+ nn.Mish(),
1484
+ - nn.Linear(embed_dim, out_channels),
1485
+ + nn.Linear(embed_dim, out_channels).to(dtype=torch.float64),
1486
+ Rearrange('batch t -> batch t 1'),
1487
+ - )
1488
+ + ).to(dtype=torch.float64)
1489
+
1490
+ - self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
1491
+ - if inp_channels != out_channels else nn.Identity()
1492
+ + self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1).to(dtype=torch.float64) \
1493
+ + if inp_channels != out_channels else nn.Identity().to(dtype=torch.float64)
1494
+
1495
+ def forward(self, x, t):
1496
+ '''
1497
+ @@ -37,7 +37,8 @@ class ResidualTemporalBlock(nn.Module):
1498
+ returns:
1499
+ out : [ batch_size x out_channels x horizon ]
1500
+ '''
1501
+ - out = self.blocks[0](x) + self.time_mlp(t)
1502
+ +
1503
+ + out = self.blocks[0](x) + self.time_mlp(t.double())
1504
+ out = self.blocks[1](out)
1505
+ return out + self.residual_conv(x)
1506
+
1507
+ @@ -49,11 +50,11 @@ class TemporalUnet(nn.Module):
1508
+ transition_dim,
1509
+ cond_dim,
1510
+ dim=32,
1511
+ - dim_mults=(1, 2, 4, 8),
1512
+ + dim_mults=(1, 2, 4),
1513
+ ):
1514
+ super().__init__()
1515
+
1516
+ - dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
1517
+ + dims = [(transition_dim + cond_dim), *map(lambda m: dim * m, dim_mults)]
1518
+ in_out = list(zip(dims[:-1], dims[1:]))
1519
+ print(f'[ models/temporal ] Channel dimensions: {in_out}')
1520
+
1521
+ @@ -100,7 +101,7 @@ class TemporalUnet(nn.Module):
1522
+
1523
+ self.final_conv = nn.Sequential(
1524
+ Conv1dBlock(dim, dim, kernel_size=5),
1525
+ - nn.Conv1d(dim, transition_dim, 1),
1526
+ + nn.Conv1d(dim, transition_dim, 1).to(dtype=torch.float64),
1527
+ )
1528
+
1529
+ def forward(self, x, cond, time):
1530
+ @@ -129,7 +130,6 @@ class TemporalUnet(nn.Module):
1531
+ x = upsample(x)
1532
+
1533
+ x = self.final_conv(x)
1534
+ -
1535
+ x = einops.rearrange(x, 'b t h -> b h t')
1536
+ return x
1537
+
1538
+ diff --git a/diffuser/utils/arrays.py b/diffuser/utils/arrays.py
1539
+ index c3a9d24..96a7093 100644
1540
+ --- a/diffuser/utils/arrays.py
1541
+ +++ b/diffuser/utils/arrays.py
1542
+ @@ -54,7 +54,7 @@ def batchify(batch):
1543
+ 1) converting np arrays to torch tensors and
1544
+ 2) and ensuring that everything has a batch dimension
1545
+ '''
1546
+ - fn = lambda x: to_torch(x[None])
1547
+ + fn = lambda x: to_torch(x[None], dtype=torch.float64)
1548
+
1549
+ batched_vals = []
1550
+ for field in batch._fields:
1551
+ diff --git a/diffuser/utils/serialization.py b/diffuser/utils/serialization.py
1552
+ index 6cc9db9..039eb64 100644
1553
+ --- a/diffuser/utils/serialization.py
1554
+ +++ b/diffuser/utils/serialization.py
1555
+ @@ -19,7 +19,7 @@ def mkdir(savepath):
1556
+ return False
1557
+
1558
+ def get_latest_epoch(loadpath):
1559
+ - states = glob.glob1(os.path.join(*loadpath), 'state_*')
1560
+ + states = glob.glob1(os.path.join(loadpath), 'state_*')
1561
+ latest_epoch = -1
1562
+ for state in states:
1563
+ epoch = int(state.replace('state_', '').replace('.pt', ''))
1564
+ diff --git a/diffuser/utils/training.py b/diffuser/utils/training.py
1565
+ index be3556e..c21e0f0 100644
1566
+ --- a/diffuser/utils/training.py
1567
+ +++ b/diffuser/utils/training.py
1568
+ @@ -4,16 +4,24 @@ import numpy as np
1569
+ import torch
1570
+ import einops
1571
+ import pdb
1572
+ +from tqdm import tqdm
1573
+ +import wandb
1574
+ +from pytorch3d.transforms import axis_angle_to_quaternion
1575
+
1576
+ from .arrays import batch_to_device, to_np, to_device, apply_dict
1577
+ from .timer import Timer
1578
+ from .cloud import sync_logs
1579
+ +from ..models.helpers import catmull_rom_spline_with_rotation
1580
+
1581
+ def cycle(dl):
1582
+ while True:
1583
+ for data in dl:
1584
+ yield data
1585
+
1586
+ +def assert_no_nan_weights(model):
1587
+ + for name, param in model.named_parameters():
1588
+ + assert not torch.isnan(param).any(), f"NaN detected in parameter: {name}"
1589
+ +
1590
+ class EMA():
1591
+ '''
1592
+ empirical moving average
1593
+ @@ -71,13 +79,35 @@ class Trainer(object):
1594
+ self.gradient_accumulate_every = gradient_accumulate_every
1595
+
1596
+ self.dataset = dataset
1597
+ - self.dataloader = cycle(torch.utils.data.DataLoader(
1598
+ - self.dataset, batch_size=train_batch_size, num_workers=1, shuffle=True, pin_memory=True
1599
+ + dataset_size = len(self.dataset)
1600
+ +
1601
+ + # Read the indices from the .txt file
1602
+ + with open(os.path.join(results_folder, 'train_indices.txt'), 'r') as f:
1603
+ + self.train_indices = f.read()
1604
+ + self.train_indices = [int(i) for i in self.train_indices.split('\n') if i]
1605
+ +
1606
+ + with open(os.path.join(results_folder, 'val_indices.txt'), 'r') as f:
1607
+ + self.val_indices = f.read()
1608
+ + self.val_indices = [int(i) for i in self.val_indices.split('\n') if i]
1609
+ +
1610
+ +
1611
+ + self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)
1612
+ + self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)
1613
+ + self.train_dataloader = cycle(torch.utils.data.DataLoader(
1614
+ + self.train_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False
1615
+ + ))
1616
+ +
1617
+ + self.val_dataloader = cycle(torch.utils.data.DataLoader(
1618
+ + self.val_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False
1619
+ ))
1620
+ +
1621
+ self.dataloader_vis = cycle(torch.utils.data.DataLoader(
1622
+ self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
1623
+ ))
1624
+ self.renderer = renderer
1625
+ +
1626
+ +
1627
+ +
1628
+ self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr)
1629
+
1630
+ self.logdir = results_folder
1631
+ @@ -88,6 +118,8 @@ class Trainer(object):
1632
+ self.reset_parameters()
1633
+ self.step = 0
1634
+
1635
+ + self.log_to_wandb = False
1636
+ +
1637
+ def reset_parameters(self):
1638
+ self.ema_model.load_state_dict(self.model.state_dict())
1639
+
1640
+ @@ -102,36 +134,129 @@ class Trainer(object):
1641
+ #-----------------------------------------------------------------------------#
1642
+
1643
+ def train(self, n_train_steps):
1644
+ -
1645
+ + # Save the indices as .txt files
1646
+ + with open(os.path.join(self.logdir, 'train_indices.txt'), 'w') as f:
1647
+ + for idx in self.train_indices:
1648
+ + f.write(f"{idx}\n")
1649
+ + with open(os.path.join(self.logdir, 'val_indices.txt'), 'w') as f:
1650
+ + for idx in self.val_indices:
1651
+ + f.write(f"{idx}\n")
1652
+ +
1653
+ timer = Timer()
1654
+ - for step in range(n_train_steps):
1655
+ + torch.autograd.set_detect_anomaly(True)
1656
+ +
1657
+ + # Setup wandb
1658
+ + if self.log_to_wandb:
1659
+ + wandb.init(
1660
+ + project='trajectory-generation',
1661
+ + config={'lr': self.optimizer.param_groups[0]['lr'], 'batch_size': self.batch_size, 'gradient_accumulate_every': self.gradient_accumulate_every},
1662
+ + )
1663
+ +
1664
+ + for step in tqdm(range(n_train_steps)):
1665
+ +
1666
+ + mean_train_loss = 0.0
1667
+ for i in range(self.gradient_accumulate_every):
1668
+ - batch = next(self.dataloader)
1669
+ + batch = next(self.train_dataloader)
1670
+ batch = batch_to_device(batch)
1671
+ -
1672
+ - loss, infos = self.model.loss(*batch)
1673
+ +
1674
+ + loss, infos = self.model.loss(x=batch.trajectories, cond=batch.conditions)
1675
+ loss = loss / self.gradient_accumulate_every
1676
+ + mean_train_loss += loss.item()
1677
+ loss.backward()
1678
+
1679
+ + if self.log_to_wandb:
1680
+ + wandb.log({
1681
+ + 'step': self.step,
1682
+ + 'train/loss': mean_train_loss
1683
+ + })
1684
+ +
1685
+ + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
1686
+ +
1687
+ self.optimizer.step()
1688
+ self.optimizer.zero_grad()
1689
+
1690
+ + assert_no_nan_weights(self.model)
1691
+ +
1692
+ if self.step % self.update_ema_every == 0:
1693
+ self.step_ema()
1694
+
1695
+ if self.step % self.save_freq == 0:
1696
+ - label = self.step // self.label_freq * self.label_freq
1697
+ + label = self.step
1698
+ + print(f'Saving model at step {self.step}...')
1699
+ self.save(label)
1700
+
1701
+ if self.step % self.log_freq == 0:
1702
+ - infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
1703
+ - print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}')
1704
+ + val_losses = []
1705
+ + lin_int_losses = []
1706
+ +
1707
+ + val_infos_list = []
1708
+ + lin_int_infos_list = []
1709
+ +
1710
+ + catmull_losses = []
1711
+ + catmull_infos_list = []
1712
+ +
1713
+ + for _ in range(len(self.val_indices)):
1714
+ + val_batch = next(self.val_dataloader)
1715
+ + val_batch = batch_to_device(val_batch)
1716
+ +
1717
+ + traj = self.model.forward(val_batch.conditions, horizon=val_batch.trajectories.shape[1])
1718
+ + val_loss, val_infos = self.model.loss_fn(traj, val_batch.trajectories, cond=val_batch.conditions)
1719
+ +
1720
+ + val_losses.append(val_loss.item())
1721
+ + val_infos_list.append({key: val for key, val in val_infos.items()})
1722
+ +
1723
+ +
1724
+ + (lin_int_loss, lin_int_infos), lin_int_traj = self.linear_interpolation_loss(
1725
+ + val_batch.trajectories, val_batch.conditions, self.model.loss_fn
1726
+ + )
1727
+ + lin_int_losses.append(lin_int_loss.item())
1728
+ + lin_int_infos_list.append({key: val for key, val in lin_int_infos.items()})
1729
+ +
1730
+ + (catmull_loss, catmull_infos), catmull_traj = self.catmull_rom_loss(
1731
+ + val_batch.trajectories, val_batch.conditions, self.model.loss_fn
1732
+ + )
1733
+ +
1734
+ + catmull_losses.append(catmull_loss.item())
1735
+ + catmull_infos_list.append(catmull_infos)
1736
+ +
1737
+ + avg_val_loss = np.mean(val_losses)
1738
+ + avg_lin_int_loss = np.mean(lin_int_losses)
1739
+ +
1740
+ + val_infos = {key: np.mean([info[key] for info in val_infos_list]) for key in val_infos_list[0].keys()}
1741
+ + lin_int_infos = {key: np.mean([info[key] for info in lin_int_infos_list]) for key in lin_int_infos_list[0].keys()}
1742
+
1743
+ - if self.step == 0 and self.sample_freq:
1744
+ - self.render_reference(self.n_reference)
1745
+ + avg_catmull_loss = np.mean(catmull_losses)
1746
+ + catmull_infos = {key: np.mean([info[key] for info in catmull_infos_list]) for key in catmull_infos_list[0].keys()}
1747
+
1748
+ - if self.sample_freq and self.step % self.sample_freq == 0:
1749
+ - self.render_samples(n_samples=self.n_samples)
1750
+ + val_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in val_infos.items()])
1751
+ + lin_int_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in lin_int_infos.items()])
1752
+ + catmull_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in catmull_infos.items()])
1753
+ +
1754
+ +
1755
+ + infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
1756
+ + print("Learning Rate: ", self.optimizer.param_groups[0]['lr'])
1757
+ + print(f'Step {self.step}: {loss * self.gradient_accumulate_every:8.4f} | {infos_str} | t: {timer():8.4f}')
1758
+ + print(f'Validation - {self.step}: {avg_val_loss:8.4f} | {val_infos_str} | t: {timer():8.4f}')
1759
+ + print(f'Linear Interpolation Loss - {self.step}: {avg_lin_int_loss:8.4f} | {lin_int_infos_str} | t: {timer():8.4f}')
1760
+ + print(f'Catmull Rom Loss - {self.step}: {avg_catmull_loss:8.4f} | {catmull_infos_str} | t: {timer():8.4f}')
1761
+ + print()
1762
+ +
1763
+ + if self.log_to_wandb:
1764
+ + wandb.log({
1765
+ + 'step': self.step,
1766
+ + 'val/loss': avg_val_loss,
1767
+ + 'val/linear_interp/loss': avg_lin_int_loss,
1768
+ + 'val/linear_interp/quaternion dist.': lin_int_infos['quat. dist.'],
1769
+ + 'val/linear_interp/euclidean dist.': lin_int_infos['trans. error'],
1770
+ + 'val/linear_interp/geodesic loss': lin_int_infos['geodesic dist.'],
1771
+ + 'val/catmull_rom/loss': avg_catmull_loss,
1772
+ + 'val/catmull_rom/quaternion dist.': catmull_infos['quat. dist.'],
1773
+ + 'val/catmull_rom/euclidean dist.': catmull_infos['trans. error'],
1774
+ + 'val/catmull_rom/geodesic loss': catmull_infos['geodesic dist.'],
1775
+ + 'val/quaternion dist.': val_infos['quat. dist.'],
1776
+ + 'val/euclidean dist.': val_infos['trans. error'],
1777
+ + 'val/geodesic loss': val_infos['geodesic dist.'],
1778
+ + })
1779
+
1780
+ self.step += 1
1781
+
1782
+ @@ -186,15 +311,6 @@ class Trainer(object):
1783
+ normed_observations = trajectories[:, :, self.dataset.action_dim:]
1784
+ observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
1785
+
1786
+ - # from diffusion.datasets.preprocessing import blocks_cumsum_quat
1787
+ - # # observations = conditions + blocks_cumsum_quat(deltas)
1788
+ - # observations = conditions + deltas.cumsum(axis=1)
1789
+ -
1790
+ - #### @TODO: remove block-stacking specific stuff
1791
+ - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
1792
+ - # observations = blocks_add_kuka(observations)
1793
+ - ####
1794
+ -
1795
+ savepath = os.path.join(self.logdir, f'_sample-reference.png')
1796
+ self.renderer.composite(savepath, observations)
1797
+
1798
+ @@ -225,9 +341,6 @@ class Trainer(object):
1799
+ # [ 1 x 1 x observation_dim ]
1800
+ normed_conditions = to_np(batch.conditions[0])[:,None]
1801
+
1802
+ - # from diffusion.datasets.preprocessing import blocks_cumsum_quat
1803
+ - # observations = conditions + blocks_cumsum_quat(deltas)
1804
+ - # observations = conditions + deltas.cumsum(axis=1)
1805
+
1806
+ ## [ n_samples x (horizon + 1) x observation_dim ]
1807
+ normed_observations = np.concatenate([
1808
+ @@ -238,10 +351,70 @@ class Trainer(object):
1809
+ ## [ n_samples x (horizon + 1) x observation_dim ]
1810
+ observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
1811
+
1812
+ - #### @TODO: remove block-stacking specific stuff
1813
+ - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
1814
+ - # observations = blocks_add_kuka(observations)
1815
+ - ####
1816
+ -
1817
+ savepath = os.path.join(self.logdir, f'sample-{self.step}-{i}.png')
1818
+ self.renderer.composite(savepath, observations)
1819
+ +
1820
+ + def linear_interpolation_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None):
1821
+ + batch_size, horizon, transition = trajectories.shape
1822
+ +
1823
+ + # Extract known indices and values
1824
+ + known_indices = np.array(list(conditions.keys()), dtype=int)
1825
+ + # candidate_no x batch_size x dim
1826
+ + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
1827
+ + known_values = np.moveaxis(known_values, 0, 1)
1828
+ +
1829
+ + # Create time steps for interpolation
1830
+ + time_steps = np.linspace(0, horizon, num=horizon)
1831
+ +
1832
+ + # Perform interpolation across all dimensions at once
1833
+ + linear_int_arr = np.array([[
1834
+ + np.interp(time_steps, known_indices, known_values[b, :, dim])
1835
+ + for dim in range(transition)]
1836
+ + for b in range(batch_size)]
1837
+ + ).T # Transpose to match shape (horizon, transition)
1838
+ +
1839
+ + # Convert to tensor and move to the same device as trajectories
1840
+ + linear_int_arr = np.transpose(linear_int_arr, axes=[2, 0, 1])
1841
+ + linear_int_tensor = torch.tensor(linear_int_arr, dtype=torch.float64, device=trajectories.device)
1842
+ +
1843
+ + return loss_fc(linear_int_tensor, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), linear_int_tensor
1844
+ +
1845
+ +
1846
+ + def catmull_rom_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None):
1847
+ + '''
1848
+ + loss for catmull-rom interpolation
1849
+ + '''
1850
+ +
1851
+ + batch_size, horizon, transition = trajectories.shape
1852
+ +
1853
+ + # Extract known indices and values
1854
+ + known_indices = np.array(list(conditions.keys()), dtype=int)
1855
+ + # candidate_no x batch_size x dim
1856
+ + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
1857
+ + known_values = np.moveaxis(known_values, 0, 1)
1858
+ +
1859
+ + # Sort the timepoints
1860
+ + sorted_indices = np.argsort(known_indices)
1861
+ + known_indices = known_indices[sorted_indices]
1862
+ + known_values = known_values[:, sorted_indices]
1863
+ +
1864
+ + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)])
1865
+ +
1866
+ + # Convert to tensor and move to the same device as trajectories
1867
+ + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device)
1868
+ +
1869
+ + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}"
1870
+ +
1871
+ + return loss_fc(spline_points, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), spline_points
1872
+ +
1873
+ +
1874
+ +
1875
+ +
1876
+ +
1877
+ +
1878
+ +
1879
+ +
1880
+ +
1881
+ +
1882
+ +
1883
+ +
1884
+ diff --git a/scripts/train.py b/scripts/train.py
1885
+ index 2c5f299..6728d6f 100644
1886
+ --- a/scripts/train.py
1887
+ +++ b/scripts/train.py
1888
+ @@ -108,6 +108,7 @@ utils.report_parameters(model)
1889
+
1890
+ print('Testing forward...', end=' ', flush=True)
1891
+ batch = utils.batchify(dataset[0])
1892
+ +
1893
+ loss, _ = diffusion.loss(*batch)
1894
+ loss.backward()
1895
+ print('✓')
residual-diffuser/diffusion_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02206556f60d5d7911ade8ae3a68cc6c59c8ce65aa6f16481125122ea83a827b
3
+ size 316
residual-diffuser/model_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c7d03e458df1b0f5eec375eaedbc0daaab6a95a996e158fbbfcf4128a25fc1e
3
+ size 202
residual-diffuser/render_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddf10f1f4223e218c38e6190698e58a382fa33344e2d587bcde53f4700444683
3
+ size 156
residual-diffuser/state_58000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf0e30c9d06d00f9e8651f621640ab063f766fef6f85d0dd4912fe150ec6f083
3
+ size 59009153
residual-diffuser/test_indices.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 369
2
+ 764
3
+ 474
4
+ 312
5
+ 857
6
+ 384
7
+ 323
8
+ 548
9
+ 796
10
+ 212
11
+ 595
12
+ 388
13
+ 444
14
+ 120
15
+ 598
16
+ 302
17
+ 633
18
+ 688
19
+ 653
20
+ 20
21
+ 665
22
+ 67
23
+ 130
24
+ 56
25
+ 822
26
+ 160
27
+ 169
28
+ 30
29
+ 623
30
+ 200
31
+ 520
32
+ 13
33
+ 273
34
+ 296
35
+ 411
36
+ 530
37
+ 367
38
+ 579
39
+ 788
40
+ 387
41
+ 8
42
+ 216
43
+ 738
44
+ 527
45
+ 35
46
+ 713
47
+ 416
48
+ 422
49
+ 492
50
+ 680
51
+ 757
52
+ 435
53
+ 218
54
+ 643
55
+ 489
56
+ 481
57
+ 54
58
+ 760
59
+ 558
60
+ 485
61
+ 666
62
+ 619
63
+ 806
64
+ 724
65
+ 742
66
+ 452
67
+ 445
68
+ 137
69
+ 165
70
+ 260
71
+ 855
72
+ 95
73
+ 191
74
+ 736
75
+ 71
76
+ 860
77
+ 210
78
+ 176
79
+ 662
80
+ 480
81
+ 583
82
+ 34
83
+ 471
84
+ 772
85
+ 393
86
+ 466
87
+ 469
88
+ 111
89
+ 687
90
+ 125
91
+ 231
92
+ 123
93
+ 366
94
+ 304
95
+ 262
96
+ 97
97
+ 597
98
+ 177
99
+ 636
100
+ 350
residual-diffuser/train_indices.txt ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 749
2
+ 792
3
+ 22
4
+ 3
5
+ 775
6
+ 420
7
+ 83
8
+ 284
9
+ 635
10
+ 376
11
+ 700
12
+ 754
13
+ 575
14
+ 115
15
+ 122
16
+ 751
17
+ 826
18
+ 695
19
+ 263
20
+ 577
21
+ 856
22
+ 336
23
+ 249
24
+ 150
25
+ 226
26
+ 248
27
+ 478
28
+ 617
29
+ 535
30
+ 10
31
+ 329
32
+ 46
33
+ 821
34
+ 206
35
+ 807
36
+ 147
37
+ 345
38
+ 766
39
+ 768
40
+ 254
41
+ 164
42
+ 188
43
+ 133
44
+ 437
45
+ 716
46
+ 532
47
+ 391
48
+ 426
49
+ 105
50
+ 728
51
+ 463
52
+ 864
53
+ 5
54
+ 178
55
+ 640
56
+ 774
57
+ 837
58
+ 309
59
+ 348
60
+ 850
61
+ 205
62
+ 314
63
+ 346
64
+ 385
65
+ 423
66
+ 425
67
+ 707
68
+ 163
69
+ 415
70
+ 412
71
+ 599
72
+ 503
73
+ 490
74
+ 319
75
+ 693
76
+ 274
77
+ 156
78
+ 316
79
+ 135
80
+ 721
81
+ 153
82
+ 72
83
+ 162
84
+ 765
85
+ 684
86
+ 138
87
+ 90
88
+ 834
89
+ 229
90
+ 673
91
+ 195
92
+ 94
93
+ 569
94
+ 270
95
+ 786
96
+ 342
97
+ 745
98
+ 356
99
+ 80
100
+ 685
101
+ 677
102
+ 544
103
+ 220
104
+ 672
105
+ 668
106
+ 361
107
+ 7
108
+ 31
109
+ 203
110
+ 142
111
+ 539
112
+ 421
113
+ 378
114
+ 16
115
+ 816
116
+ 443
117
+ 696
118
+ 276
119
+ 151
120
+ 472
121
+ 616
122
+ 305
123
+ 681
124
+ 93
125
+ 491
126
+ 849
127
+ 183
128
+ 300
129
+ 675
130
+ 753
131
+ 294
132
+ 649
133
+ 691
134
+ 175
135
+ 467
136
+ 144
137
+ 501
138
+ 858
139
+ 779
140
+ 869
141
+ 246
142
+ 867
143
+ 333
144
+ 258
145
+ 414
146
+ 84
147
+ 486
148
+ 632
149
+ 0
150
+ 519
151
+ 830
152
+ 600
153
+ 541
154
+ 52
155
+ 631
156
+ 198
157
+ 626
158
+ 278
159
+ 552
160
+ 547
161
+ 235
162
+ 559
163
+ 528
164
+ 353
165
+ 86
166
+ 88
167
+ 718
168
+ 234
169
+ 828
170
+ 295
171
+ 829
172
+ 646
173
+ 874
174
+ 564
175
+ 525
176
+ 810
177
+ 682
178
+ 250
179
+ 861
180
+ 217
181
+ 748
182
+ 113
183
+ 740
184
+ 505
185
+ 770
186
+ 787
187
+ 611
188
+ 642
189
+ 550
190
+ 63
191
+ 567
192
+ 549
193
+ 21
194
+ 588
195
+ 524
196
+ 752
197
+ 747
198
+ 48
199
+ 53
200
+ 285
201
+ 335
202
+ 66
203
+ 453
204
+ 124
205
+ 360
206
+ 815
207
+ 390
208
+ 781
209
+ 543
210
+ 201
211
+ 283
212
+ 141
213
+ 434
214
+ 230
215
+ 613
216
+ 193
217
+ 608
218
+ 508
219
+ 199
220
+ 732
221
+ 741
222
+ 222
223
+ 76
224
+ 555
225
+ 261
226
+ 96
227
+ 436
228
+ 282
229
+ 45
230
+ 589
231
+ 11
232
+ 459
233
+ 382
234
+ 57
235
+ 877
236
+ 70
237
+ 537
238
+ 801
239
+ 129
240
+ 722
241
+ 494
242
+ 823
243
+ 26
244
+ 377
245
+ 326
246
+ 820
247
+ 310
248
+ 77
249
+ 876
250
+ 565
251
+ 504
252
+ 406
253
+ 686
254
+ 811
255
+ 482
256
+ 499
257
+ 507
258
+ 458
259
+ 386
260
+ 847
261
+ 658
262
+ 708
263
+ 100
264
+ 60
265
+ 607
266
+ 817
267
+ 663
268
+ 428
269
+ 859
270
+ 313
271
+ 68
272
+ 32
273
+ 267
274
+ 701
275
+ 139
276
+ 349
277
+ 487
278
+ 289
279
+ 225
280
+ 840
281
+ 375
282
+ 186
283
+ 875
284
+ 832
285
+ 381
286
+ 667
287
+ 777
288
+ 515
289
+ 298
290
+ 862
291
+ 773
292
+ 509
293
+ 715
294
+ 449
295
+ 664
296
+ 358
297
+ 121
298
+ 172
299
+ 594
300
+ 39
301
+ 553
302
+ 468
303
+ 370
304
+ 424
305
+ 570
306
+ 614
307
+ 477
308
+ 645
309
+ 400
310
+ 179
311
+ 441
312
+ 767
313
+ 865
314
+ 55
315
+ 497
316
+ 288
317
+ 704
318
+ 551
319
+ 809
320
+ 498
321
+ 354
322
+ 82
323
+ 408
324
+ 157
325
+ 460
326
+ 98
327
+ 145
328
+ 439
329
+ 591
330
+ 556
331
+ 211
332
+ 606
333
+ 9
334
+ 538
335
+ 719
336
+ 720
337
+ 641
338
+ 240
339
+ 841
340
+ 17
341
+ 112
342
+ 465
343
+ 733
344
+ 372
345
+ 338
346
+ 268
347
+ 219
348
+ 365
349
+ 624
350
+ 114
351
+ 87
352
+ 585
353
+ 128
354
+ 14
355
+ 612
356
+ 803
357
+ 215
358
+ 836
359
+ 297
360
+ 255
361
+ 795
362
+ 495
363
+ 730
364
+ 29
365
+ 838
366
+ 99
367
+ 161
368
+ 814
369
+ 51
370
+ 794
371
+ 36
372
+ 170
373
+ 171
374
+ 644
375
+ 271
376
+ 281
377
+ 131
378
+ 603
379
+ 514
380
+ 562
381
+ 389
382
+ 180
383
+ 239
384
+ 118
385
+ 593
386
+ 407
387
+ 574
388
+ 79
389
+ 714
390
+ 655
391
+ 394
392
+ 622
393
+ 166
394
+ 802
395
+ 780
396
+ 639
397
+ 101
398
+ 19
399
+ 253
400
+ 674
401
+ 763
402
+ 269
403
+ 427
404
+ 364
405
+ 689
406
+ 526
407
+ 800
408
+ 303
409
+ 207
410
+ 630
411
+ 184
412
+ 168
413
+ 290
414
+ 392
415
+ 251
416
+ 561
417
+ 506
418
+ 65
419
+ 450
420
+ 651
421
+ 399
422
+ 484
423
+ 190
424
+ 851
425
+ 28
426
+ 637
427
+ 202
428
+ 657
429
+ 656
430
+ 868
431
+ 808
432
+ 568
433
+ 223
434
+ 236
435
+ 108
436
+ 580
437
+ 853
438
+ 75
439
+ 58
440
+ 327
441
+ 671
442
+ 510
443
+ 578
444
+ 797
445
+ 536
446
+ 110
447
+ 602
448
+ 196
449
+ 798
450
+ 790
451
+ 185
452
+ 670
453
+ 448
454
+ 523
455
+ 692
456
+ 430
457
+ 140
458
+ 383
459
+ 252
460
+ 531
461
+ 746
462
+ 456
463
+ 557
464
+ 522
465
+ 690
466
+ 782
467
+ 286
468
+ 92
469
+ 618
470
+ 723
471
+ 605
472
+ 49
473
+ 563
474
+ 243
475
+ 64
476
+ 455
477
+ 804
478
+ 299
479
+ 292
480
+ 529
481
+ 835
482
+ 2
483
+ 831
484
+ 197
485
+ 479
486
+ 102
487
+ 470
488
+ 47
489
+ 762
490
+ 518
491
+ 703
492
+ 842
493
+ 337
494
+ 678
495
+ 698
496
+ 189
497
+ 247
498
+ 410
499
+ 213
500
+ 401
501
+ 277
502
+ 280
503
+ 173
504
+ 328
505
+ 818
506
+ 744
507
+ 238
508
+ 315
509
+ 676
510
+ 872
511
+ 756
512
+ 244
513
+ 291
514
+ 727
515
+ 155
516
+ 208
517
+ 25
518
+ 339
519
+ 755
520
+ 844
521
+ 592
522
+ 321
523
+ 521
524
+ 546
525
+ 44
526
+ 242
527
+ 759
528
+ 769
529
+ 652
530
+ 181
531
+ 275
532
+ 50
533
+ 833
534
+ 371
535
+ 584
536
+ 758
537
+ 433
538
+ 279
539
+ 627
540
+ 107
541
+ 15
542
+ 109
543
+ 854
544
+ 227
545
+ 596
546
+ 395
547
+ 182
548
+ 778
549
+ 648
550
+ 825
551
+ 628
552
+ 18
553
+ 717
554
+ 398
555
+ 601
556
+ 566
557
+ 625
558
+ 447
559
+ 660
560
+ 647
561
+ 866
562
+ 735
563
+ 462
564
+ 590
565
+ 351
566
+ 659
567
+ 330
568
+ 634
569
+ 126
570
+ 334
571
+ 324
572
+ 783
573
+ 516
574
+ 500
575
+ 743
576
+ 739
577
+ 784
578
+ 457
579
+ 709
580
+ 318
581
+ 726
582
+ 192
583
+ 697
584
+ 69
585
+ 204
586
+ 669
587
+ 461
588
+ 413
589
+ 650
590
+ 362
591
+ 824
592
+ 127
593
+ 871
594
+ 805
595
+ 355
596
+ 442
597
+ 347
598
+ 209
599
+ 117
600
+ 306
601
+ 332
602
+ 379
603
+ 42
604
+ 152
605
+ 512
606
+ 638
607
+ 106
608
+ 187
609
+ 194
610
+ 159
611
+ 89
612
+ 712
613
+ 119
614
+ 307
615
+ 214
616
+ 403
617
+ 705
618
+ 582
619
+ 586
620
+ 264
621
+ 502
622
+ 488
623
+ 409
624
+ 621
625
+ 233
626
+ 340
627
+ 863
628
+ 737
629
+ 432
630
+ 24
631
+ 576
632
+ 454
633
+ 464
634
+ 85
635
+ 404
636
+ 517
637
+ 451
638
+ 513
639
+ 483
640
+ 363
641
+ 317
642
+ 573
643
+ 620
644
+ 74
645
+ 609
646
+ 706
647
+ 301
648
+ 259
649
+ 812
650
+ 679
651
+ 396
652
+ 610
653
+ 476
654
+ 417
655
+ 827
656
+ 405
657
+ 325
658
+ 581
659
+ 6
660
+ 846
661
+ 418
662
+ 132
663
+ 710
664
+ 272
665
+ 571
666
+ 368
667
+ 533
668
+ 839
669
+ 402
670
+ 380
671
+ 245
672
+ 174
673
+ 228
674
+ 266
675
+ 143
676
+ 785
677
+ 587
678
+ 661
679
+ 344
680
+ 38
681
+ 848
682
+ 154
683
+ 265
684
+ 771
685
+ 791
686
+ 761
687
+ 493
688
+ 542
689
+ 359
690
+ 62
691
+ 149
residual-diffuser/trainer_config.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b582355f277afd95fac231f750f74b3e2e3e301b0420e40098d54537790de3ce
3
+ size 381
residual-diffuser/val_indices.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 136
2
+ 604
3
+ 683
4
+ 146
5
+ 654
6
+ 373
7
+ 511
8
+ 73
9
+ 343
10
+ 311
11
+ 819
12
+ 331
13
+ 750
14
+ 33
15
+ 341
16
+ 694
17
+ 237
18
+ 572
19
+ 91
20
+ 134
21
+ 357
22
+ 540
23
+ 729
24
+ 776
25
+ 789
26
+ 61
27
+ 852
28
+ 699
29
+ 12
30
+ 232
31
+ 734
32
+ 43
33
+ 27
34
+ 545
35
+ 158
36
+ 224
37
+ 438
38
+ 629
39
+ 104
40
+ 554
41
+ 429
42
+ 534
43
+ 873
44
+ 322
45
+ 496
46
+ 446
47
+ 287
48
+ 397
49
+ 1
50
+ 293
51
+ 725
52
+ 81
53
+ 440
54
+ 419
55
+ 702
56
+ 870
57
+ 308
58
+ 793
59
+ 103
60
+ 843
61
+ 78
62
+ 256
63
+ 475
64
+ 560
65
+ 711
66
+ 813
67
+ 431
68
+ 374
69
+ 731
70
+ 23
71
+ 167
72
+ 37
73
+ 4
74
+ 352
75
+ 116
76
+ 148
77
+ 59
78
+ 845
79
+ 221
80
+ 257
81
+ 40
82
+ 615
83
+ 473
84
+ 320
85
+ 241
86
+ 799
87
+ 41