VLA-RFT commited on
Commit
f59cd07
·
verified ·
1 Parent(s): 90d4711

Upload 18 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
action_head--checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4ba06808a556856dd72aee0ce44087982840594713c4b0ce7b465c490b8d672
3
+ size 102677921
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
chat_template.jinja ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {{- messages[0]['content'] }}
5
+ {%- else %}
6
+ {{- 'You are a helpful assistant.' }}
7
+ {%- endif %}
8
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9
+ {%- for tool in tools %}
10
+ {{- "\n" }}
11
+ {{- tool | tojson }}
12
+ {%- endfor %}
13
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17
+ {%- else %}
18
+ {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
19
+ {%- endif %}
20
+ {%- endif %}
21
+ {%- for message in messages %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
24
+ {%- elif message.role == "assistant" %}
25
+ {{- '<|im_start|>' + message.role }}
26
+ {%- if message.content %}
27
+ {{- '\n' + message.content }}
28
+ {%- endif %}
29
+ {%- for tool_call in message.tool_calls %}
30
+ {%- if tool_call.function is defined %}
31
+ {%- set tool_call = tool_call.function %}
32
+ {%- endif %}
33
+ {{- '\n<tool_call>\n{"name": "' }}
34
+ {{- tool_call.name }}
35
+ {{- '", "arguments": ' }}
36
+ {{- tool_call.arguments | tojson }}
37
+ {{- '}\n</tool_call>' }}
38
+ {%- endfor %}
39
+ {{- '<|im_end|>\n' }}
40
+ {%- elif message.role == "tool" %}
41
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
42
+ {{- '<|im_start|>user' }}
43
+ {%- endif %}
44
+ {{- '\n<tool_response>\n' }}
45
+ {{- message.content }}
46
+ {{- '\n</tool_response>' }}
47
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
48
+ {{- '<|im_end|>\n' }}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|im_start|>assistant\n' }}
54
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,3214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "arch_specifier": "no-align+fused-gelu-mlp",
3
+ "architectures": [
4
+ "OpenVLAForActionPrediction"
5
+ ],
6
+ "attn_implementation": "flash_attention_2",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
9
+ "AutoModelForVision2Seq": "modeling_prismatic.OpenVLAForActionPrediction"
10
+ },
11
+ "hf_llm_id": "",
12
+ "image_resize_strategy": "resize-naive",
13
+ "image_sizes": [
14
+ 224,
15
+ 224
16
+ ],
17
+ "llm_backbone_id": "qwen25-0_5b-extra",
18
+ "llm_max_length": 2048,
19
+ "model_type": "openvla",
20
+ "n_action_bins": 256,
21
+ "norm_stats": {
22
+ "austin_buds_dataset_converted_externally_to_rlds": {
23
+ "action": {
24
+ "mask": [
25
+ true,
26
+ true,
27
+ true,
28
+ true,
29
+ true,
30
+ true,
31
+ false
32
+ ],
33
+ "max": [
34
+ 1.0,
35
+ 1.0,
36
+ 1.0,
37
+ 0.0,
38
+ 0.0,
39
+ 0.0,
40
+ 1.0
41
+ ],
42
+ "mean": [
43
+ -0.07678354531526566,
44
+ 0.0036849044263362885,
45
+ 0.05644911900162697,
46
+ 0.0,
47
+ 0.0,
48
+ 0.0,
49
+ 0.3510494828224182
50
+ ],
51
+ "min": [
52
+ -1.0,
53
+ -1.0,
54
+ -1.0,
55
+ 0.0,
56
+ 0.0,
57
+ 0.0,
58
+ 0.0
59
+ ],
60
+ "q01": [
61
+ -1.0,
62
+ -0.9599999785423279,
63
+ -0.8714285492897034,
64
+ 0.0,
65
+ 0.0,
66
+ 0.0,
67
+ 0.0
68
+ ],
69
+ "q99": [
70
+ 1.0,
71
+ 0.8600000143051147,
72
+ 1.0,
73
+ 0.0,
74
+ 0.0,
75
+ 0.0,
76
+ 1.0
77
+ ],
78
+ "std": [
79
+ 0.6367740631103516,
80
+ 0.37889179587364197,
81
+ 0.47796326875686646,
82
+ 0.0,
83
+ 0.0,
84
+ 0.0,
85
+ 0.47721168398857117
86
+ ]
87
+ },
88
+ "num_trajectories": 50,
89
+ "num_transitions": 34112,
90
+ "proprio": {
91
+ "max": [
92
+ 0.0,
93
+ 0.0,
94
+ 0.0,
95
+ 0.0,
96
+ 0.0,
97
+ 0.0,
98
+ 0.0
99
+ ],
100
+ "mean": [
101
+ 0.0,
102
+ 0.0,
103
+ 0.0,
104
+ 0.0,
105
+ 0.0,
106
+ 0.0,
107
+ 0.0
108
+ ],
109
+ "min": [
110
+ 0.0,
111
+ 0.0,
112
+ 0.0,
113
+ 0.0,
114
+ 0.0,
115
+ 0.0,
116
+ 0.0
117
+ ],
118
+ "q01": [
119
+ 0.0,
120
+ 0.0,
121
+ 0.0,
122
+ 0.0,
123
+ 0.0,
124
+ 0.0,
125
+ 0.0
126
+ ],
127
+ "q99": [
128
+ 0.0,
129
+ 0.0,
130
+ 0.0,
131
+ 0.0,
132
+ 0.0,
133
+ 0.0,
134
+ 0.0
135
+ ],
136
+ "std": [
137
+ 0.0,
138
+ 0.0,
139
+ 0.0,
140
+ 0.0,
141
+ 0.0,
142
+ 0.0,
143
+ 0.0
144
+ ]
145
+ }
146
+ },
147
+ "austin_sailor_dataset_converted_externally_to_rlds": {
148
+ "action": {
149
+ "mask": [
150
+ true,
151
+ true,
152
+ true,
153
+ true,
154
+ true,
155
+ true,
156
+ false
157
+ ],
158
+ "max": [
159
+ 1.0,
160
+ 1.0,
161
+ 1.0,
162
+ 0.0,
163
+ 0.0,
164
+ 0.375,
165
+ 1.0
166
+ ],
167
+ "mean": [
168
+ 0.011825348250567913,
169
+ 0.006461074110120535,
170
+ 0.06023626774549484,
171
+ 0.0,
172
+ 0.0,
173
+ 0.0016465914668515325,
174
+ 0.5260950326919556
175
+ ],
176
+ "min": [
177
+ -1.0,
178
+ -1.0,
179
+ -1.0,
180
+ 0.0,
181
+ 0.0,
182
+ -0.375,
183
+ 0.0
184
+ ],
185
+ "q01": [
186
+ -1.0,
187
+ -0.9828571677207947,
188
+ -0.6000000238418579,
189
+ 0.0,
190
+ 0.0,
191
+ -0.17249999940395355,
192
+ 0.0
193
+ ],
194
+ "q99": [
195
+ 1.0,
196
+ 0.9457142949104309,
197
+ 1.0,
198
+ 0.0,
199
+ 0.0,
200
+ 0.17892856895923615,
201
+ 1.0
202
+ ],
203
+ "std": [
204
+ 0.46348899602890015,
205
+ 0.41240179538726807,
206
+ 0.411862850189209,
207
+ 0.0,
208
+ 0.0,
209
+ 0.0578610822558403,
210
+ 0.49894046783447266
211
+ ]
212
+ },
213
+ "num_trajectories": 240,
214
+ "num_transitions": 353094,
215
+ "proprio": {
216
+ "max": [
217
+ 0.0,
218
+ 0.0,
219
+ 0.0,
220
+ 0.0,
221
+ 0.0,
222
+ 0.0,
223
+ 0.0
224
+ ],
225
+ "mean": [
226
+ 0.0,
227
+ 0.0,
228
+ 0.0,
229
+ 0.0,
230
+ 0.0,
231
+ 0.0,
232
+ 0.0
233
+ ],
234
+ "min": [
235
+ 0.0,
236
+ 0.0,
237
+ 0.0,
238
+ 0.0,
239
+ 0.0,
240
+ 0.0,
241
+ 0.0
242
+ ],
243
+ "q01": [
244
+ 0.0,
245
+ 0.0,
246
+ 0.0,
247
+ 0.0,
248
+ 0.0,
249
+ 0.0,
250
+ 0.0
251
+ ],
252
+ "q99": [
253
+ 0.0,
254
+ 0.0,
255
+ 0.0,
256
+ 0.0,
257
+ 0.0,
258
+ 0.0,
259
+ 0.0
260
+ ],
261
+ "std": [
262
+ 0.0,
263
+ 0.0,
264
+ 0.0,
265
+ 0.0,
266
+ 0.0,
267
+ 0.0,
268
+ 0.0
269
+ ]
270
+ }
271
+ },
272
+ "austin_sirius_dataset_converted_externally_to_rlds": {
273
+ "action": {
274
+ "mask": [
275
+ true,
276
+ true,
277
+ true,
278
+ true,
279
+ true,
280
+ true,
281
+ false
282
+ ],
283
+ "max": [
284
+ 1.0002285242080688,
285
+ 0.960608720779419,
286
+ 1.105179786682129,
287
+ 0.0,
288
+ 0.0,
289
+ 0.341785728931427,
290
+ 1.0
291
+ ],
292
+ "mean": [
293
+ 0.07747682929039001,
294
+ 0.03195561468601227,
295
+ 0.04244732856750488,
296
+ 0.0,
297
+ 0.0,
298
+ -0.01603456400334835,
299
+ 0.43260177969932556
300
+ ],
301
+ "min": [
302
+ -1.0183025598526,
303
+ -0.9800000190734863,
304
+ -0.9774575233459473,
305
+ 0.0,
306
+ 0.0,
307
+ -0.34607142210006714,
308
+ 0.0
309
+ ],
310
+ "q01": [
311
+ -0.780905865430832,
312
+ -0.5667179036140442,
313
+ -0.5254343223571777,
314
+ 0.0,
315
+ 0.0,
316
+ -0.28495091378688814,
317
+ 0.0
318
+ ],
319
+ "q99": [
320
+ 0.9569637751579284,
321
+ 0.6971374487876891,
322
+ 0.8124888157844541,
323
+ 0.0,
324
+ 0.0,
325
+ 0.1971428543329239,
326
+ 1.0
327
+ ],
328
+ "std": [
329
+ 0.3906329572200775,
330
+ 0.2998155355453491,
331
+ 0.2782271206378937,
332
+ 0.0,
333
+ 0.0,
334
+ 0.08120622485876083,
335
+ 0.49528297781944275
336
+ ]
337
+ },
338
+ "num_trajectories": 559,
339
+ "num_transitions": 279939,
340
+ "proprio": {
341
+ "max": [
342
+ 0.0,
343
+ 0.0,
344
+ 0.0,
345
+ 0.0,
346
+ 0.0,
347
+ 0.0,
348
+ 0.0
349
+ ],
350
+ "mean": [
351
+ 0.0,
352
+ 0.0,
353
+ 0.0,
354
+ 0.0,
355
+ 0.0,
356
+ 0.0,
357
+ 0.0
358
+ ],
359
+ "min": [
360
+ 0.0,
361
+ 0.0,
362
+ 0.0,
363
+ 0.0,
364
+ 0.0,
365
+ 0.0,
366
+ 0.0
367
+ ],
368
+ "q01": [
369
+ 0.0,
370
+ 0.0,
371
+ 0.0,
372
+ 0.0,
373
+ 0.0,
374
+ 0.0,
375
+ 0.0
376
+ ],
377
+ "q99": [
378
+ 0.0,
379
+ 0.0,
380
+ 0.0,
381
+ 0.0,
382
+ 0.0,
383
+ 0.0,
384
+ 0.0
385
+ ],
386
+ "std": [
387
+ 0.0,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.0,
392
+ 0.0,
393
+ 0.0
394
+ ]
395
+ }
396
+ },
397
+ "bc_z": {
398
+ "action": {
399
+ "mask": [
400
+ true,
401
+ true,
402
+ true,
403
+ true,
404
+ true,
405
+ true,
406
+ false
407
+ ],
408
+ "max": [
409
+ 0.2165454924106598,
410
+ 0.1251407265663147,
411
+ 0.10772687941789627,
412
+ 0.33544227480888367,
413
+ 0.28117990493774414,
414
+ 0.40614867210388184,
415
+ 1.0
416
+ ],
417
+ "mean": [
418
+ -0.009958467446267605,
419
+ 0.0008958321413956583,
420
+ 0.004995597992092371,
421
+ 0.00029755113064311445,
422
+ -0.008735382929444313,
423
+ -0.030693737789988518,
424
+ 0.8344562649726868
425
+ ],
426
+ "min": [
427
+ -0.1677047461271286,
428
+ -0.14630407094955444,
429
+ -0.10066790133714676,
430
+ -0.29421567916870117,
431
+ -0.32101404666900635,
432
+ -0.4635624885559082,
433
+ 0.0
434
+ ],
435
+ "q01": [
436
+ -0.09220654994249344,
437
+ -0.06456145539879798,
438
+ -0.049121275544166565,
439
+ -0.11594625547528267,
440
+ -0.14152548640966414,
441
+ -0.2251061636209488,
442
+ 0.0
443
+ ],
444
+ "q99": [
445
+ 0.07628866866230968,
446
+ 0.058019736707210584,
447
+ 0.052540797740221024,
448
+ 0.11740604028105736,
449
+ 0.11703975558280955,
450
+ 0.16729306846857078,
451
+ 1.0
452
+ ],
453
+ "std": [
454
+ 0.03053455986082554,
455
+ 0.0231423731893301,
456
+ 0.020641816779971123,
457
+ 0.04155943542718887,
458
+ 0.046427831053733826,
459
+ 0.0769818127155304,
460
+ 0.3610210120677948
461
+ ]
462
+ },
463
+ "num_trajectories": 43264,
464
+ "num_transitions": 6015535,
465
+ "proprio": {
466
+ "max": [
467
+ 0.0,
468
+ 0.0,
469
+ 0.0,
470
+ 0.0,
471
+ 0.0,
472
+ 0.0,
473
+ 0.0
474
+ ],
475
+ "mean": [
476
+ 0.0,
477
+ 0.0,
478
+ 0.0,
479
+ 0.0,
480
+ 0.0,
481
+ 0.0,
482
+ 0.0
483
+ ],
484
+ "min": [
485
+ 0.0,
486
+ 0.0,
487
+ 0.0,
488
+ 0.0,
489
+ 0.0,
490
+ 0.0,
491
+ 0.0
492
+ ],
493
+ "q01": [
494
+ 0.0,
495
+ 0.0,
496
+ 0.0,
497
+ 0.0,
498
+ 0.0,
499
+ 0.0,
500
+ 0.0
501
+ ],
502
+ "q99": [
503
+ 0.0,
504
+ 0.0,
505
+ 0.0,
506
+ 0.0,
507
+ 0.0,
508
+ 0.0,
509
+ 0.0
510
+ ],
511
+ "std": [
512
+ 0.0,
513
+ 0.0,
514
+ 0.0,
515
+ 0.0,
516
+ 0.0,
517
+ 0.0,
518
+ 0.0
519
+ ]
520
+ }
521
+ },
522
+ "berkeley_autolab_ur5": {
523
+ "action": {
524
+ "mask": [
525
+ true,
526
+ true,
527
+ true,
528
+ true,
529
+ true,
530
+ true,
531
+ false
532
+ ],
533
+ "max": [
534
+ 0.019999999552965164,
535
+ 0.019999999552965164,
536
+ 0.019999999552965164,
537
+ 0.06666667014360428,
538
+ 0.06666667014360428,
539
+ 0.06666667014360428,
540
+ 1.0
541
+ ],
542
+ "mean": [
543
+ 0.0005683620693162084,
544
+ 0.001217700308188796,
545
+ -0.0005296372692100704,
546
+ 0.00021029810886830091,
547
+ 6.0695128922816366e-05,
548
+ 0.001204986940138042,
549
+ 0.6298308372497559
550
+ ],
551
+ "min": [
552
+ -0.019999999552965164,
553
+ -0.019999999552965164,
554
+ -0.019999999552965164,
555
+ -0.06666667014360428,
556
+ -0.06666667014360428,
557
+ -0.06666667014360428,
558
+ 0.0
559
+ ],
560
+ "q01": [
561
+ -0.019999999552965164,
562
+ -0.019999999552965164,
563
+ -0.019999999552965164,
564
+ -0.02628571353852749,
565
+ -0.06666667014360428,
566
+ -0.03847619146108627,
567
+ 0.0
568
+ ],
569
+ "q99": [
570
+ 0.019999999552965164,
571
+ 0.019999999552965164,
572
+ 0.019999999552965164,
573
+ 0.031809523701667786,
574
+ 0.06666667014360428,
575
+ 0.036571428179740906,
576
+ 1.0
577
+ ],
578
+ "std": [
579
+ 0.0115329809486866,
580
+ 0.007990492507815361,
581
+ 0.009577835910022259,
582
+ 0.009432995691895485,
583
+ 0.016427582129836082,
584
+ 0.011053967289626598,
585
+ 0.48267969489097595
586
+ ]
587
+ },
588
+ "num_trajectories": 1000,
589
+ "num_transitions": 97939,
590
+ "proprio": {
591
+ "max": [
592
+ 0.0,
593
+ 0.0,
594
+ 0.0,
595
+ 0.0,
596
+ 0.0,
597
+ 0.0,
598
+ 0.0
599
+ ],
600
+ "mean": [
601
+ 0.0,
602
+ 0.0,
603
+ 0.0,
604
+ 0.0,
605
+ 0.0,
606
+ 0.0,
607
+ 0.0
608
+ ],
609
+ "min": [
610
+ 0.0,
611
+ 0.0,
612
+ 0.0,
613
+ 0.0,
614
+ 0.0,
615
+ 0.0,
616
+ 0.0
617
+ ],
618
+ "q01": [
619
+ 0.0,
620
+ 0.0,
621
+ 0.0,
622
+ 0.0,
623
+ 0.0,
624
+ 0.0,
625
+ 0.0
626
+ ],
627
+ "q99": [
628
+ 0.0,
629
+ 0.0,
630
+ 0.0,
631
+ 0.0,
632
+ 0.0,
633
+ 0.0,
634
+ 0.0
635
+ ],
636
+ "std": [
637
+ 0.0,
638
+ 0.0,
639
+ 0.0,
640
+ 0.0,
641
+ 0.0,
642
+ 0.0,
643
+ 0.0
644
+ ]
645
+ }
646
+ },
647
+ "berkeley_cable_routing": {
648
+ "action": {
649
+ "mask": [
650
+ true,
651
+ true,
652
+ true,
653
+ true,
654
+ true,
655
+ true,
656
+ false
657
+ ],
658
+ "max": [
659
+ 0.9633283019065857,
660
+ 1.0,
661
+ 1.0,
662
+ 0.0,
663
+ 0.0,
664
+ 1.0,
665
+ 0.0
666
+ ],
667
+ "mean": [
668
+ -0.07139874249696732,
669
+ 0.023609008640050888,
670
+ 0.10241943597793579,
671
+ 0.0,
672
+ 0.0,
673
+ 0.049671024084091187,
674
+ 0.0
675
+ ],
676
+ "min": [
677
+ -0.9809081554412842,
678
+ -0.9554349184036255,
679
+ -0.9994775056838989,
680
+ 0.0,
681
+ 0.0,
682
+ -1.0,
683
+ 0.0
684
+ ],
685
+ "q01": [
686
+ -0.5534318816661835,
687
+ -0.4797285574674606,
688
+ -0.5314934802055359,
689
+ 0.0,
690
+ 0.0,
691
+ -0.8855219376087189,
692
+ 0.0
693
+ ],
694
+ "q99": [
695
+ 0.42652835428714786,
696
+ 0.5000944086909298,
697
+ 0.639823433756829,
698
+ 0.0,
699
+ 0.0,
700
+ 0.984243879914284,
701
+ 0.0
702
+ ],
703
+ "std": [
704
+ 0.1815500408411026,
705
+ 0.1810990273952484,
706
+ 0.21220779418945312,
707
+ 0.0,
708
+ 0.0,
709
+ 0.3475511968135834,
710
+ 0.0
711
+ ]
712
+ },
713
+ "num_trajectories": 1647,
714
+ "num_transitions": 42328,
715
+ "proprio": {
716
+ "max": [
717
+ 0.0,
718
+ 0.0,
719
+ 0.0,
720
+ 0.0,
721
+ 0.0,
722
+ 0.0,
723
+ 0.0
724
+ ],
725
+ "mean": [
726
+ 0.0,
727
+ 0.0,
728
+ 0.0,
729
+ 0.0,
730
+ 0.0,
731
+ 0.0,
732
+ 0.0
733
+ ],
734
+ "min": [
735
+ 0.0,
736
+ 0.0,
737
+ 0.0,
738
+ 0.0,
739
+ 0.0,
740
+ 0.0,
741
+ 0.0
742
+ ],
743
+ "q01": [
744
+ 0.0,
745
+ 0.0,
746
+ 0.0,
747
+ 0.0,
748
+ 0.0,
749
+ 0.0,
750
+ 0.0
751
+ ],
752
+ "q99": [
753
+ 0.0,
754
+ 0.0,
755
+ 0.0,
756
+ 0.0,
757
+ 0.0,
758
+ 0.0,
759
+ 0.0
760
+ ],
761
+ "std": [
762
+ 0.0,
763
+ 0.0,
764
+ 0.0,
765
+ 0.0,
766
+ 0.0,
767
+ 0.0,
768
+ 0.0
769
+ ]
770
+ }
771
+ },
772
+ "berkeley_fanuc_manipulation": {
773
+ "action": {
774
+ "mask": [
775
+ true,
776
+ true,
777
+ true,
778
+ true,
779
+ true,
780
+ true,
781
+ false
782
+ ],
783
+ "max": [
784
+ 0.009999999776482582,
785
+ 0.009999999776482582,
786
+ 0.009999999776482582,
787
+ 0.03490658476948738,
788
+ 0.03490658476948738,
789
+ 0.03490658476948738,
790
+ 1.0
791
+ ],
792
+ "mean": [
793
+ 0.0007744057802483439,
794
+ -0.00031240080716088414,
795
+ -0.0015001941937953234,
796
+ -0.0007515158504247665,
797
+ -0.00015832878125365824,
798
+ 0.00014327642566058785,
799
+ 0.699295699596405
800
+ ],
801
+ "min": [
802
+ -0.009999999776482582,
803
+ -0.009999999776482582,
804
+ -0.009999999776482582,
805
+ -0.03490658476948738,
806
+ -0.03490658476948738,
807
+ -0.03490658476948738,
808
+ 0.0
809
+ ],
810
+ "q01": [
811
+ -0.009999999776482582,
812
+ -0.009999999776482582,
813
+ -0.009999999776482582,
814
+ -0.03490658476948738,
815
+ 0.0,
816
+ -0.03490658476948738,
817
+ 0.0
818
+ ],
819
+ "q99": [
820
+ 0.009999999776482582,
821
+ 0.009999999776482582,
822
+ 0.009999999776482582,
823
+ 0.03490658476948738,
824
+ 0.0,
825
+ 0.03490658476948738,
826
+ 1.0
827
+ ],
828
+ "std": [
829
+ 0.0034070091787725687,
830
+ 0.0049921851605176926,
831
+ 0.005344334989786148,
832
+ 0.00759894959628582,
833
+ 0.004081866703927517,
834
+ 0.008568956516683102,
835
+ 0.4586937427520752
836
+ ]
837
+ },
838
+ "num_trajectories": 415,
839
+ "num_transitions": 62613,
840
+ "proprio": {
841
+ "max": [
842
+ 0.0,
843
+ 0.0,
844
+ 0.0,
845
+ 0.0,
846
+ 0.0,
847
+ 0.0,
848
+ 0.0
849
+ ],
850
+ "mean": [
851
+ 0.0,
852
+ 0.0,
853
+ 0.0,
854
+ 0.0,
855
+ 0.0,
856
+ 0.0,
857
+ 0.0
858
+ ],
859
+ "min": [
860
+ 0.0,
861
+ 0.0,
862
+ 0.0,
863
+ 0.0,
864
+ 0.0,
865
+ 0.0,
866
+ 0.0
867
+ ],
868
+ "q01": [
869
+ 0.0,
870
+ 0.0,
871
+ 0.0,
872
+ 0.0,
873
+ 0.0,
874
+ 0.0,
875
+ 0.0
876
+ ],
877
+ "q99": [
878
+ 0.0,
879
+ 0.0,
880
+ 0.0,
881
+ 0.0,
882
+ 0.0,
883
+ 0.0,
884
+ 0.0
885
+ ],
886
+ "std": [
887
+ 0.0,
888
+ 0.0,
889
+ 0.0,
890
+ 0.0,
891
+ 0.0,
892
+ 0.0,
893
+ 0.0
894
+ ]
895
+ }
896
+ },
897
+ "bridge_orig": {
898
+ "action": {
899
+ "mask": [
900
+ true,
901
+ true,
902
+ true,
903
+ true,
904
+ true,
905
+ true,
906
+ false
907
+ ],
908
+ "max": [
909
+ 0.41691166162490845,
910
+ 0.25864794850349426,
911
+ 0.21218234300613403,
912
+ 3.122201919555664,
913
+ 1.8618112802505493,
914
+ 6.280478477478027,
915
+ 1.0
916
+ ],
917
+ "mean": [
918
+ 0.0002334194869035855,
919
+ 0.00013004911306779832,
920
+ -0.00012762474943883717,
921
+ -0.0001556558854645118,
922
+ -0.0004039328487124294,
923
+ 0.00023557482927571982,
924
+ 0.5764579176902771
925
+ ],
926
+ "min": [
927
+ -0.4007510244846344,
928
+ -0.13874775171279907,
929
+ -0.22553899884223938,
930
+ -3.2010786533355713,
931
+ -1.8618112802505493,
932
+ -6.279075622558594,
933
+ 0.0
934
+ ],
935
+ "q01": [
936
+ -0.02872725307941437,
937
+ -0.04170349963009357,
938
+ -0.026093858778476715,
939
+ -0.08092105075716972,
940
+ -0.09288699507713317,
941
+ -0.20718276381492615,
942
+ 0.0
943
+ ],
944
+ "q99": [
945
+ 0.028309678435325586,
946
+ 0.040855254605412394,
947
+ 0.040161586627364146,
948
+ 0.08192047759890528,
949
+ 0.07792850524187081,
950
+ 0.20382574498653397,
951
+ 1.0
952
+ ],
953
+ "std": [
954
+ 0.009765930473804474,
955
+ 0.013689135201275349,
956
+ 0.012667362578213215,
957
+ 0.028534092009067535,
958
+ 0.030637972056865692,
959
+ 0.07691419124603271,
960
+ 0.4973701536655426
961
+ ]
962
+ },
963
+ "num_trajectories": 60064,
964
+ "num_transitions": 2135463,
965
+ "proprio": {
966
+ "max": [
967
+ 0.0,
968
+ 0.0,
969
+ 0.0,
970
+ 0.0,
971
+ 0.0,
972
+ 0.0,
973
+ 0.0
974
+ ],
975
+ "mean": [
976
+ 0.0,
977
+ 0.0,
978
+ 0.0,
979
+ 0.0,
980
+ 0.0,
981
+ 0.0,
982
+ 0.0
983
+ ],
984
+ "min": [
985
+ 0.0,
986
+ 0.0,
987
+ 0.0,
988
+ 0.0,
989
+ 0.0,
990
+ 0.0,
991
+ 0.0
992
+ ],
993
+ "q01": [
994
+ 0.0,
995
+ 0.0,
996
+ 0.0,
997
+ 0.0,
998
+ 0.0,
999
+ 0.0,
1000
+ 0.0
1001
+ ],
1002
+ "q99": [
1003
+ 0.0,
1004
+ 0.0,
1005
+ 0.0,
1006
+ 0.0,
1007
+ 0.0,
1008
+ 0.0,
1009
+ 0.0
1010
+ ],
1011
+ "std": [
1012
+ 0.0,
1013
+ 0.0,
1014
+ 0.0,
1015
+ 0.0,
1016
+ 0.0,
1017
+ 0.0,
1018
+ 0.0
1019
+ ]
1020
+ }
1021
+ },
1022
+ "cmu_stretch": {
1023
+ "action": {
1024
+ "mask": [
1025
+ true,
1026
+ true,
1027
+ true,
1028
+ true,
1029
+ true,
1030
+ true,
1031
+ false
1032
+ ],
1033
+ "max": [
1034
+ 0.02338407188653946,
1035
+ 0.0,
1036
+ 0.023404927924275398,
1037
+ 0.0,
1038
+ 0.0,
1039
+ 0.0,
1040
+ 1.0
1041
+ ],
1042
+ "mean": [
1043
+ 0.00036304505192674696,
1044
+ 0.0,
1045
+ 0.0016466958913952112,
1046
+ 0.0,
1047
+ 0.0,
1048
+ 0.0,
1049
+ 0.3987048268318176
1050
+ ],
1051
+ "min": [
1052
+ -0.019353797659277916,
1053
+ 0.0,
1054
+ -0.02019215188920498,
1055
+ 0.0,
1056
+ 0.0,
1057
+ 0.0,
1058
+ 0.0
1059
+ ],
1060
+ "q01": [
1061
+ -0.011175686959177256,
1062
+ 0.0,
1063
+ -0.0032206363626755773,
1064
+ 0.0,
1065
+ 0.0,
1066
+ 0.0,
1067
+ 0.0
1068
+ ],
1069
+ "q99": [
1070
+ 0.014501785952597848,
1071
+ 0.0,
1072
+ 0.015056106168776728,
1073
+ 0.0,
1074
+ 0.0,
1075
+ 0.0,
1076
+ 1.0
1077
+ ],
1078
+ "std": [
1079
+ 0.004081828519701958,
1080
+ 0.0,
1081
+ 0.0037743328139185905,
1082
+ 0.0,
1083
+ 0.0,
1084
+ 0.0,
1085
+ 0.48963725566864014
1086
+ ]
1087
+ },
1088
+ "num_trajectories": 135,
1089
+ "num_transitions": 25016,
1090
+ "proprio": {
1091
+ "max": [
1092
+ 0.0,
1093
+ 0.0,
1094
+ 0.0,
1095
+ 0.0,
1096
+ 0.0,
1097
+ 0.0,
1098
+ 0.0
1099
+ ],
1100
+ "mean": [
1101
+ 0.0,
1102
+ 0.0,
1103
+ 0.0,
1104
+ 0.0,
1105
+ 0.0,
1106
+ 0.0,
1107
+ 0.0
1108
+ ],
1109
+ "min": [
1110
+ 0.0,
1111
+ 0.0,
1112
+ 0.0,
1113
+ 0.0,
1114
+ 0.0,
1115
+ 0.0,
1116
+ 0.0
1117
+ ],
1118
+ "q01": [
1119
+ 0.0,
1120
+ 0.0,
1121
+ 0.0,
1122
+ 0.0,
1123
+ 0.0,
1124
+ 0.0,
1125
+ 0.0
1126
+ ],
1127
+ "q99": [
1128
+ 0.0,
1129
+ 0.0,
1130
+ 0.0,
1131
+ 0.0,
1132
+ 0.0,
1133
+ 0.0,
1134
+ 0.0
1135
+ ],
1136
+ "std": [
1137
+ 0.0,
1138
+ 0.0,
1139
+ 0.0,
1140
+ 0.0,
1141
+ 0.0,
1142
+ 0.0,
1143
+ 0.0
1144
+ ]
1145
+ }
1146
+ },
1147
+ "dlr_edan_shared_control_converted_externally_to_rlds": {
1148
+ "action": {
1149
+ "mask": [
1150
+ true,
1151
+ true,
1152
+ true,
1153
+ true,
1154
+ true,
1155
+ true,
1156
+ false
1157
+ ],
1158
+ "max": [
1159
+ 0.18991442024707794,
1160
+ 0.0739002525806427,
1161
+ 0.18064819276332855,
1162
+ 0.0866486132144928,
1163
+ 0.13464981317520142,
1164
+ 0.16910280287265778,
1165
+ 1.0
1166
+ ],
1167
+ "mean": [
1168
+ 0.006647810339927673,
1169
+ -0.0007657372043468058,
1170
+ 0.006522852927446365,
1171
+ 0.0011679717572405934,
1172
+ -0.006395625416189432,
1173
+ -0.011902998201549053,
1174
+ 0.6985887289047241
1175
+ ],
1176
+ "min": [
1177
+ -0.10054297000169754,
1178
+ -0.08427435159683228,
1179
+ -0.13533438742160797,
1180
+ -0.17556548118591309,
1181
+ -0.18485672771930695,
1182
+ -0.2680685818195343,
1183
+ 0.0
1184
+ ],
1185
+ "q01": [
1186
+ -0.02987122368067503,
1187
+ -0.06013262912631035,
1188
+ -0.08286409199237824,
1189
+ -0.05924444157630205,
1190
+ -0.15986866518855095,
1191
+ -0.15636983573436739,
1192
+ 0.0
1193
+ ],
1194
+ "q99": [
1195
+ 0.08832092039287087,
1196
+ 0.042126184627413736,
1197
+ 0.11311905644834042,
1198
+ 0.0643695573508739,
1199
+ 0.03941855944693088,
1200
+ 0.156646853685379,
1201
+ 1.0
1202
+ ],
1203
+ "std": [
1204
+ 0.021393608301877975,
1205
+ 0.01814231649041176,
1206
+ 0.03374375030398369,
1207
+ 0.01743541844189167,
1208
+ 0.03394376486539841,
1209
+ 0.04641875624656677,
1210
+ 0.4588589072227478
1211
+ ]
1212
+ },
1213
+ "num_trajectories": 104,
1214
+ "num_transitions": 8928,
1215
+ "proprio": {
1216
+ "max": [
1217
+ 0.0,
1218
+ 0.0,
1219
+ 0.0,
1220
+ 0.0,
1221
+ 0.0,
1222
+ 0.0,
1223
+ 0.0
1224
+ ],
1225
+ "mean": [
1226
+ 0.0,
1227
+ 0.0,
1228
+ 0.0,
1229
+ 0.0,
1230
+ 0.0,
1231
+ 0.0,
1232
+ 0.0
1233
+ ],
1234
+ "min": [
1235
+ 0.0,
1236
+ 0.0,
1237
+ 0.0,
1238
+ 0.0,
1239
+ 0.0,
1240
+ 0.0,
1241
+ 0.0
1242
+ ],
1243
+ "q01": [
1244
+ 0.0,
1245
+ 0.0,
1246
+ 0.0,
1247
+ 0.0,
1248
+ 0.0,
1249
+ 0.0,
1250
+ 0.0
1251
+ ],
1252
+ "q99": [
1253
+ 0.0,
1254
+ 0.0,
1255
+ 0.0,
1256
+ 0.0,
1257
+ 0.0,
1258
+ 0.0,
1259
+ 0.0
1260
+ ],
1261
+ "std": [
1262
+ 0.0,
1263
+ 0.0,
1264
+ 0.0,
1265
+ 0.0,
1266
+ 0.0,
1267
+ 0.0,
1268
+ 0.0
1269
+ ]
1270
+ }
1271
+ },
1272
+ "dobbe": {
1273
+ "action": {
1274
+ "mask": [
1275
+ true,
1276
+ true,
1277
+ true,
1278
+ true,
1279
+ true,
1280
+ true,
1281
+ false
1282
+ ],
1283
+ "max": [
1284
+ 38.590423583984375,
1285
+ 17.932697296142578,
1286
+ 4.843764305114746,
1287
+ 1.4372116327285767,
1288
+ 0.4340403974056244,
1289
+ 1.2057193517684937,
1290
+ 0.9998947381973267
1291
+ ],
1292
+ "mean": [
1293
+ -0.0001120665911003016,
1294
+ 0.0011229600058868527,
1295
+ -0.00010194431524723768,
1296
+ -7.371398532995954e-05,
1297
+ -0.00067531579406932,
1298
+ -5.6643435527803376e-05,
1299
+ 0.6318281888961792
1300
+ ],
1301
+ "min": [
1302
+ -5.700923442840576,
1303
+ -21.605947494506836,
1304
+ -123.72489929199219,
1305
+ -1.7229845523834229,
1306
+ -0.4998578727245331,
1307
+ -0.8867913484573364,
1308
+ 1.4196479014572105e-06
1309
+ ],
1310
+ "q01": [
1311
+ -0.01119564864784479,
1312
+ -0.014266146533191203,
1313
+ -0.0071747214533388615,
1314
+ -0.009444301575422287,
1315
+ -0.03990109823644161,
1316
+ -0.017422311007976532,
1317
+ 4.003279136668425e-05
1318
+ ],
1319
+ "q99": [
1320
+ 0.01015154086053368,
1321
+ 0.017181577533483497,
1322
+ 0.007216989761218411,
1323
+ 0.010380979906767595,
1324
+ 0.03556173853576176,
1325
+ 0.018032474815845446,
1326
+ 0.9982578039169312
1327
+ ],
1328
+ "std": [
1329
+ 0.04264938458800316,
1330
+ 0.04428559169173241,
1331
+ 0.12224084138870239,
1332
+ 0.005388413090258837,
1333
+ 0.011246449314057827,
1334
+ 0.006287882570177317,
1335
+ 0.39732322096824646
1336
+ ]
1337
+ },
1338
+ "num_trajectories": 5208,
1339
+ "num_transitions": 1139911,
1340
+ "proprio": {
1341
+ "max": [
1342
+ 0.0,
1343
+ 0.0,
1344
+ 0.0,
1345
+ 0.0,
1346
+ 0.0,
1347
+ 0.0,
1348
+ 0.0
1349
+ ],
1350
+ "mean": [
1351
+ 0.0,
1352
+ 0.0,
1353
+ 0.0,
1354
+ 0.0,
1355
+ 0.0,
1356
+ 0.0,
1357
+ 0.0
1358
+ ],
1359
+ "min": [
1360
+ 0.0,
1361
+ 0.0,
1362
+ 0.0,
1363
+ 0.0,
1364
+ 0.0,
1365
+ 0.0,
1366
+ 0.0
1367
+ ],
1368
+ "q01": [
1369
+ 0.0,
1370
+ 0.0,
1371
+ 0.0,
1372
+ 0.0,
1373
+ 0.0,
1374
+ 0.0,
1375
+ 0.0
1376
+ ],
1377
+ "q99": [
1378
+ 0.0,
1379
+ 0.0,
1380
+ 0.0,
1381
+ 0.0,
1382
+ 0.0,
1383
+ 0.0,
1384
+ 0.0
1385
+ ],
1386
+ "std": [
1387
+ 0.0,
1388
+ 0.0,
1389
+ 0.0,
1390
+ 0.0,
1391
+ 0.0,
1392
+ 0.0,
1393
+ 0.0
1394
+ ]
1395
+ }
1396
+ },
1397
+ "fmb_dataset": {
1398
+ "action": {
1399
+ "mask": [
1400
+ true,
1401
+ true,
1402
+ true,
1403
+ true,
1404
+ true,
1405
+ true,
1406
+ false
1407
+ ],
1408
+ "max": [
1409
+ 1.399999976158142,
1410
+ 1.0,
1411
+ 1.399999976158142,
1412
+ 1.0,
1413
+ 1.0,
1414
+ 1.0,
1415
+ 1.0
1416
+ ],
1417
+ "mean": [
1418
+ 0.059029702097177505,
1419
+ -0.06476633995771408,
1420
+ -0.09787475317716599,
1421
+ 0.004325388930737972,
1422
+ 0.00028963794466108084,
1423
+ -0.04457257315516472,
1424
+ 0.7336440086364746
1425
+ ],
1426
+ "min": [
1427
+ -1.399999976158142,
1428
+ -1.399999976158142,
1429
+ -1.0,
1430
+ -1.0,
1431
+ -1.0,
1432
+ -1.0,
1433
+ 0.0
1434
+ ],
1435
+ "q01": [
1436
+ -0.8257142901420593,
1437
+ -1.399999976158142,
1438
+ -1.0,
1439
+ -1.0,
1440
+ -0.3028571307659149,
1441
+ -1.0,
1442
+ 0.0
1443
+ ],
1444
+ "q99": [
1445
+ 1.0,
1446
+ 0.5257142782211304,
1447
+ 1.0,
1448
+ 1.0,
1449
+ 0.3400000035762787,
1450
+ 1.0,
1451
+ 1.0
1452
+ ],
1453
+ "std": [
1454
+ 0.28809213638305664,
1455
+ 0.2820415794849396,
1456
+ 0.4626740515232086,
1457
+ 0.3266514539718628,
1458
+ 0.10842999070882797,
1459
+ 0.3440099358558655,
1460
+ 0.4435282051563263
1461
+ ]
1462
+ },
1463
+ "num_trajectories": 8612,
1464
+ "num_transitions": 1137459,
1465
+ "proprio": {
1466
+ "max": [
1467
+ 0.0,
1468
+ 0.0,
1469
+ 0.0,
1470
+ 0.0,
1471
+ 0.0,
1472
+ 0.0,
1473
+ 0.0
1474
+ ],
1475
+ "mean": [
1476
+ 0.0,
1477
+ 0.0,
1478
+ 0.0,
1479
+ 0.0,
1480
+ 0.0,
1481
+ 0.0,
1482
+ 0.0
1483
+ ],
1484
+ "min": [
1485
+ 0.0,
1486
+ 0.0,
1487
+ 0.0,
1488
+ 0.0,
1489
+ 0.0,
1490
+ 0.0,
1491
+ 0.0
1492
+ ],
1493
+ "q01": [
1494
+ 0.0,
1495
+ 0.0,
1496
+ 0.0,
1497
+ 0.0,
1498
+ 0.0,
1499
+ 0.0,
1500
+ 0.0
1501
+ ],
1502
+ "q99": [
1503
+ 0.0,
1504
+ 0.0,
1505
+ 0.0,
1506
+ 0.0,
1507
+ 0.0,
1508
+ 0.0,
1509
+ 0.0
1510
+ ],
1511
+ "std": [
1512
+ 0.0,
1513
+ 0.0,
1514
+ 0.0,
1515
+ 0.0,
1516
+ 0.0,
1517
+ 0.0,
1518
+ 0.0
1519
+ ]
1520
+ }
1521
+ },
1522
+ "fractal20220817_data": {
1523
+ "action": {
1524
+ "mask": [
1525
+ true,
1526
+ true,
1527
+ true,
1528
+ true,
1529
+ true,
1530
+ true,
1531
+ false
1532
+ ],
1533
+ "max": [
1534
+ 2.9984593391418457,
1535
+ 22.09052848815918,
1536
+ 2.7507524490356445,
1537
+ 1.570636510848999,
1538
+ 1.5321086645126343,
1539
+ 1.5691522359848022,
1540
+ 1.0
1541
+ ],
1542
+ "mean": [
1543
+ 0.006987582892179489,
1544
+ 0.006265917327255011,
1545
+ -0.01262515690177679,
1546
+ 0.04333311319351196,
1547
+ -0.005756212864071131,
1548
+ 0.0009130256366916001,
1549
+ 0.5354204773902893
1550
+ ],
1551
+ "min": [
1552
+ -2.0204520225524902,
1553
+ -5.497899532318115,
1554
+ -2.031663417816162,
1555
+ -1.569917917251587,
1556
+ -1.569892168045044,
1557
+ -1.570419430732727,
1558
+ 0.0
1559
+ ],
1560
+ "q01": [
1561
+ -0.22453527510166169,
1562
+ -0.14820013284683228,
1563
+ -0.231589707583189,
1564
+ -0.3517994859814644,
1565
+ -0.4193011274933815,
1566
+ -0.43643461108207704,
1567
+ 0.0
1568
+ ],
1569
+ "q99": [
1570
+ 0.17824687153100965,
1571
+ 0.14938379630446405,
1572
+ 0.21842354819178575,
1573
+ 0.5892666035890578,
1574
+ 0.35272657424211445,
1575
+ 0.44796681255102094,
1576
+ 1.0
1577
+ ],
1578
+ "std": [
1579
+ 0.0692116990685463,
1580
+ 0.05970962345600128,
1581
+ 0.07353084534406662,
1582
+ 0.15610496699810028,
1583
+ 0.13164450228214264,
1584
+ 0.14593800902366638,
1585
+ 0.497110515832901
1586
+ ]
1587
+ },
1588
+ "num_trajectories": 87212,
1589
+ "num_transitions": 3786400,
1590
+ "proprio": {
1591
+ "max": [
1592
+ 0.0,
1593
+ 0.0,
1594
+ 0.0,
1595
+ 0.0,
1596
+ 0.0,
1597
+ 0.0,
1598
+ 0.0
1599
+ ],
1600
+ "mean": [
1601
+ 0.0,
1602
+ 0.0,
1603
+ 0.0,
1604
+ 0.0,
1605
+ 0.0,
1606
+ 0.0,
1607
+ 0.0
1608
+ ],
1609
+ "min": [
1610
+ 0.0,
1611
+ 0.0,
1612
+ 0.0,
1613
+ 0.0,
1614
+ 0.0,
1615
+ 0.0,
1616
+ 0.0
1617
+ ],
1618
+ "q01": [
1619
+ 0.0,
1620
+ 0.0,
1621
+ 0.0,
1622
+ 0.0,
1623
+ 0.0,
1624
+ 0.0,
1625
+ 0.0
1626
+ ],
1627
+ "q99": [
1628
+ 0.0,
1629
+ 0.0,
1630
+ 0.0,
1631
+ 0.0,
1632
+ 0.0,
1633
+ 0.0,
1634
+ 0.0
1635
+ ],
1636
+ "std": [
1637
+ 0.0,
1638
+ 0.0,
1639
+ 0.0,
1640
+ 0.0,
1641
+ 0.0,
1642
+ 0.0,
1643
+ 0.0
1644
+ ]
1645
+ }
1646
+ },
1647
+ "furniture_bench_dataset_converted_externally_to_rlds": {
1648
+ "action": {
1649
+ "mask": [
1650
+ true,
1651
+ true,
1652
+ true,
1653
+ true,
1654
+ true,
1655
+ true,
1656
+ false
1657
+ ],
1658
+ "max": [
1659
+ 0.10000000149011612,
1660
+ 0.10000000149011612,
1661
+ 0.10000000149011612,
1662
+ 0.8651833534240723,
1663
+ 1.0909736156463623,
1664
+ 2.863185405731201,
1665
+ 1.0
1666
+ ],
1667
+ "mean": [
1668
+ 0.00014610752987209707,
1669
+ 0.0010830952087417245,
1670
+ 0.0006224989192560315,
1671
+ -0.003303206292912364,
1672
+ -0.0026880695950239897,
1673
+ 0.018242603167891502,
1674
+ 0.48854944109916687
1675
+ ],
1676
+ "min": [
1677
+ -0.10495579987764359,
1678
+ -0.10939455777406693,
1679
+ -0.10000000149011612,
1680
+ -0.971906840801239,
1681
+ -1.0475432872772217,
1682
+ -3.06000018119812,
1683
+ 0.0
1684
+ ],
1685
+ "q01": [
1686
+ -0.053988199681043625,
1687
+ -0.05049169331789017,
1688
+ -0.032499241530895236,
1689
+ -0.1953887003660202,
1690
+ -0.41674559473991396,
1691
+ -0.8886768388748169,
1692
+ 0.0
1693
+ ],
1694
+ "q99": [
1695
+ 0.05414841488003723,
1696
+ 0.04965164884924884,
1697
+ 0.060055799782276154,
1698
+ 0.18231668293476103,
1699
+ 0.39867786407470646,
1700
+ 0.8772023963928218,
1701
+ 1.0
1702
+ ],
1703
+ "std": [
1704
+ 0.01610708422958851,
1705
+ 0.014891477301716805,
1706
+ 0.014014219865202904,
1707
+ 0.058274295181035995,
1708
+ 0.11417088657617569,
1709
+ 0.33479776978492737,
1710
+ 0.49991825222969055
1711
+ ]
1712
+ },
1713
+ "num_trajectories": 5100,
1714
+ "num_transitions": 3948057,
1715
+ "proprio": {
1716
+ "max": [
1717
+ 0.0,
1718
+ 0.0,
1719
+ 0.0,
1720
+ 0.0,
1721
+ 0.0,
1722
+ 0.0,
1723
+ 0.0
1724
+ ],
1725
+ "mean": [
1726
+ 0.0,
1727
+ 0.0,
1728
+ 0.0,
1729
+ 0.0,
1730
+ 0.0,
1731
+ 0.0,
1732
+ 0.0
1733
+ ],
1734
+ "min": [
1735
+ 0.0,
1736
+ 0.0,
1737
+ 0.0,
1738
+ 0.0,
1739
+ 0.0,
1740
+ 0.0,
1741
+ 0.0
1742
+ ],
1743
+ "q01": [
1744
+ 0.0,
1745
+ 0.0,
1746
+ 0.0,
1747
+ 0.0,
1748
+ 0.0,
1749
+ 0.0,
1750
+ 0.0
1751
+ ],
1752
+ "q99": [
1753
+ 0.0,
1754
+ 0.0,
1755
+ 0.0,
1756
+ 0.0,
1757
+ 0.0,
1758
+ 0.0,
1759
+ 0.0
1760
+ ],
1761
+ "std": [
1762
+ 0.0,
1763
+ 0.0,
1764
+ 0.0,
1765
+ 0.0,
1766
+ 0.0,
1767
+ 0.0,
1768
+ 0.0
1769
+ ]
1770
+ }
1771
+ },
1772
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
1773
+ "action": {
1774
+ "mask": [
1775
+ true,
1776
+ true,
1777
+ true,
1778
+ true,
1779
+ true,
1780
+ true,
1781
+ false
1782
+ ],
1783
+ "max": [
1784
+ 0.6634981632232666,
1785
+ 0.23428471386432648,
1786
+ 0.4308285415172577,
1787
+ 3.1415927410125732,
1788
+ 0.13647015392780304,
1789
+ 3.141592502593994,
1790
+ 1.0
1791
+ ],
1792
+ "mean": [
1793
+ 0.5274372696876526,
1794
+ 0.02858201041817665,
1795
+ 0.18712575733661652,
1796
+ 1.2339589595794678,
1797
+ 0.03226623684167862,
1798
+ -1.4199490547180176,
1799
+ 0.5550631880760193
1800
+ ],
1801
+ "min": [
1802
+ 0.3071657121181488,
1803
+ -0.29754969477653503,
1804
+ 0.06578229367733002,
1805
+ -3.1415927410125732,
1806
+ -0.04584203287959099,
1807
+ -3.141592502593994,
1808
+ 0.0
1809
+ ],
1810
+ "q01": [
1811
+ 0.3148897051811218,
1812
+ -0.20317550599575043,
1813
+ 0.06785467118024827,
1814
+ -3.140952730178833,
1815
+ -0.029743434861302376,
1816
+ -3.141091251373291,
1817
+ 0.0
1818
+ ],
1819
+ "q99": [
1820
+ 0.6472805738449097,
1821
+ 0.20846802592277527,
1822
+ 0.36855655312538155,
1823
+ 3.1409926891326903,
1824
+ 0.11424950212240226,
1825
+ 3.1410969257354737,
1826
+ 1.0
1827
+ ],
1828
+ "std": [
1829
+ 0.08108345419168472,
1830
+ 0.1116757020354271,
1831
+ 0.07747554779052734,
1832
+ 2.8737246990203857,
1833
+ 0.02774704433977604,
1834
+ 2.7678682804107666,
1835
+ 0.49695101380348206
1836
+ ]
1837
+ },
1838
+ "num_trajectories": 631,
1839
+ "num_transitions": 146241,
1840
+ "proprio": {
1841
+ "max": [
1842
+ 0.0,
1843
+ 0.0,
1844
+ 0.0,
1845
+ 0.0,
1846
+ 0.0,
1847
+ 0.0,
1848
+ 0.0
1849
+ ],
1850
+ "mean": [
1851
+ 0.0,
1852
+ 0.0,
1853
+ 0.0,
1854
+ 0.0,
1855
+ 0.0,
1856
+ 0.0,
1857
+ 0.0
1858
+ ],
1859
+ "min": [
1860
+ 0.0,
1861
+ 0.0,
1862
+ 0.0,
1863
+ 0.0,
1864
+ 0.0,
1865
+ 0.0,
1866
+ 0.0
1867
+ ],
1868
+ "q01": [
1869
+ 0.0,
1870
+ 0.0,
1871
+ 0.0,
1872
+ 0.0,
1873
+ 0.0,
1874
+ 0.0,
1875
+ 0.0
1876
+ ],
1877
+ "q99": [
1878
+ 0.0,
1879
+ 0.0,
1880
+ 0.0,
1881
+ 0.0,
1882
+ 0.0,
1883
+ 0.0,
1884
+ 0.0
1885
+ ],
1886
+ "std": [
1887
+ 0.0,
1888
+ 0.0,
1889
+ 0.0,
1890
+ 0.0,
1891
+ 0.0,
1892
+ 0.0,
1893
+ 0.0
1894
+ ]
1895
+ }
1896
+ },
1897
+ "jaco_play": {
1898
+ "action": {
1899
+ "mask": [
1900
+ true,
1901
+ true,
1902
+ true,
1903
+ true,
1904
+ true,
1905
+ true,
1906
+ false
1907
+ ],
1908
+ "max": [
1909
+ 0.20000000298023224,
1910
+ 0.20000000298023224,
1911
+ 0.20000000298023224,
1912
+ 0.0,
1913
+ 0.0,
1914
+ 0.0,
1915
+ 1.0
1916
+ ],
1917
+ "mean": [
1918
+ 0.0009658430935814977,
1919
+ -0.00580078037455678,
1920
+ -0.00395062193274498,
1921
+ 0.0,
1922
+ 0.0,
1923
+ 0.0,
1924
+ 0.34934908151626587
1925
+ ],
1926
+ "min": [
1927
+ -0.20000000298023224,
1928
+ -0.20000000298023224,
1929
+ -0.20000000298023224,
1930
+ 0.0,
1931
+ 0.0,
1932
+ 0.0,
1933
+ 0.0
1934
+ ],
1935
+ "q01": [
1936
+ -0.20000000298023224,
1937
+ -0.20000000298023224,
1938
+ -0.20000000298023224,
1939
+ 0.0,
1940
+ 0.0,
1941
+ 0.0,
1942
+ 0.0
1943
+ ],
1944
+ "q99": [
1945
+ 0.20000000298023224,
1946
+ 0.20000000298023224,
1947
+ 0.20000000298023224,
1948
+ 0.0,
1949
+ 0.0,
1950
+ 0.0,
1951
+ 1.0
1952
+ ],
1953
+ "std": [
1954
+ 0.12235074490308762,
1955
+ 0.09678777307271957,
1956
+ 0.11155334860086441,
1957
+ 0.0,
1958
+ 0.0,
1959
+ 0.0,
1960
+ 0.4768252968788147
1961
+ ]
1962
+ },
1963
+ "num_trajectories": 1085,
1964
+ "num_transitions": 77965,
1965
+ "proprio": {
1966
+ "max": [
1967
+ 0.0,
1968
+ 0.0,
1969
+ 0.0,
1970
+ 0.0,
1971
+ 0.0,
1972
+ 0.0,
1973
+ 0.0
1974
+ ],
1975
+ "mean": [
1976
+ 0.0,
1977
+ 0.0,
1978
+ 0.0,
1979
+ 0.0,
1980
+ 0.0,
1981
+ 0.0,
1982
+ 0.0
1983
+ ],
1984
+ "min": [
1985
+ 0.0,
1986
+ 0.0,
1987
+ 0.0,
1988
+ 0.0,
1989
+ 0.0,
1990
+ 0.0,
1991
+ 0.0
1992
+ ],
1993
+ "q01": [
1994
+ 0.0,
1995
+ 0.0,
1996
+ 0.0,
1997
+ 0.0,
1998
+ 0.0,
1999
+ 0.0,
2000
+ 0.0
2001
+ ],
2002
+ "q99": [
2003
+ 0.0,
2004
+ 0.0,
2005
+ 0.0,
2006
+ 0.0,
2007
+ 0.0,
2008
+ 0.0,
2009
+ 0.0
2010
+ ],
2011
+ "std": [
2012
+ 0.0,
2013
+ 0.0,
2014
+ 0.0,
2015
+ 0.0,
2016
+ 0.0,
2017
+ 0.0,
2018
+ 0.0
2019
+ ]
2020
+ }
2021
+ },
2022
+ "kuka": {
2023
+ "action": {
2024
+ "mask": [
2025
+ true,
2026
+ true,
2027
+ true,
2028
+ true,
2029
+ true,
2030
+ true,
2031
+ false
2032
+ ],
2033
+ "max": [
2034
+ 0.1697135865688324,
2035
+ 0.2777623236179352,
2036
+ 0.43710532784461975,
2037
+ 0.0,
2038
+ 0.0,
2039
+ 1.9684287309646606,
2040
+ 1.0
2041
+ ],
2042
+ "mean": [
2043
+ -0.0004668905457947403,
2044
+ 0.00040138536132872105,
2045
+ -0.001280792523175478,
2046
+ 0.0,
2047
+ 0.0,
2048
+ -0.03722453489899635,
2049
+ 0.4131543040275574
2050
+ ],
2051
+ "min": [
2052
+ -0.159867063164711,
2053
+ -0.2892282009124756,
2054
+ -0.2795473635196686,
2055
+ 0.0,
2056
+ 0.0,
2057
+ -1.9875637292861938,
2058
+ 0.0
2059
+ ],
2060
+ "q01": [
2061
+ -0.06619441494345665,
2062
+ -0.08713878810405731,
2063
+ -0.15083016991615295,
2064
+ 0.0,
2065
+ 0.0,
2066
+ -0.5415697038173676,
2067
+ 0.0
2068
+ ],
2069
+ "q99": [
2070
+ 0.06601839080452929,
2071
+ 0.08732476785779003,
2072
+ 0.18168179214000715,
2073
+ 0.0,
2074
+ 0.0,
2075
+ 0.2923380345106127,
2076
+ 1.0
2077
+ ],
2078
+ "std": [
2079
+ 0.02083250693976879,
2080
+ 0.02915887162089348,
2081
+ 0.06422865390777588,
2082
+ 0.0,
2083
+ 0.0,
2084
+ 0.14224295318126678,
2085
+ 0.49086448550224304
2086
+ ]
2087
+ },
2088
+ "num_trajectories": 209880,
2089
+ "num_transitions": 2455879,
2090
+ "proprio": {
2091
+ "max": [
2092
+ 0.0,
2093
+ 0.0,
2094
+ 0.0,
2095
+ 0.0,
2096
+ 0.0,
2097
+ 0.0,
2098
+ 0.0
2099
+ ],
2100
+ "mean": [
2101
+ 0.0,
2102
+ 0.0,
2103
+ 0.0,
2104
+ 0.0,
2105
+ 0.0,
2106
+ 0.0,
2107
+ 0.0
2108
+ ],
2109
+ "min": [
2110
+ 0.0,
2111
+ 0.0,
2112
+ 0.0,
2113
+ 0.0,
2114
+ 0.0,
2115
+ 0.0,
2116
+ 0.0
2117
+ ],
2118
+ "q01": [
2119
+ 0.0,
2120
+ 0.0,
2121
+ 0.0,
2122
+ 0.0,
2123
+ 0.0,
2124
+ 0.0,
2125
+ 0.0
2126
+ ],
2127
+ "q99": [
2128
+ 0.0,
2129
+ 0.0,
2130
+ 0.0,
2131
+ 0.0,
2132
+ 0.0,
2133
+ 0.0,
2134
+ 0.0
2135
+ ],
2136
+ "std": [
2137
+ 0.0,
2138
+ 0.0,
2139
+ 0.0,
2140
+ 0.0,
2141
+ 0.0,
2142
+ 0.0,
2143
+ 0.0
2144
+ ]
2145
+ }
2146
+ },
2147
+ "nyu_franka_play_dataset_converted_externally_to_rlds": {
2148
+ "action": {
2149
+ "mask": [
2150
+ true,
2151
+ true,
2152
+ true,
2153
+ true,
2154
+ true,
2155
+ true,
2156
+ false
2157
+ ],
2158
+ "max": [
2159
+ 0.06424188613891602,
2160
+ 0.07027634978294373,
2161
+ 0.06129661202430725,
2162
+ 6.281067848205566,
2163
+ 0.1967729926109314,
2164
+ 0.26377415657043457,
2165
+ 1.0
2166
+ ],
2167
+ "mean": [
2168
+ 0.001021989737637341,
2169
+ -0.00012002651783404872,
2170
+ 0.00032894269679673016,
2171
+ 0.0015034361276775599,
2172
+ -0.002198522910475731,
2173
+ -0.001663230243138969,
2174
+ 0.7230083346366882
2175
+ ],
2176
+ "min": [
2177
+ -0.05952230095863342,
2178
+ -0.07232445478439331,
2179
+ -0.06730806827545166,
2180
+ -6.278434753417969,
2181
+ -0.21479034423828125,
2182
+ -0.3627619743347168,
2183
+ 0.0
2184
+ ],
2185
+ "q01": [
2186
+ -0.03199600875377655,
2187
+ -0.032861671447753905,
2188
+ -0.03368805110454559,
2189
+ -0.12080862045288086,
2190
+ -0.12175218224525451,
2191
+ -0.11370223641395569,
2192
+ 0.0
2193
+ ],
2194
+ "q99": [
2195
+ 0.03101520001888276,
2196
+ 0.0373908892273903,
2197
+ 0.03646374464035038,
2198
+ 0.11764093399047852,
2199
+ 0.1258920183777809,
2200
+ 0.09366151213645942,
2201
+ 1.0
2202
+ ],
2203
+ "std": [
2204
+ 0.01327415369451046,
2205
+ 0.013215910643339157,
2206
+ 0.012822109274566174,
2207
+ 0.2732451558113098,
2208
+ 0.057022541761398315,
2209
+ 0.039172880351543427,
2210
+ 0.44752755761146545
2211
+ ]
2212
+ },
2213
+ "num_trajectories": 456,
2214
+ "num_transitions": 44875,
2215
+ "proprio": {
2216
+ "max": [
2217
+ 0.0,
2218
+ 0.0,
2219
+ 0.0,
2220
+ 0.0,
2221
+ 0.0,
2222
+ 0.0,
2223
+ 0.0
2224
+ ],
2225
+ "mean": [
2226
+ 0.0,
2227
+ 0.0,
2228
+ 0.0,
2229
+ 0.0,
2230
+ 0.0,
2231
+ 0.0,
2232
+ 0.0
2233
+ ],
2234
+ "min": [
2235
+ 0.0,
2236
+ 0.0,
2237
+ 0.0,
2238
+ 0.0,
2239
+ 0.0,
2240
+ 0.0,
2241
+ 0.0
2242
+ ],
2243
+ "q01": [
2244
+ 0.0,
2245
+ 0.0,
2246
+ 0.0,
2247
+ 0.0,
2248
+ 0.0,
2249
+ 0.0,
2250
+ 0.0
2251
+ ],
2252
+ "q99": [
2253
+ 0.0,
2254
+ 0.0,
2255
+ 0.0,
2256
+ 0.0,
2257
+ 0.0,
2258
+ 0.0,
2259
+ 0.0
2260
+ ],
2261
+ "std": [
2262
+ 0.0,
2263
+ 0.0,
2264
+ 0.0,
2265
+ 0.0,
2266
+ 0.0,
2267
+ 0.0,
2268
+ 0.0
2269
+ ]
2270
+ }
2271
+ },
2272
+ "roboturk": {
2273
+ "action": {
2274
+ "mask": [
2275
+ true,
2276
+ true,
2277
+ true,
2278
+ true,
2279
+ true,
2280
+ true,
2281
+ false
2282
+ ],
2283
+ "max": [
2284
+ 0.39124172925949097,
2285
+ 0.4601028263568878,
2286
+ 0.4870833456516266,
2287
+ 1.816888689994812,
2288
+ 1.8240282535552979,
2289
+ 1.4824820756912231,
2290
+ 1.0
2291
+ ],
2292
+ "mean": [
2293
+ 0.0014448732836171985,
2294
+ -0.0015945249469950795,
2295
+ -0.0011753785656765103,
2296
+ 0.0023012510500848293,
2297
+ -0.0009382463176734746,
2298
+ -0.00011485807772260159,
2299
+ 0.5746025443077087
2300
+ ],
2301
+ "min": [
2302
+ -0.6546999216079712,
2303
+ -0.6365841031074524,
2304
+ -0.4217723608016968,
2305
+ -1.6695482730865479,
2306
+ -1.8023357391357422,
2307
+ -1.4630827903747559,
2308
+ 0.0
2309
+ ],
2310
+ "q01": [
2311
+ -0.1342635464668274,
2312
+ -0.19996687173843383,
2313
+ -0.1482972100377083,
2314
+ -0.20720748245716095,
2315
+ -0.09676413893699647,
2316
+ -0.18075634717941286,
2317
+ 0.0
2318
+ ],
2319
+ "q99": [
2320
+ 0.14956976801157001,
2321
+ 0.1805950567126275,
2322
+ 0.18841815620660796,
2323
+ 0.21615413755178453,
2324
+ 0.09457383215427405,
2325
+ 0.18543301910162005,
2326
+ 1.0
2327
+ ],
2328
+ "std": [
2329
+ 0.04935386776924133,
2330
+ 0.0635455846786499,
2331
+ 0.061164740473032,
2332
+ 0.09553450345993042,
2333
+ 0.08420111238956451,
2334
+ 0.06517903506755829,
2335
+ 0.49452081322669983
2336
+ ]
2337
+ },
2338
+ "num_trajectories": 1995,
2339
+ "num_transitions": 187507,
2340
+ "proprio": {
2341
+ "max": [
2342
+ 0.0,
2343
+ 0.0,
2344
+ 0.0,
2345
+ 0.0,
2346
+ 0.0,
2347
+ 0.0,
2348
+ 0.0
2349
+ ],
2350
+ "mean": [
2351
+ 0.0,
2352
+ 0.0,
2353
+ 0.0,
2354
+ 0.0,
2355
+ 0.0,
2356
+ 0.0,
2357
+ 0.0
2358
+ ],
2359
+ "min": [
2360
+ 0.0,
2361
+ 0.0,
2362
+ 0.0,
2363
+ 0.0,
2364
+ 0.0,
2365
+ 0.0,
2366
+ 0.0
2367
+ ],
2368
+ "q01": [
2369
+ 0.0,
2370
+ 0.0,
2371
+ 0.0,
2372
+ 0.0,
2373
+ 0.0,
2374
+ 0.0,
2375
+ 0.0
2376
+ ],
2377
+ "q99": [
2378
+ 0.0,
2379
+ 0.0,
2380
+ 0.0,
2381
+ 0.0,
2382
+ 0.0,
2383
+ 0.0,
2384
+ 0.0
2385
+ ],
2386
+ "std": [
2387
+ 0.0,
2388
+ 0.0,
2389
+ 0.0,
2390
+ 0.0,
2391
+ 0.0,
2392
+ 0.0,
2393
+ 0.0
2394
+ ]
2395
+ }
2396
+ },
2397
+ "stanford_hydra_dataset_converted_externally_to_rlds": {
2398
+ "action": {
2399
+ "mask": [
2400
+ true,
2401
+ true,
2402
+ true,
2403
+ true,
2404
+ true,
2405
+ true,
2406
+ false
2407
+ ],
2408
+ "max": [
2409
+ 0.02499854564666748,
2410
+ 0.02499903365969658,
2411
+ 0.024999922141432762,
2412
+ 0.24974457919597626,
2413
+ 0.24997030198574066,
2414
+ 0.24999946355819702,
2415
+ 1.0
2416
+ ],
2417
+ "mean": [
2418
+ 0.0007790001109242439,
2419
+ 0.00013707754260394722,
2420
+ -0.0002548607881180942,
2421
+ 0.0012903271708637476,
2422
+ -0.004751681815832853,
2423
+ 0.002692886395379901,
2424
+ 0.48855218291282654
2425
+ ],
2426
+ "min": [
2427
+ -0.024999044835567474,
2428
+ -0.024999700486660004,
2429
+ -0.02499929815530777,
2430
+ -0.24993225932121277,
2431
+ -0.2499666064977646,
2432
+ -0.2499932497739792,
2433
+ 0.0
2434
+ ],
2435
+ "q01": [
2436
+ -0.019992006458342076,
2437
+ -0.02415412735193968,
2438
+ -0.022941758055239916,
2439
+ -0.11085530579090118,
2440
+ -0.12024572037160397,
2441
+ -0.13314770206809043,
2442
+ 0.0
2443
+ ],
2444
+ "q99": [
2445
+ 0.022886231057345868,
2446
+ 0.022358838934451335,
2447
+ 0.02410089675337076,
2448
+ 0.12370114490389822,
2449
+ 0.11323311634361738,
2450
+ 0.18474749639630164,
2451
+ 1.0
2452
+ ],
2453
+ "std": [
2454
+ 0.008022161200642586,
2455
+ 0.009131459519267082,
2456
+ 0.009574338793754578,
2457
+ 0.04122216999530792,
2458
+ 0.0384303517639637,
2459
+ 0.04606688767671585,
2460
+ 0.49976691603660583
2461
+ ]
2462
+ },
2463
+ "num_trajectories": 570,
2464
+ "num_transitions": 358234,
2465
+ "proprio": {
2466
+ "max": [
2467
+ 0.0,
2468
+ 0.0,
2469
+ 0.0,
2470
+ 0.0,
2471
+ 0.0,
2472
+ 0.0,
2473
+ 0.0
2474
+ ],
2475
+ "mean": [
2476
+ 0.0,
2477
+ 0.0,
2478
+ 0.0,
2479
+ 0.0,
2480
+ 0.0,
2481
+ 0.0,
2482
+ 0.0
2483
+ ],
2484
+ "min": [
2485
+ 0.0,
2486
+ 0.0,
2487
+ 0.0,
2488
+ 0.0,
2489
+ 0.0,
2490
+ 0.0,
2491
+ 0.0
2492
+ ],
2493
+ "q01": [
2494
+ 0.0,
2495
+ 0.0,
2496
+ 0.0,
2497
+ 0.0,
2498
+ 0.0,
2499
+ 0.0,
2500
+ 0.0
2501
+ ],
2502
+ "q99": [
2503
+ 0.0,
2504
+ 0.0,
2505
+ 0.0,
2506
+ 0.0,
2507
+ 0.0,
2508
+ 0.0,
2509
+ 0.0
2510
+ ],
2511
+ "std": [
2512
+ 0.0,
2513
+ 0.0,
2514
+ 0.0,
2515
+ 0.0,
2516
+ 0.0,
2517
+ 0.0,
2518
+ 0.0
2519
+ ]
2520
+ }
2521
+ },
2522
+ "taco_play": {
2523
+ "action": {
2524
+ "mask": [
2525
+ true,
2526
+ true,
2527
+ true,
2528
+ true,
2529
+ true,
2530
+ true,
2531
+ false
2532
+ ],
2533
+ "max": [
2534
+ 1.4915844202041626,
2535
+ 2.1842432022094727,
2536
+ 2.6836395263671875,
2537
+ 5.035226821899414,
2538
+ 2.665864944458008,
2539
+ 4.250768661499023,
2540
+ 1.0
2541
+ ],
2542
+ "mean": [
2543
+ -0.003845922416076064,
2544
+ 0.009671456180512905,
2545
+ 0.012780580669641495,
2546
+ -0.005403771996498108,
2547
+ -0.009606587700545788,
2548
+ -0.002480733208358288,
2549
+ 0.4263913035392761
2550
+ ],
2551
+ "min": [
2552
+ -4.242457866668701,
2553
+ -3.192805051803589,
2554
+ -1.3371467590332031,
2555
+ -4.202683448791504,
2556
+ -2.6722638607025146,
2557
+ -3.3467135429382324,
2558
+ 0.0
2559
+ ],
2560
+ "q01": [
2561
+ -0.7106140398979186,
2562
+ -1.056944659948349,
2563
+ -0.5878450274467468,
2564
+ -0.7682853937149048,
2565
+ -0.7180147767066956,
2566
+ -1.5527938604354858,
2567
+ 0.0
2568
+ ],
2569
+ "q99": [
2570
+ 0.6482916426658629,
2571
+ 1.0051310062408447,
2572
+ 0.9480248689651489,
2573
+ 0.6926478147506714,
2574
+ 0.6351067513227462,
2575
+ 1.628010264635086,
2576
+ 1.0
2577
+ ],
2578
+ "std": [
2579
+ 0.23254038393497467,
2580
+ 0.36298269033432007,
2581
+ 0.28692901134490967,
2582
+ 0.2617705166339874,
2583
+ 0.2438892275094986,
2584
+ 0.5216503143310547,
2585
+ 0.4946896731853485
2586
+ ]
2587
+ },
2588
+ "num_trajectories": 3603,
2589
+ "num_transitions": 237798,
2590
+ "proprio": {
2591
+ "max": [
2592
+ 0.0,
2593
+ 0.0,
2594
+ 0.0,
2595
+ 0.0,
2596
+ 0.0,
2597
+ 0.0,
2598
+ 0.0
2599
+ ],
2600
+ "mean": [
2601
+ 0.0,
2602
+ 0.0,
2603
+ 0.0,
2604
+ 0.0,
2605
+ 0.0,
2606
+ 0.0,
2607
+ 0.0
2608
+ ],
2609
+ "min": [
2610
+ 0.0,
2611
+ 0.0,
2612
+ 0.0,
2613
+ 0.0,
2614
+ 0.0,
2615
+ 0.0,
2616
+ 0.0
2617
+ ],
2618
+ "q01": [
2619
+ 0.0,
2620
+ 0.0,
2621
+ 0.0,
2622
+ 0.0,
2623
+ 0.0,
2624
+ 0.0,
2625
+ 0.0
2626
+ ],
2627
+ "q99": [
2628
+ 0.0,
2629
+ 0.0,
2630
+ 0.0,
2631
+ 0.0,
2632
+ 0.0,
2633
+ 0.0,
2634
+ 0.0
2635
+ ],
2636
+ "std": [
2637
+ 0.0,
2638
+ 0.0,
2639
+ 0.0,
2640
+ 0.0,
2641
+ 0.0,
2642
+ 0.0,
2643
+ 0.0
2644
+ ]
2645
+ }
2646
+ },
2647
+ "toto": {
2648
+ "action": {
2649
+ "mask": [
2650
+ true,
2651
+ true,
2652
+ true,
2653
+ true,
2654
+ true,
2655
+ true,
2656
+ false
2657
+ ],
2658
+ "max": [
2659
+ 0.6839867234230042,
2660
+ 0.4454185664653778,
2661
+ 0.7984078526496887,
2662
+ 2.120781660079956,
2663
+ 1.371164321899414,
2664
+ 1.4118704795837402,
2665
+ 0.0
2666
+ ],
2667
+ "mean": [
2668
+ 0.38542115688323975,
2669
+ 0.007769413758069277,
2670
+ 0.3632740378379822,
2671
+ -0.6652036905288696,
2672
+ 0.1890396922826767,
2673
+ 0.03298724442720413,
2674
+ 0.0
2675
+ ],
2676
+ "min": [
2677
+ 0.09922284632921219,
2678
+ -0.5180193781852722,
2679
+ 0.13791072368621826,
2680
+ -2.635117530822754,
2681
+ -1.0734480619430542,
2682
+ -1.9282547235488892,
2683
+ 0.0
2684
+ ],
2685
+ "q01": [
2686
+ 0.1756722891330719,
2687
+ -0.3077590811252594,
2688
+ 0.235383919775486,
2689
+ -2.0908505964279174,
2690
+ -0.6191593289375306,
2691
+ -0.7488683319091797,
2692
+ 0.0
2693
+ ],
2694
+ "q99": [
2695
+ 0.6136963081359863,
2696
+ 0.33704194784164443,
2697
+ 0.6681221985816956,
2698
+ 0.7422861719131538,
2699
+ 0.7955395007133507,
2700
+ 0.740464625358582,
2701
+ 0.0
2702
+ ],
2703
+ "std": [
2704
+ 0.12211652100086212,
2705
+ 0.19378550350666046,
2706
+ 0.10178236663341522,
2707
+ 0.5725259184837341,
2708
+ 0.29884573817253113,
2709
+ 0.3259911835193634,
2710
+ 0.0
2711
+ ]
2712
+ },
2713
+ "num_trajectories": 1003,
2714
+ "num_transitions": 325699,
2715
+ "proprio": {
2716
+ "max": [
2717
+ 0.0,
2718
+ 0.0,
2719
+ 0.0,
2720
+ 0.0,
2721
+ 0.0,
2722
+ 0.0,
2723
+ 0.0
2724
+ ],
2725
+ "mean": [
2726
+ 0.0,
2727
+ 0.0,
2728
+ 0.0,
2729
+ 0.0,
2730
+ 0.0,
2731
+ 0.0,
2732
+ 0.0
2733
+ ],
2734
+ "min": [
2735
+ 0.0,
2736
+ 0.0,
2737
+ 0.0,
2738
+ 0.0,
2739
+ 0.0,
2740
+ 0.0,
2741
+ 0.0
2742
+ ],
2743
+ "q01": [
2744
+ 0.0,
2745
+ 0.0,
2746
+ 0.0,
2747
+ 0.0,
2748
+ 0.0,
2749
+ 0.0,
2750
+ 0.0
2751
+ ],
2752
+ "q99": [
2753
+ 0.0,
2754
+ 0.0,
2755
+ 0.0,
2756
+ 0.0,
2757
+ 0.0,
2758
+ 0.0,
2759
+ 0.0
2760
+ ],
2761
+ "std": [
2762
+ 0.0,
2763
+ 0.0,
2764
+ 0.0,
2765
+ 0.0,
2766
+ 0.0,
2767
+ 0.0,
2768
+ 0.0
2769
+ ]
2770
+ }
2771
+ },
2772
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": {
2773
+ "action": {
2774
+ "mask": [
2775
+ true,
2776
+ true,
2777
+ true,
2778
+ true,
2779
+ true,
2780
+ true,
2781
+ false
2782
+ ],
2783
+ "max": [
2784
+ 678.0,
2785
+ 400.0,
2786
+ 507.0,
2787
+ 180.00001525878906,
2788
+ 6.000013828277588,
2789
+ 116.99998474121094,
2790
+ 1.0
2791
+ ],
2792
+ "mean": [
2793
+ 410.37567138671875,
2794
+ 116.9518814086914,
2795
+ 192.35032653808594,
2796
+ -121.22441864013672,
2797
+ -33.84893035888672,
2798
+ 50.016136169433594,
2799
+ 0.741813600063324
2800
+ ],
2801
+ "min": [
2802
+ 172.0,
2803
+ -166.0,
2804
+ -99.99999237060547,
2805
+ -180.00001525878906,
2806
+ -89.0,
2807
+ -96.00010681152344,
2808
+ 0.0
2809
+ ],
2810
+ "q01": [
2811
+ 200.00001052856445,
2812
+ -102.31004211425781,
2813
+ -94.99993370056153,
2814
+ -180.00001525878906,
2815
+ -88.00001525878906,
2816
+ -38.999977111816406,
2817
+ 0.0
2818
+ ],
2819
+ "q99": [
2820
+ 637.0,
2821
+ 368.30999999999995,
2822
+ 493.0,
2823
+ 180.00001525878906,
2824
+ 0.999983012676239,
2825
+ 105.00001525878906,
2826
+ 1.0
2827
+ ],
2828
+ "std": [
2829
+ 122.81494903564453,
2830
+ 108.8009033203125,
2831
+ 130.303466796875,
2832
+ 116.28205108642578,
2833
+ 27.621843338012695,
2834
+ 41.02094650268555,
2835
+ 0.43763357400894165
2836
+ ]
2837
+ },
2838
+ "num_trajectories": 150,
2839
+ "num_transitions": 3970,
2840
+ "proprio": {
2841
+ "max": [
2842
+ 0.0,
2843
+ 0.0,
2844
+ 0.0,
2845
+ 0.0,
2846
+ 0.0,
2847
+ 0.0,
2848
+ 0.0
2849
+ ],
2850
+ "mean": [
2851
+ 0.0,
2852
+ 0.0,
2853
+ 0.0,
2854
+ 0.0,
2855
+ 0.0,
2856
+ 0.0,
2857
+ 0.0
2858
+ ],
2859
+ "min": [
2860
+ 0.0,
2861
+ 0.0,
2862
+ 0.0,
2863
+ 0.0,
2864
+ 0.0,
2865
+ 0.0,
2866
+ 0.0
2867
+ ],
2868
+ "q01": [
2869
+ 0.0,
2870
+ 0.0,
2871
+ 0.0,
2872
+ 0.0,
2873
+ 0.0,
2874
+ 0.0,
2875
+ 0.0
2876
+ ],
2877
+ "q99": [
2878
+ 0.0,
2879
+ 0.0,
2880
+ 0.0,
2881
+ 0.0,
2882
+ 0.0,
2883
+ 0.0,
2884
+ 0.0
2885
+ ],
2886
+ "std": [
2887
+ 0.0,
2888
+ 0.0,
2889
+ 0.0,
2890
+ 0.0,
2891
+ 0.0,
2892
+ 0.0,
2893
+ 0.0
2894
+ ]
2895
+ }
2896
+ },
2897
+ "utaustin_mutex": {
2898
+ "action": {
2899
+ "mask": [
2900
+ true,
2901
+ true,
2902
+ true,
2903
+ true,
2904
+ true,
2905
+ true,
2906
+ false
2907
+ ],
2908
+ "max": [
2909
+ 1.0,
2910
+ 1.0,
2911
+ 1.0,
2912
+ 0.375,
2913
+ 0.375,
2914
+ 0.375,
2915
+ 1.0
2916
+ ],
2917
+ "mean": [
2918
+ 0.06176406890153885,
2919
+ -0.005005486309528351,
2920
+ 0.10216785222291946,
2921
+ -0.03314131125807762,
2922
+ 0.013895004987716675,
2923
+ -0.011317633092403412,
2924
+ 0.5038976669311523
2925
+ ],
2926
+ "min": [
2927
+ -1.0,
2928
+ -1.0,
2929
+ -1.0,
2930
+ -0.375,
2931
+ -0.375,
2932
+ -0.375,
2933
+ 0.0
2934
+ ],
2935
+ "q01": [
2936
+ -0.4285714328289032,
2937
+ -0.9800000190734863,
2938
+ -0.5571428537368774,
2939
+ -0.375,
2940
+ -0.15642857551574707,
2941
+ -0.335357129573822,
2942
+ 0.0
2943
+ ],
2944
+ "q99": [
2945
+ 0.5914285778999329,
2946
+ 0.9714285731315613,
2947
+ 1.0,
2948
+ 0.3278571367263794,
2949
+ 0.207857146859169,
2950
+ 0.25607141852378845,
2951
+ 1.0
2952
+ ],
2953
+ "std": [
2954
+ 0.1875014752149582,
2955
+ 0.4468473494052887,
2956
+ 0.3792876601219177,
2957
+ 0.14097853004932404,
2958
+ 0.06453701853752136,
2959
+ 0.11765272170305252,
2960
+ 0.501045286655426
2961
+ ]
2962
+ },
2963
+ "num_trajectories": 1500,
2964
+ "num_transitions": 361883,
2965
+ "proprio": {
2966
+ "max": [
2967
+ 0.0,
2968
+ 0.0,
2969
+ 0.0,
2970
+ 0.0,
2971
+ 0.0,
2972
+ 0.0,
2973
+ 0.0
2974
+ ],
2975
+ "mean": [
2976
+ 0.0,
2977
+ 0.0,
2978
+ 0.0,
2979
+ 0.0,
2980
+ 0.0,
2981
+ 0.0,
2982
+ 0.0
2983
+ ],
2984
+ "min": [
2985
+ 0.0,
2986
+ 0.0,
2987
+ 0.0,
2988
+ 0.0,
2989
+ 0.0,
2990
+ 0.0,
2991
+ 0.0
2992
+ ],
2993
+ "q01": [
2994
+ 0.0,
2995
+ 0.0,
2996
+ 0.0,
2997
+ 0.0,
2998
+ 0.0,
2999
+ 0.0,
3000
+ 0.0
3001
+ ],
3002
+ "q99": [
3003
+ 0.0,
3004
+ 0.0,
3005
+ 0.0,
3006
+ 0.0,
3007
+ 0.0,
3008
+ 0.0,
3009
+ 0.0
3010
+ ],
3011
+ "std": [
3012
+ 0.0,
3013
+ 0.0,
3014
+ 0.0,
3015
+ 0.0,
3016
+ 0.0,
3017
+ 0.0,
3018
+ 0.0
3019
+ ]
3020
+ }
3021
+ },
3022
+ "viola": {
3023
+ "action": {
3024
+ "mask": [
3025
+ true,
3026
+ true,
3027
+ true,
3028
+ true,
3029
+ true,
3030
+ true,
3031
+ false
3032
+ ],
3033
+ "max": [
3034
+ 1.0,
3035
+ 1.0,
3036
+ 1.0,
3037
+ 0.375,
3038
+ 0.36321428418159485,
3039
+ 0.375,
3040
+ 1.0
3041
+ ],
3042
+ "mean": [
3043
+ 0.04761844128370285,
3044
+ -0.029204415157437325,
3045
+ 0.05586736649274826,
3046
+ -0.002618510741740465,
3047
+ 0.006867344491183758,
3048
+ -0.01682133786380291,
3049
+ 0.7323777675628662
3050
+ ],
3051
+ "min": [
3052
+ -1.0,
3053
+ -1.0,
3054
+ -1.0,
3055
+ -0.375,
3056
+ -0.375,
3057
+ -0.375,
3058
+ 0.0
3059
+ ],
3060
+ "q01": [
3061
+ -0.9628571271896362,
3062
+ -1.0,
3063
+ -1.0,
3064
+ -0.26249998807907104,
3065
+ -0.21321429312229156,
3066
+ -0.3385714292526245,
3067
+ 0.0
3068
+ ],
3069
+ "q99": [
3070
+ 0.9114285707473755,
3071
+ 0.868571400642395,
3072
+ 1.0,
3073
+ 0.2817857265472412,
3074
+ 0.2239285707473755,
3075
+ 0.3557142913341522,
3076
+ 1.0
3077
+ ],
3078
+ "std": [
3079
+ 0.39157867431640625,
3080
+ 0.4076525568962097,
3081
+ 0.40077948570251465,
3082
+ 0.10023996233940125,
3083
+ 0.0844319611787796,
3084
+ 0.10375042259693146,
3085
+ 0.44260647892951965
3086
+ ]
3087
+ },
3088
+ "num_trajectories": 150,
3089
+ "num_transitions": 76324,
3090
+ "proprio": {
3091
+ "max": [
3092
+ 0.0,
3093
+ 0.0,
3094
+ 0.0,
3095
+ 0.0,
3096
+ 0.0,
3097
+ 0.0,
3098
+ 0.0
3099
+ ],
3100
+ "mean": [
3101
+ 0.0,
3102
+ 0.0,
3103
+ 0.0,
3104
+ 0.0,
3105
+ 0.0,
3106
+ 0.0,
3107
+ 0.0
3108
+ ],
3109
+ "min": [
3110
+ 0.0,
3111
+ 0.0,
3112
+ 0.0,
3113
+ 0.0,
3114
+ 0.0,
3115
+ 0.0,
3116
+ 0.0
3117
+ ],
3118
+ "q01": [
3119
+ 0.0,
3120
+ 0.0,
3121
+ 0.0,
3122
+ 0.0,
3123
+ 0.0,
3124
+ 0.0,
3125
+ 0.0
3126
+ ],
3127
+ "q99": [
3128
+ 0.0,
3129
+ 0.0,
3130
+ 0.0,
3131
+ 0.0,
3132
+ 0.0,
3133
+ 0.0,
3134
+ 0.0
3135
+ ],
3136
+ "std": [
3137
+ 0.0,
3138
+ 0.0,
3139
+ 0.0,
3140
+ 0.0,
3141
+ 0.0,
3142
+ 0.0,
3143
+ 0.0
3144
+ ]
3145
+ }
3146
+ }
3147
+ },
3148
+ "output_projector_states": false,
3149
+ "pad_to_multiple_of": 64,
3150
+ "pad_token_id": 32000,
3151
+ "text_config": {
3152
+ "attention_dropout": 0.0,
3153
+ "bos_token_id": 151643,
3154
+ "eos_token_id": 151643,
3155
+ "hidden_act": "silu",
3156
+ "hidden_size": 896,
3157
+ "initializer_range": 0.02,
3158
+ "intermediate_size": 4864,
3159
+ "layer_types": [
3160
+ "full_attention",
3161
+ "full_attention",
3162
+ "full_attention",
3163
+ "full_attention",
3164
+ "full_attention",
3165
+ "full_attention",
3166
+ "full_attention",
3167
+ "full_attention",
3168
+ "full_attention",
3169
+ "full_attention",
3170
+ "full_attention",
3171
+ "full_attention",
3172
+ "full_attention",
3173
+ "full_attention",
3174
+ "full_attention",
3175
+ "full_attention",
3176
+ "full_attention",
3177
+ "full_attention",
3178
+ "full_attention",
3179
+ "full_attention",
3180
+ "full_attention",
3181
+ "full_attention",
3182
+ "full_attention",
3183
+ "full_attention"
3184
+ ],
3185
+ "max_position_embeddings": 32768,
3186
+ "max_window_layers": 24,
3187
+ "model_type": "qwen2",
3188
+ "num_attention_heads": 14,
3189
+ "num_hidden_layers": 24,
3190
+ "num_key_value_heads": 2,
3191
+ "rms_norm_eps": 1e-06,
3192
+ "rope_scaling": null,
3193
+ "rope_theta": 1000000.0,
3194
+ "sliding_window": null,
3195
+ "tie_word_embeddings": true,
3196
+ "torch_dtype": "bfloat16",
3197
+ "use_cache": true,
3198
+ "use_mrope": false,
3199
+ "use_sliding_window": false,
3200
+ "vocab_size": 151936
3201
+ },
3202
+ "timm_model_ids": [
3203
+ "vit_large_patch14_reg4_dinov2.lvd142m",
3204
+ "vit_so400m_patch14_siglip_224"
3205
+ ],
3206
+ "timm_override_act_layers": [
3207
+ null,
3208
+ null
3209
+ ],
3210
+ "torch_dtype": "bfloat16",
3211
+ "transformers_version": "4.54.1",
3212
+ "use_fused_vision_backbone": true,
3213
+ "vision_backbone_id": "dinosiglip-vit-so-224px"
3214
+ }
configuration_prismatic.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+
51
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
+
53
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
+
56
+ "phi-2-3b": "microsoft/phi-2",
57
+ "qwen25-0_5b-extra": "Qwen/Qwen2.5-0.5B", "qwen25-0_5b-pure": "Qwen/Qwen2.5-0.5B"
58
+
59
+
60
+ }
61
+ LLM_BACKBONE_TO_HF_METACLASS = {
62
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
63
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
64
+
65
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
66
+
67
+ "phi-2-3b": "phi",
68
+ "qwen25-0_5b-extra": "qwen2" ,"qwen25-0_5b-pure": "qwen2"
69
+ }
70
+
71
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
72
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
73
+ # fmt: on
74
+
75
+
76
+ class PrismaticConfig(PretrainedConfig):
77
+ model_type: str = "prismatic"
78
+ is_composition: bool = False
79
+
80
+ def __init__(
81
+ self,
82
+ vision_backbone_id: str = "siglip-vit-so400m",
83
+ llm_backbone_id: str = "vicuna-v15-7b",
84
+ arch_specifier: str = "no-align+gelu-mlp",
85
+ use_fused_vision_backbone: Optional[bool] = None,
86
+ image_resize_strategy: str = "letterbox",
87
+ text_config: Optional[Dict[str, Any]] = None,
88
+ llm_max_length: int = 2048,
89
+ pad_token_id: int = 32000,
90
+ pad_to_multiple_of: int = 64,
91
+ output_projector_states: bool = False,
92
+ **kwargs: str,
93
+ ) -> None:
94
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
95
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
96
+
97
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
98
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
99
+
100
+ # Set Prismatic Configuration Fields
101
+ self.vision_backbone_id = vision_backbone_id
102
+ self.llm_backbone_id = llm_backbone_id
103
+ self.arch_specifier = arch_specifier
104
+ self.output_projector_states = output_projector_states
105
+
106
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
107
+ self.use_fused_vision_backbone = (
108
+ use_fused_vision_backbone
109
+ if use_fused_vision_backbone is not None
110
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
111
+ )
112
+
113
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
114
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
115
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
116
+ self.image_resize_strategy = image_resize_strategy
117
+
118
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
119
+ self.llm_max_length = llm_max_length
120
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
121
+
122
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
123
+ self.text_config = (
124
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
125
+ if text_config is not None
126
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
127
+ )
128
+
129
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
130
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
131
+
132
+
133
+ class OpenVLAConfig(PrismaticConfig):
134
+ model_type: str = "openvla"
135
+
136
+ def __init__(
137
+ self,
138
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
139
+ n_action_bins: int = 256,
140
+ **kwargs: str,
141
+ ) -> None:
142
+ self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
143
+
144
+ super().__init__(**kwargs)
dataset_statistics.json ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "libero_spatial_no_noops": {
3
+ "action": {
4
+ "mean": [
5
+ 0.15312477946281433,
6
+ 0.13707244396209717,
7
+ -0.15526829659938812,
8
+ -0.005176456645131111,
9
+ -0.011208733543753624,
10
+ -0.02019423246383667,
11
+ 0.4578818082809448
12
+ ],
13
+ "std": [
14
+ 0.4127277433872223,
15
+ 0.34724390506744385,
16
+ 0.50869220495224,
17
+ 0.037265900522470474,
18
+ 0.07244452834129333,
19
+ 0.05762367323040962,
20
+ 0.4982788562774658
21
+ ],
22
+ "max": [
23
+ 0.9375,
24
+ 0.9375,
25
+ 0.9375,
26
+ 0.1971428543329239,
27
+ 0.33642858266830444,
28
+ 0.375,
29
+ 1.0
30
+ ],
31
+ "min": [
32
+ -0.9375,
33
+ -0.9375,
34
+ -0.9375,
35
+ -0.1875,
36
+ -0.3675000071525574,
37
+ -0.36000001430511475,
38
+ 0.0
39
+ ],
40
+ "q01": [
41
+ -0.7454732114076613,
42
+ -0.6616071462631226,
43
+ -0.9375,
44
+ -0.1071428582072258,
45
+ -0.20678570866584778,
46
+ -0.1842857152223587,
47
+ 0.0
48
+ ],
49
+ "q99": [
50
+ 0.9375,
51
+ 0.8758928775787354,
52
+ 0.9321428537368774,
53
+ 0.1039285734295845,
54
+ 0.17678570747375488,
55
+ 0.14571428298950195,
56
+ 1.0
57
+ ],
58
+ "mask": [
59
+ true,
60
+ true,
61
+ true,
62
+ true,
63
+ true,
64
+ true,
65
+ false
66
+ ]
67
+ },
68
+ "proprio": {
69
+ "mean": [
70
+ -0.024462517350912094,
71
+ 0.10652990639209747,
72
+ 1.058049201965332,
73
+ 3.0628392696380615,
74
+ -0.10464007407426834,
75
+ 0.08307299017906189,
76
+ 0.01995452307164669,
77
+ -0.02016276866197586
78
+ ],
79
+ "std": [
80
+ 0.1101478859782219,
81
+ 0.137846902012825,
82
+ 0.10442808270454407,
83
+ 0.1045105829834938,
84
+ 0.41120994091033936,
85
+ 0.2176690399646759,
86
+ 0.01726088672876358,
87
+ 0.017111632972955704
88
+ ],
89
+ "max": [
90
+ 0.1759040206670761,
91
+ 0.3904820382595062,
92
+ 1.3290715217590332,
93
+ 3.4566118717193604,
94
+ 1.2268599271774292,
95
+ 1.0429412126541138,
96
+ 0.041053611785173416,
97
+ 0.000775813648942858
98
+ ],
99
+ "min": [
100
+ -0.3095473051071167,
101
+ -0.29250794649124146,
102
+ 0.9095591306686401,
103
+ 2.497488260269165,
104
+ -1.8006486892700195,
105
+ -0.7207611203193665,
106
+ -0.0004703797458205372,
107
+ -0.041536275297403336
108
+ ],
109
+ "q01": [
110
+ -0.2727657300233841,
111
+ -0.23721413239836692,
112
+ 0.9160063165426254,
113
+ 2.77949666261673,
114
+ -1.3187511622905732,
115
+ -0.41989982962608335,
116
+ 0.001503719249740243,
117
+ -0.03989770736545324
118
+ ],
119
+ "q99": [
120
+ 0.13529365032911292,
121
+ 0.3629165390133857,
122
+ 1.2862326657772063,
123
+ 3.2829698753356933,
124
+ 0.9332760351896285,
125
+ 0.6325724506378171,
126
+ 0.039933966137468815,
127
+ -0.001671919699292631
128
+ ]
129
+ },
130
+ "num_transitions": 52970,
131
+ "num_trajectories": 432
132
+ },
133
+ "libero_object_no_noops": {
134
+ "action": {
135
+ "mean": [
136
+ 0.07096527516841888,
137
+ 0.1349887251853943,
138
+ -0.046013835817575455,
139
+ 0.0012351985787972808,
140
+ 0.006998839322477579,
141
+ -0.015027610585093498,
142
+ 0.46428999304771423
143
+ ],
144
+ "std": [
145
+ 0.26812371611595154,
146
+ 0.4384680688381195,
147
+ 0.44749751687049866,
148
+ 0.024446608498692513,
149
+ 0.04935549572110176,
150
+ 0.04210718348622322,
151
+ 0.4987916350364685
152
+ ],
153
+ "max": [
154
+ 0.9375,
155
+ 0.8919642567634583,
156
+ 0.9375,
157
+ 0.17678570747375488,
158
+ 0.35035714507102966,
159
+ 0.1810714304447174,
160
+ 1.0
161
+ ],
162
+ "min": [
163
+ -0.8839285969734192,
164
+ -0.9375,
165
+ -0.9375,
166
+ -0.15000000596046448,
167
+ -0.29035714268684387,
168
+ -0.32892856001853943,
169
+ 0.0
170
+ ],
171
+ "q01": [
172
+ -0.5383928418159485,
173
+ -0.8758928775787354,
174
+ -0.9375,
175
+ -0.06964285671710968,
176
+ -0.11678571254014969,
177
+ -0.15964286029338837,
178
+ 0.0
179
+ ],
180
+ "q99": [
181
+ 0.8464285731315613,
182
+ 0.84375,
183
+ 0.9375,
184
+ 0.08142857253551483,
185
+ 0.14892856776714325,
186
+ 0.0867857113480568,
187
+ 1.0
188
+ ],
189
+ "mask": [
190
+ true,
191
+ true,
192
+ true,
193
+ true,
194
+ true,
195
+ true,
196
+ false
197
+ ]
198
+ },
199
+ "proprio": {
200
+ "mean": [
201
+ -0.029990315437316895,
202
+ -0.007947145961225033,
203
+ 0.20293475687503815,
204
+ 3.1086409091949463,
205
+ -0.21404768526554108,
206
+ -0.11307074129581451,
207
+ 0.029380440711975098,
208
+ -0.03055672161281109
209
+ ],
210
+ "std": [
211
+ 0.06694885343313217,
212
+ 0.1760847419500351,
213
+ 0.07807066291570663,
214
+ 0.08684844523668289,
215
+ 0.335404634475708,
216
+ 0.2072829008102417,
217
+ 0.00956575945019722,
218
+ 0.009197483770549297
219
+ ],
220
+ "max": [
221
+ 0.14580604434013367,
222
+ 0.33216384053230286,
223
+ 0.3857804834842682,
224
+ 3.4003844261169434,
225
+ 0.7954911589622498,
226
+ 0.6642207503318787,
227
+ 0.04104341194033623,
228
+ -0.00018117300351150334
229
+ ],
230
+ "min": [
231
+ -0.1765444278717041,
232
+ -0.29457300901412964,
233
+ 0.008128180168569088,
234
+ 2.2890501022338867,
235
+ -1.883241891860962,
236
+ -1.0600427389144897,
237
+ 0.0006495157140307128,
238
+ -0.041782498359680176
239
+ ],
240
+ "q01": [
241
+ -0.14911890715360643,
242
+ -0.25978428691625594,
243
+ 0.009925739830359817,
244
+ 2.7545341420173646,
245
+ -1.3996034812927245,
246
+ -0.6867720144987106,
247
+ 0.008197814421728254,
248
+ -0.04015838988125324
249
+ ],
250
+ "q99": [
251
+ 0.09063626825809479,
252
+ 0.29066365867853167,
253
+ 0.3370887073874472,
254
+ 3.2611824750900267,
255
+ 0.32092821151018125,
256
+ 0.4037663781642913,
257
+ 0.039891827926039694,
258
+ -0.009106044843792932
259
+ ]
260
+ },
261
+ "num_transitions": 66984,
262
+ "num_trajectories": 454
263
+ },
264
+ "libero_goal_no_noops": {
265
+ "action": {
266
+ "mean": [
267
+ 0.04721052572131157,
268
+ 0.02883528731763363,
269
+ -0.14858423173427582,
270
+ -0.002501001814380288,
271
+ 0.026408176869153976,
272
+ 0.027379784733057022,
273
+ 0.6299911737442017
274
+ ],
275
+ "std": [
276
+ 0.3968808054924011,
277
+ 0.3473387062549591,
278
+ 0.49239954352378845,
279
+ 0.0553317591547966,
280
+ 0.0784476175904274,
281
+ 0.10008786618709564,
282
+ 0.4827007055282593
283
+ ],
284
+ "max": [
285
+ 0.9375,
286
+ 0.9375,
287
+ 0.9375,
288
+ 0.3557142913341522,
289
+ 0.375,
290
+ 0.375,
291
+ 1.0
292
+ ],
293
+ "min": [
294
+ -0.9375,
295
+ -0.9375,
296
+ -0.9375,
297
+ -0.2582142949104309,
298
+ -0.375,
299
+ -0.2871428430080414,
300
+ 0.0
301
+ ],
302
+ "q01": [
303
+ -0.8785714507102966,
304
+ -0.7553571462631226,
305
+ -0.9375,
306
+ -0.1510714292526245,
307
+ -0.1639285683631897,
308
+ -0.13777500048279764,
309
+ 0.0
310
+ ],
311
+ "q99": [
312
+ 0.9375,
313
+ 0.9107142686843872,
314
+ 0.9375,
315
+ 0.20357142388820648,
316
+ 0.26357144117355347,
317
+ 0.375,
318
+ 1.0
319
+ ],
320
+ "mask": [
321
+ true,
322
+ true,
323
+ true,
324
+ true,
325
+ true,
326
+ true,
327
+ false
328
+ ]
329
+ },
330
+ "proprio": {
331
+ "mean": [
332
+ -0.09923479706048965,
333
+ 0.013597898185253143,
334
+ 1.0694578886032104,
335
+ 2.8289811611175537,
336
+ 0.307992547750473,
337
+ -0.2742873728275299,
338
+ 0.028092363849282265,
339
+ -0.027339326217770576
340
+ ],
341
+ "std": [
342
+ 0.11653962731361389,
343
+ 0.11478123068809509,
344
+ 0.1048782616853714,
345
+ 0.557030439376831,
346
+ 0.7221670746803284,
347
+ 0.3647960424423218,
348
+ 0.015074768103659153,
349
+ 0.014990939758718014
350
+ ],
351
+ "max": [
352
+ 0.13579000532627106,
353
+ 0.33316105604171753,
354
+ 1.3660105466842651,
355
+ 3.473310708999634,
356
+ 2.6688623428344727,
357
+ 0.8255361318588257,
358
+ 0.04233968257904053,
359
+ 0.0010111660230904818
360
+ ],
361
+ "min": [
362
+ -0.46141114830970764,
363
+ -0.30129560828208923,
364
+ 0.9083037972450256,
365
+ 0.35277295112609863,
366
+ -1.4858465194702148,
367
+ -1.5227035284042358,
368
+ -0.0013586411951109767,
369
+ -0.042040832340717316
370
+ ],
371
+ "q01": [
372
+ -0.42401049643754957,
373
+ -0.27338370531797407,
374
+ 0.911226047873497,
375
+ 1.3085840785503386,
376
+ -0.691297555565834,
377
+ -1.130668159723282,
378
+ 0.0016738151130266487,
379
+ -0.040336399003863335
380
+ ],
381
+ "q99": [
382
+ 0.08990443304181095,
383
+ 0.26473945528268716,
384
+ 1.2910678112506866,
385
+ 3.2425890421867365,
386
+ 2.3376442337036116,
387
+ 0.4659483411908149,
388
+ 0.040610933862626555,
389
+ -0.0015016929572448147
390
+ ]
391
+ },
392
+ "num_transitions": 52042,
393
+ "num_trajectories": 428
394
+ },
395
+ "libero_10_no_noops": {
396
+ "action": {
397
+ "mean": [
398
+ 0.018203141167759895,
399
+ 0.05858383700251579,
400
+ -0.05592375248670578,
401
+ 0.004626910667866468,
402
+ 0.002896096557378769,
403
+ -0.007673157844692469,
404
+ 0.5457824468612671
405
+ ],
406
+ "std": [
407
+ 0.2825472950935364,
408
+ 0.3590468764305115,
409
+ 0.3673798739910126,
410
+ 0.03770606219768524,
411
+ 0.05429606884717941,
412
+ 0.08725270628929138,
413
+ 0.4981527626514435
414
+ ],
415
+ "max": [
416
+ 0.9375,
417
+ 0.9375,
418
+ 0.9375,
419
+ 0.30000001192092896,
420
+ 0.29357144236564636,
421
+ 0.375,
422
+ 1.0
423
+ ],
424
+ "min": [
425
+ -0.9375,
426
+ -0.9375,
427
+ -0.9375,
428
+ -0.23642857372760773,
429
+ -0.3053571283817291,
430
+ -0.3675000071525574,
431
+ 0.0
432
+ ],
433
+ "q01": [
434
+ -0.6348214149475098,
435
+ -0.7741071581840515,
436
+ -0.7633928656578064,
437
+ -0.09749999642372131,
438
+ -0.14819999992847435,
439
+ -0.2742857038974762,
440
+ 0.0
441
+ ],
442
+ "q99": [
443
+ 0.7714285850524902,
444
+ 0.8464285731315613,
445
+ 0.9375,
446
+ 0.13928571343421936,
447
+ 0.15964286029338837,
448
+ 0.3246428668498993,
449
+ 1.0
450
+ ],
451
+ "mask": [
452
+ true,
453
+ true,
454
+ true,
455
+ true,
456
+ true,
457
+ true,
458
+ false
459
+ ]
460
+ },
461
+ "proprio": {
462
+ "mean": [
463
+ -0.04190631955862045,
464
+ 0.03539435938000679,
465
+ 0.8257158398628235,
466
+ 2.908320426940918,
467
+ -0.5562182664871216,
468
+ -0.1664901226758957,
469
+ 0.02831663191318512,
470
+ -0.02856140024960041
471
+ ],
472
+ "std": [
473
+ 0.10743344575166702,
474
+ 0.14424683153629303,
475
+ 0.257233202457428,
476
+ 0.3441365361213684,
477
+ 1.2344205379486084,
478
+ 0.357982873916626,
479
+ 0.013308685272932053,
480
+ 0.013174619525671005
481
+ ],
482
+ "max": [
483
+ 0.21031762659549713,
484
+ 0.39128610491752625,
485
+ 1.3332009315490723,
486
+ 3.6714255809783936,
487
+ 3.560650587081909,
488
+ 1.386339545249939,
489
+ 0.04160946607589722,
490
+ 0.0013633022317662835
491
+ ],
492
+ "min": [
493
+ -0.4828203022480011,
494
+ -0.3255046010017395,
495
+ 0.445506751537323,
496
+ 1.1321442127227783,
497
+ -3.641430377960205,
498
+ -1.842738389968872,
499
+ -0.0010040868073701859,
500
+ -0.04111652821302414
501
+ ],
502
+ "q01": [
503
+ -0.3899900782108307,
504
+ -0.2838300323486328,
505
+ 0.44795057058334353,
506
+ 1.8810229921340942,
507
+ -2.886677579879761,
508
+ -1.1599004411697387,
509
+ 0.002066459748893976,
510
+ -0.04001387819647789
511
+ ],
512
+ "q99": [
513
+ 0.1530261474847791,
514
+ 0.32915401458740223,
515
+ 1.2546923208236693,
516
+ 3.303542451858519,
517
+ 2.7496529006957933,
518
+ 0.6893712210655194,
519
+ 0.040048558115959164,
520
+ -0.0017598449345678235
521
+ ]
522
+ },
523
+ "num_transitions": 101469,
524
+ "num_trajectories": 379
525
+ }
526
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97eaf96d22d0438f69cc6a739f771b8a7b7972dec4c7df3f31c2178bd1990a5a
3
+ size 2505232584
modeling_prismatic.py ADDED
@@ -0,0 +1,1607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions.
5
+ Inherits from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained,
6
+ but exactly replicate the logic in `prismatic.models.vlms.prismatic.py`.
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+ import timm
16
+ import tokenizers
17
+ import torch
18
+ import torch.nn as nn
19
+ import transformers
20
+ from timm.models.vision_transformer import LayerScale
21
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import ModelOutput
23
+
24
+ from prismatic.training.train_utils import (
25
+ get_current_action_mask,
26
+ get_next_actions_mask,
27
+ )
28
+ from prismatic.vla.constants import (
29
+ ACTION_DIM,
30
+ ACTION_PROPRIO_NORMALIZATION_TYPE,
31
+ ACTION_TOKEN_BEGIN_IDX,
32
+ IGNORE_INDEX,
33
+ NUM_ACTIONS_CHUNK,
34
+ STOP_INDEX,
35
+ NormalizationType,
36
+ NUM_TOKENS
37
+ )
38
+
39
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
40
+
41
+
42
+
43
+ # Set up logger
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ # === Utility Functions for Monkey-Patching ===
48
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
49
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
50
+ result = fn(*args, **kwargs)
51
+ return result[0] if isinstance(result, tuple) else result
52
+
53
+ return wrapper
54
+
55
+
56
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
57
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
58
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
59
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
61
+
62
+
63
+ def ls_apply_patch(ls_module: LayerScale):
64
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
65
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
66
+ del ls_module.gamma
67
+
68
+
69
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
70
+ class PrismaticVisionBackbone(nn.Module):
71
+ """
72
+ Vision backbone for Prismatic models that handles image feature extraction.
73
+
74
+ Supports both single backbone (e.g., SigLIP) and fused backbone (e.g., SigLIP + DINOv2) configurations.
75
+ For fused backbones, features from both models are concatenated along the feature dimension.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ use_fused_vision_backbone: bool,
81
+ image_sizes: List[int],
82
+ timm_model_ids: List[str],
83
+ timm_override_act_layers: List[Optional[str]],
84
+ ) -> None:
85
+ """
86
+ Initialize the vision backbone.
87
+
88
+ Args:
89
+ use_fused_vision_backbone: Whether to use two backbones and fuse their features
90
+ image_sizes: List of image sizes for each backbone
91
+ timm_model_ids: List of TIMM model IDs to use for each backbone
92
+ timm_override_act_layers: List of activation layer overrides for each backbone
93
+ """
94
+ super().__init__()
95
+ self.use_fused_vision_backbone = use_fused_vision_backbone
96
+ self.num_images_in_input = 1 # Default value, can be overridden later
97
+
98
+ # Validate number of (fused) vision backbones
99
+ if len(timm_model_ids) > 2:
100
+ raise ValueError("Prismatic models only support up to 2 (fused) vision backbones!")
101
+
102
+ # Create primary featurizer
103
+ self.featurizer = self._create_featurizer(
104
+ model_id=timm_model_ids[0], img_size=image_sizes[0], act_layer=timm_override_act_layers[0]
105
+ )
106
+ self.embed_dim = self.featurizer.embed_dim
107
+
108
+ # Create secondary featurizer if using fused backbone
109
+ if self.use_fused_vision_backbone:
110
+ self.fused_featurizer = self._create_featurizer(
111
+ model_id=timm_model_ids[1], img_size=image_sizes[1], act_layer=timm_override_act_layers[1]
112
+ )
113
+ self.embed_dim += self.fused_featurizer.embed_dim
114
+
115
+ # Patch LayerScale modules for HF compatibility
116
+ self._patch_layer_scales()
117
+
118
+ def _create_featurizer(self, model_id: str, img_size: int, act_layer: Optional[str]) -> nn.Module:
119
+ """
120
+ Create a TIMM-based featurizer model with appropriate configurations.
121
+
122
+ Args:
123
+ model_id: The TIMM model ID to load
124
+ img_size: Input image size for the model
125
+ act_layer: Override for the activation layer type
126
+
127
+ Returns:
128
+ A configured featurizer model
129
+ """
130
+ featurizer = timm.create_model(
131
+ model_id,
132
+ pretrained=False,
133
+ num_classes=0,
134
+ img_size=img_size,
135
+ act_layer=act_layer,
136
+ )
137
+
138
+ # Monkey-patch the forward function to extract the second-to-last layer features
139
+ num_blocks = len(featurizer.blocks)
140
+ featurizer.forward = unpack_tuple(partial(featurizer.get_intermediate_layers, n={num_blocks - 2}))
141
+
142
+ return featurizer
143
+
144
+ def _patch_layer_scales(self) -> None:
145
+ """
146
+ Patch all LayerScale modules to be compatible with HF's parameter naming.
147
+
148
+ HF Transformers overwrites parameters with names containing 'gamma',
149
+ so we need to rename and modify the forward method.
150
+ """
151
+ # Patch primary featurizer
152
+ for module in self.featurizer.modules():
153
+ if isinstance(module, LayerScale):
154
+ ls_apply_patch(module)
155
+
156
+ # Patch secondary featurizer if it exists
157
+ if self.use_fused_vision_backbone:
158
+ for module in self.fused_featurizer.modules():
159
+ if isinstance(module, LayerScale):
160
+ ls_apply_patch(module)
161
+
162
+ def get_num_patches(self) -> int:
163
+ """
164
+ Returns the number of vision patches output by the vision backbone.
165
+
166
+ Returns:
167
+ Number of patches per image
168
+ """
169
+ return self.featurizer.patch_embed.num_patches
170
+
171
+ def get_num_images_in_input(self) -> int:
172
+ """
173
+ Returns the number of input images for the vision backbone.
174
+
175
+ Returns:
176
+ Number of images expected in the input
177
+ """
178
+ return self.num_images_in_input
179
+
180
+ def set_num_images_in_input(self, num_images_in_input: int) -> None:
181
+ """
182
+ Sets the number of input images for the vision backbone.
183
+
184
+ Args:
185
+ num_images_in_input: Number of images to expect in the input
186
+ """
187
+ self.num_images_in_input = num_images_in_input
188
+
189
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Implements the forward pass for the vision backbone.
192
+
193
+ If `self.use_fused_vision_backbone == True`, uses both SigLIP and DINOv2 transformers to extract visual features
194
+ (otherwise uses SigLIP only). Allows multi-image inputs (but only for fused vision backbone).
195
+
196
+ Args:
197
+ pixel_values (torch.Tensor): Pixels for input image(s), (B, C, H, W).
198
+ """
199
+ if self.num_images_in_input == 1:
200
+ if not self.use_fused_vision_backbone:
201
+ return self.featurizer(pixel_values)
202
+
203
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
204
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
205
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
206
+
207
+ return torch.cat([patches, patches_fused], dim=2)
208
+
209
+ else:
210
+ assert self.use_fused_vision_backbone, "Multi-image inputs require using fused backbone!"
211
+
212
+ # Split `pixel_values` into individual images (each with 6 channels: 3 for SigLIP + 3 for DINOv2)
213
+ images = torch.split(pixel_values, [6] * self.num_images_in_input, dim=1)
214
+
215
+ # Process each image and collect patches
216
+ all_patches = []
217
+ for img in images:
218
+ # Split each image further into two stacks of channels (each with 3 channels)
219
+ img_regular, img_fused = torch.split(img, [3, 3], dim=1)
220
+
221
+ # Get patches from both SigLIP and DINOv2 vision transformers
222
+ patches = self.featurizer(img_regular)
223
+ patches_fused = self.fused_featurizer(img_fused)
224
+
225
+ # Concatenate SigLIP and DINOv2 patches along the hidden dimension
226
+ combined_patches = torch.cat([patches, patches_fused], dim=2)
227
+ all_patches.append(combined_patches)
228
+
229
+ # Concatenate all patches along the patch dimension
230
+ return torch.cat(all_patches, dim=1)
231
+
232
+
233
+ # === Prismatic Projector (nn.Module) Definitions ===
234
+ class PrismaticProjector(nn.Module):
235
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
236
+ super().__init__()
237
+ self.use_fused_vision_backbone = use_fused_vision_backbone
238
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
239
+
240
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
241
+ if not self.use_fused_vision_backbone:
242
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
243
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
244
+ self.act_fn1 = nn.GELU()
245
+ else:
246
+ initial_projection_dim = 4 * vision_dim
247
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
248
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
249
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
250
+ self.act_fn1 = nn.GELU()
251
+ self.act_fn2 = nn.GELU()
252
+
253
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
254
+ if not self.use_fused_vision_backbone:
255
+ projected_features = self.fc1(img_patches)
256
+ projected_features = self.act_fn1(projected_features)
257
+ projected_features = self.fc2(projected_features)
258
+ else:
259
+ projected_features = self.fc1(img_patches)
260
+ projected_features = self.act_fn1(projected_features)
261
+ projected_features = self.fc2(projected_features)
262
+ projected_features = self.act_fn2(projected_features)
263
+ projected_features = self.fc3(projected_features)
264
+
265
+ return projected_features
266
+
267
+
268
+ # === Main HF Class Definitions ===
269
+ @dataclass
270
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
271
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
272
+
273
+ loss: Optional[torch.FloatTensor] = None
274
+ logits: torch.FloatTensor = None
275
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
276
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
277
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
278
+
279
+ # Additions for VLMs
280
+ projector_features: Optional[torch.FloatTensor] = None
281
+
282
+
283
+ class PrismaticPreTrainedModel(PreTrainedModel):
284
+ config_class: PretrainedConfig = PrismaticConfig
285
+ base_model_prefix: str = "model"
286
+ supports_gradient_checkpointing: bool = True
287
+
288
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
289
+ _skip_keys_device_placement: str = "past_key_values"
290
+ _supports_flash_attn_2: bool = True
291
+
292
+
293
+
294
+ def _init_weights(self, module: nn.Module) -> None:
295
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
296
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
297
+ # https://github.com/TRI-ML/prismatic-vlms
298
+ std = (
299
+ self.config.initializer_range
300
+ if hasattr(self.config, "initializer_range")
301
+ else self.config.text_config.initializer_range
302
+ )
303
+
304
+ if hasattr(module, "class_embedding"):
305
+ module.class_embedding.data.normal_(mean=0.0, std=std)
306
+
307
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
308
+ module.weight.data.normal_(mean=0.0, std=std)
309
+ if module.bias is not None:
310
+ module.bias.data.zero_()
311
+ elif isinstance(module, nn.Embedding):
312
+ module.weight.data.normal_(mean=0.0, std=std)
313
+ if module.padding_idx is not None:
314
+ module.weight.data[module.padding_idx].zero_()
315
+
316
+ @property
317
+ def _supports_sdpa(self) -> bool:
318
+ """Check LLM supports SDPA Attention"""
319
+ return self.language_model._supports_sdpa
320
+
321
+
322
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
323
+ def __init__(self, config: PrismaticConfig) -> None:
324
+ super().__init__(config)
325
+
326
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
327
+ if config.use_fused_vision_backbone is None:
328
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
329
+
330
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
331
+ raise NotImplementedError(
332
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
333
+ "if you urgently need support for latest TIMM versions."
334
+ )
335
+
336
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
337
+ logger.warning(
338
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
339
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
340
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
341
+ f"use the above versions."
342
+ )
343
+ # import pdb; pdb.set_trace()
344
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
345
+ self.vision_backbone = PrismaticVisionBackbone(
346
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
347
+ )
348
+
349
+ # Create Multimodal Projector
350
+ self.projector = PrismaticProjector(
351
+ config.use_fused_vision_backbone,
352
+ vision_dim=self.vision_backbone.embed_dim,
353
+ llm_dim=config.text_config.hidden_size,
354
+ )
355
+
356
+ # Instantiate LLM Backbone
357
+ self.language_model = AutoModelForCausalLM.from_config(
358
+ config.text_config, attn_implementation=config._attn_implementation
359
+ )
360
+
361
+ self.vocab_size = config.text_config.vocab_size
362
+ self.pad_token_id = config.pad_token_id
363
+ self.llm_dim = config.text_config.hidden_size
364
+
365
+ # #Action query token
366
+ self.action_queries = nn.Embedding(NUM_TOKENS, self.llm_dim)
367
+ self.action_queries.weight.data.zero_()
368
+
369
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
370
+ self.post_init()
371
+
372
+ # === `PreTrainedModel` Boilerplate ===
373
+ def get_input_embeddings(self) -> nn.Module:
374
+ return self.language_model.get_input_embeddings()
375
+ def set_version(self, version: str):
376
+ self.version = version
377
+ return self.version
378
+
379
+
380
+ def set_input_embeddings(self, value: nn.Module) -> None:
381
+ self.language_model.set_input_embeddings(value)
382
+
383
+ def get_output_embeddings(self) -> nn.Module:
384
+ return self.language_model.get_output_embeddings()
385
+
386
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
387
+ self.language_model.set_output_embeddings(new_embeddings)
388
+
389
+ def get_decoder(self) -> nn.Module:
390
+ return self.language_model.get_decoder()
391
+
392
+ def set_decoder(self, decoder: nn.Module) -> None:
393
+ self.language_model.set_decoder(decoder)
394
+
395
+ def tie_weights(self) -> None:
396
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
397
+
398
+ def resize_token_embeddings(
399
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
400
+ ) -> nn.Embedding:
401
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
402
+
403
+ # Update config/instance variables
404
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
405
+ self.vocab_size = updated_embeddings.num_embeddings
406
+
407
+ return updated_embeddings
408
+
409
+ def _replace_input_embeddings(self, input_embeddings, all_actions_mask, noisy_action_features):
410
+ """
411
+ Replace embeddings in input_embeddings at positions where all_actions_mask is True
412
+ with embeddings from noisy_action_features, using vectorized operations.
413
+
414
+ Args:
415
+ input_embeddings: Tensor of shape (B, S, D)
416
+ all_actions_mask: Boolean tensor of shape (B, S)
417
+ noisy_action_features: Tensor of shape (B, K, D) where K is the number of True values in mask per sample
418
+
419
+ Returns:
420
+ Modified input_embeddings tensor
421
+ """
422
+ # Clone input to avoid modifying the original tensor
423
+ new_input_embeddings = input_embeddings.clone()
424
+
425
+ # Create a tensor with the same shape of input_embeddings to hold the noisy action features
426
+ repositioned_noisy_action_features = torch.zeros_like(input_embeddings)
427
+
428
+ # Create batch indices for splicing
429
+ batch_indices = torch.arange(input_embeddings.shape[0], device=input_embeddings.device)
430
+ batch_indices = batch_indices.unsqueeze(1).expand(-1, noisy_action_features.shape[1])
431
+
432
+ # Get indices where mask is True for each sample
433
+ masked_indices = torch.stack([torch.where(mask)[0] for mask in all_actions_mask])
434
+
435
+ # Move the noisy action features into their correct positions
436
+ # print(noisy_action_features.size())
437
+ # import pdb; pdb.set_trace()
438
+ repositioned_noisy_action_features[batch_indices, masked_indices] = noisy_action_features
439
+
440
+ # Combine original input embeddings and noisy action embeddings using the mask
441
+ new_input_embeddings = torch.where(
442
+ all_actions_mask.unsqueeze(-1), repositioned_noisy_action_features, new_input_embeddings
443
+ )
444
+
445
+ return new_input_embeddings
446
+
447
+ def _process_action_masks(self, labels):
448
+ """Helper to get action masks from labels"""
449
+ current_action_mask = get_current_action_mask(labels)
450
+ next_actions_mask = get_next_actions_mask(labels)
451
+ all_actions_mask = current_action_mask | next_actions_mask # (B, seq_len)
452
+ return all_actions_mask
453
+
454
+ def _process_vision_features(self, pixel_values, language_embeddings=None, use_film=False):
455
+ """Process vision features with optional FiLM conditioning"""
456
+ if use_film:
457
+ # FiLM: Infuse language inputs into visual features
458
+ patch_features = self.vision_backbone(pixel_values, language_embeddings) # (bsz, 256 * num_images, D)
459
+ else:
460
+ patch_features = self.vision_backbone(pixel_values) # (bsz, 256 * num_images, D)
461
+
462
+ # Project patch embeddings into language embedding space
463
+ return self.projector(patch_features)
464
+
465
+ def _process_proprio_features(self, projected_patch_embeddings, proprio, proprio_projector):
466
+ """Process proprioceptive features and append to vision features"""
467
+ if proprio_projector is not None and proprio is not None:
468
+ # projected_patch_embeddings: (bsz, num_patches * num_images, llm_dim)
469
+ # proprio: (bsz, proprio_dim) or (propro_dim,)
470
+ proprio = proprio.reshape(projected_patch_embeddings.shape[0], -1) # (bsz, proprio_dim)
471
+ proprio_features = proprio_projector(proprio) # (bsz, llm_dim)
472
+ proprio_features = proprio_features.unsqueeze(dim=1) # (bsz, 1, llm_dim)
473
+ # For simplicity, just append proprio token to the end of projected vision patch tokens
474
+ return torch.cat((projected_patch_embeddings, proprio_features), dim=1)
475
+ return projected_patch_embeddings
476
+
477
+ def _build_multimodal_attention(self, input_embeddings, projected_patch_embeddings, attention_mask):
478
+ """Build multimodal embeddings and attention mask"""
479
+ # Update attention mask
480
+ # import pdb; pdb.set_trace()
481
+ projected_patch_attention_mask = None
482
+ if attention_mask is not None:
483
+ projected_patch_attention_mask = torch.full(
484
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
485
+ fill_value=True,
486
+ dtype=attention_mask.dtype,
487
+ device=attention_mask.device,
488
+ )
489
+
490
+ # Build multimodal embeddings & attention mask; insert embeddings after <BOS> token (1:)
491
+ multimodal_embeddings = torch.cat(
492
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
493
+ )
494
+
495
+ multimodal_attention_mask = None
496
+ if attention_mask is not None:
497
+ multimodal_attention_mask = torch.cat(
498
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
499
+ )
500
+
501
+ return multimodal_embeddings, multimodal_attention_mask
502
+
503
+ def _build_multimodal_labels(self, labels, projected_patch_embeddings):
504
+ """Build multimodal labels with IGNORE_INDEX for patch embeddings"""
505
+ if labels is not None:
506
+ projected_patch_labels = torch.full(
507
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
508
+ fill_value=IGNORE_INDEX,
509
+ dtype=labels.dtype,
510
+ device=labels.device,
511
+ )
512
+ return torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
513
+ return None
514
+
515
+ # === Core Prismatic VLM `forward()` Logic ===
516
+ def forward(
517
+ self,
518
+ input_ids: Optional[torch.LongTensor] = None,
519
+ attention_mask: Optional[torch.Tensor] = None,
520
+ pixel_values: Optional[torch.FloatTensor] = None,
521
+ labels: Optional[torch.LongTensor] = None,
522
+ inputs_embeds: Optional[torch.FloatTensor] = None,
523
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
524
+ use_cache: Optional[bool] = None,
525
+ output_attentions: Optional[bool] = None,
526
+ output_hidden_states: Optional[bool] = None,
527
+ output_projector_features: Optional[bool] = None,
528
+ return_dict: Optional[bool] = None,
529
+ proprio=None,
530
+ proprio_projector=None,
531
+ noisy_actions=None,
532
+ noisy_action_projector=None,
533
+ diffusion_timestep_embeddings=None,
534
+ use_film: bool = False,
535
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
536
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
537
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
538
+ output_hidden_states = (
539
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
540
+ )
541
+ output_projector_features = output_projector_features if output_projector_features is not None else False
542
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
543
+
544
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
545
+ use_cache = use_cache and not self.training
546
+
547
+ # Instantiate Placeholder for Projector Features
548
+ projected_patch_embeddings = None
549
+
550
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
551
+ if input_ids.shape[1] == 1:
552
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
553
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
554
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
555
+
556
+ language_model_output = self.language_model(
557
+ input_ids=input_ids,
558
+ attention_mask=None,
559
+ position_ids=None,
560
+ past_key_values=past_key_values,
561
+ inputs_embeds=None,
562
+ labels=None,
563
+ use_cache=use_cache,
564
+ output_attentions=output_attentions,
565
+ output_hidden_states=output_hidden_states,
566
+ return_dict=return_dict,
567
+ )
568
+
569
+ # === Handle Unimodal Forward ===
570
+ elif pixel_values is None:
571
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
572
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
573
+
574
+ language_model_output = self.language_model(
575
+ input_ids=input_ids,
576
+ attention_mask=attention_mask,
577
+ position_ids=None,
578
+ past_key_values=None,
579
+ inputs_embeds=None,
580
+ labels=labels,
581
+ use_cache=use_cache,
582
+ output_attentions=output_attentions,
583
+ output_hidden_states=output_hidden_states,
584
+ return_dict=return_dict,
585
+ )
586
+
587
+ # === Handle Multimodal Forward ===
588
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
589
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during multimodal forward!"
590
+
591
+ # Get input embeddings (from language model embeddings)
592
+ input_embeddings = self.get_input_embeddings()(input_ids) # (B, seq_len, D)
593
+
594
+ # import pdb; pdb.set_trace()
595
+ # Extract action masks
596
+ all_actions_mask = self._process_action_masks(labels)
597
+
598
+ # Extract the language portion of the input embeddings (i.e. remove the action tokens portion)
599
+ # import pdb; pdb.set_trace()
600
+ # print(input_embeddings[~all_actions_mask].size())
601
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
602
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
603
+ ) # (B, lang_seq_len, llm_dim)
604
+
605
+ # Get visual features
606
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
607
+
608
+ # Add proprioceptive state if provided
609
+ if self.version == 'v1':
610
+ pass
611
+ else:
612
+ projected_patch_embeddings = self._process_proprio_features(
613
+ projected_patch_embeddings, proprio, proprio_projector
614
+ )
615
+
616
+ # [Diffusion] Add diffusion timestep embedding if provided
617
+ if diffusion_timestep_embeddings is not None:
618
+ if self.version == 'v1':
619
+ pass
620
+ else:
621
+ # For simplicity, just append diffusion timestep embedding to the end of projected vision patch tokens
622
+ projected_patch_embeddings = torch.cat(
623
+ (projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
624
+ )
625
+
626
+
627
+ # Process action embeddings
628
+ if noisy_actions is not None:
629
+ # import pdb; pdb.set_trace()
630
+ if self.version == 'v1':
631
+ # action_queries = self.action_queries.weight # (1, h)
632
+ # action_queries = action_queries.view(1, 1, action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
633
+ # input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
634
+ # action_attention_mask = None
635
+ # action_attention_mask = torch.full(
636
+ # (action_queries.shape[0], action_queries.shape[1]),
637
+ # fill_value=True,
638
+ # dtype=attention_mask.dtype,
639
+ # device=attention_mask.device,)
640
+ # attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
641
+ # breakpoint()
642
+ action_queries = self.action_queries.weight # (1, h)
643
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
644
+ # action_queries = action_queries * 0.0 # 只为验证,先把影响干掉
645
+ all_actions_mask = self._process_action_masks(labels)
646
+ input_embeddings = self._replace_input_embeddings(
647
+ input_embeddings, all_actions_mask, action_queries)
648
+ # import pdb; pdb.set_trace()
649
+ # all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
650
+ # input_embeddings = input_embeddings * ~all_actions_mask
651
+
652
+ else:
653
+ # Get mask corresponding to all action tokens
654
+ all_actions_mask = self._process_action_masks(labels)
655
+
656
+ # Reshape noisy actions into individual action tokens
657
+ # noisy_actions: (B, chunk_len, action_dim) -> (B, chunk_len * action_dim, 1)
658
+ B = noisy_actions.shape[0]
659
+ noisy_actions = noisy_actions.reshape(B, -1).unsqueeze(-1)
660
+ # Project noisy action tokens into language model embedding space
661
+ noisy_action_features = noisy_action_projector(noisy_actions) # (B, chunk_len * action_dim, llm_dim)
662
+ # Replace embeddings of the action tokens with noisy action embeddings
663
+ input_embeddings = self._replace_input_embeddings(
664
+ input_embeddings, all_actions_mask, noisy_action_features)
665
+
666
+ else:
667
+ if self.version == 'v1':
668
+ action_queries = self.action_queries.weight # (1, h)
669
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
670
+ all_actions_mask = self._process_action_masks(labels)
671
+ input_embeddings = self._replace_input_embeddings(
672
+ input_embeddings, all_actions_mask, action_queries)
673
+ # import pdb; pdb.set_trace()
674
+ # all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
675
+ # input_embeddings = input_embeddings * ~all_actions_mask
676
+ else:
677
+ # Replace the embeddings of the action tokens with zeros
678
+ # (Later on, the positional embeddings will be added to them)
679
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
680
+ input_embeddings = input_embeddings * ~all_actions_mask
681
+
682
+
683
+ # Build multimodal embeddings & attention mask
684
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
685
+ input_embeddings, projected_patch_embeddings, attention_mask
686
+ )
687
+ # import pdb; pdb.set_trace()
688
+ # Build labels for multimodal sequence if needed
689
+ multimodal_labels = self._build_multimodal_labels(labels, projected_patch_embeddings)
690
+
691
+ # import pdb; pdb.set_trace()
692
+ # Dispatch to language model
693
+ if self.version == 'v1':
694
+ # import pdb; pdb.set_trace()
695
+ language_model_output = self.language_model(
696
+ input_ids=None,
697
+ attention_mask=multimodal_attention_mask,
698
+ position_ids=None,
699
+ past_key_values=None,
700
+ inputs_embeds=multimodal_embeddings,
701
+ labels=None,
702
+ use_cache=use_cache,
703
+ output_attentions=output_attentions,
704
+ output_hidden_states=output_hidden_states,
705
+ return_dict=return_dict,
706
+ )
707
+ # import pdb; pdb.set_trace()
708
+ else:
709
+ language_model_output = self.language_model(
710
+ input_ids=None,
711
+ attention_mask=multimodal_attention_mask,
712
+ position_ids=None,
713
+ past_key_values=None,
714
+ inputs_embeds=multimodal_embeddings,
715
+ labels=multimodal_labels,
716
+ use_cache=use_cache,
717
+ output_attentions=output_attentions,
718
+ output_hidden_states=output_hidden_states,
719
+ return_dict=return_dict,
720
+ )
721
+
722
+ # === Otherwise =>> Assume Invalid! ===
723
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
724
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
725
+
726
+ else:
727
+ raise ValueError(
728
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
729
+ f"=> `input_ids` = {input_ids is not None}\n"
730
+ f"=> `attention_mask` = {attention_mask is not None}\n"
731
+ f"=> `pixel_values` = {pixel_values is not None}\n"
732
+ f"=> `labels` = {labels is not None}\n"
733
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
734
+ f"=> `past_key_values` = {past_key_values is not None}\n"
735
+ f"=> `use_cache` = {use_cache}"
736
+ )
737
+
738
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
739
+ if not return_dict:
740
+ if output_projector_features and (projected_patch_embeddings is not None):
741
+ return *language_model_output, projected_patch_embeddings
742
+
743
+ return language_model_output
744
+
745
+ if self.version == 'v1':
746
+ return PrismaticCausalLMOutputWithPast(
747
+ loss=language_model_output.loss,
748
+ past_key_values=language_model_output.past_key_values,
749
+ hidden_states=language_model_output.hidden_states,
750
+ attentions=language_model_output.attentions,
751
+ projector_features=projected_patch_embeddings,
752
+ )
753
+ else:
754
+ return PrismaticCausalLMOutputWithPast(
755
+ loss=language_model_output.loss,
756
+ logits=language_model_output.logits,
757
+ past_key_values=language_model_output.past_key_values,
758
+ hidden_states=language_model_output.hidden_states,
759
+ attentions=language_model_output.attentions,
760
+ projector_features=projected_patch_embeddings,
761
+ )
762
+
763
+ # === GenerationMixin Methods ===
764
+ def prepare_inputs_for_generation(
765
+ self,
766
+ input_ids: Optional[torch.Tensor] = None,
767
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
768
+ inputs_embeds: Optional[torch.FloatTensor] = None,
769
+ pixel_values: Optional[torch.FloatTensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ **kwargs: str,
772
+ ) -> Dict[str, torch.Tensor]:
773
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
774
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
775
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
776
+ ):
777
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
778
+
779
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
780
+ if past_key_values is not None:
781
+ input_ids = input_ids[:, -1:]
782
+
783
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
784
+ if inputs_embeds is not None and past_key_values is None:
785
+ model_inputs = {"input_embeds": inputs_embeds}
786
+ else:
787
+ model_inputs = {"input_ids": input_ids}
788
+
789
+ # Make sure `pixel_values` are preserved in `model_inputs`
790
+ model_inputs.update(
791
+ {
792
+ "attention_mask": attention_mask,
793
+ "pixel_values": pixel_values,
794
+ "past_key_values": past_key_values,
795
+ "use_cache": kwargs.get("use_cache"),
796
+ }
797
+ )
798
+
799
+ return model_inputs
800
+
801
+ # Defer to Language Model (all handle this differently, with different return types)
802
+ def _reorder_cache(self, *args, **kwargs) -> Any:
803
+ return self.language_model._reorder_cache(*args, **kwargs)
804
+
805
+
806
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
807
+ config_class: PretrainedConfig = OpenVLAConfig
808
+ _supports_sdpa = False
809
+
810
+ def __init__(self, config: OpenVLAConfig) -> None:
811
+ super().__init__(config)
812
+ self.norm_stats = config.norm_stats
813
+ # import pdb; pdb.set_trace()
814
+
815
+ # Compute action bins
816
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
817
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
818
+
819
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
820
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
821
+
822
+ def _prepare_input_for_action_prediction(self, input_ids, attention_mask):
823
+ """Prepares input for action prediction by adding necessary tokens"""
824
+ # Add (ACTION_DIM * NUM_ACTIONS_CHUNK) placeholder tokens to input_ids to simulate action tokens
825
+ placeholder_action_token_ids = (
826
+ torch.ones((input_ids.shape[0], NUM_TOKENS)).to(input_ids.device).to(input_ids.dtype)
827
+ )
828
+ input_ids = torch.cat([input_ids, placeholder_action_token_ids], dim=-1)
829
+
830
+ # Add stop token to sequence (needed in non-causal bi-directional self-attention, as it appears at train time)
831
+ stop_token_id = torch.ones((input_ids.shape[0], 1)).to(input_ids.device).to(input_ids.dtype) * STOP_INDEX
832
+ input_ids = torch.cat([input_ids, stop_token_id], dim=-1)
833
+
834
+ # Extend the attention mask to fit the new shape of input
835
+ # Note: Only batch size == 1 supported right now
836
+ mask_extension = (
837
+ torch.ones((attention_mask.shape[0], input_ids.shape[-1] - attention_mask.shape[-1]))
838
+ .to(attention_mask.device)
839
+ .to(attention_mask.dtype)
840
+ )
841
+ attention_mask = torch.cat([attention_mask, mask_extension], dim=-1)
842
+
843
+ return input_ids, attention_mask
844
+
845
+ def _prepare_labels_for_action_prediction(self, labels, input_ids):
846
+ """Creates labels tensor for action prediction if not provided"""
847
+ # Extend labels tensor with fake action labels
848
+ ARBITRARY_ACTION_TOKEN_IDX = ACTION_TOKEN_BEGIN_IDX + 1
849
+ labels_extension = (
850
+ torch.ones((labels.shape[0], input_ids.shape[-1] - labels.shape[-1])).to(labels.device).to(labels.dtype)
851
+ * ARBITRARY_ACTION_TOKEN_IDX
852
+ )
853
+ labels = torch.cat([labels, labels_extension], dim=-1)
854
+
855
+ # Replace last label token with stop token
856
+ labels[:, -1] = STOP_INDEX
857
+
858
+ return labels
859
+
860
+ def _unnormalize_actions(self, normalized_actions, unnorm_key=None):
861
+ """Unnormalize actions using dataset statistics"""
862
+ action_norm_stats = self.get_action_stats(unnorm_key)
863
+
864
+ if ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS:
865
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["min"], dtype=bool))
866
+ action_high, action_low = np.array(action_norm_stats["max"]), np.array(action_norm_stats["min"])
867
+ elif ACTION_PROPRIO_NORMALIZATION_TYPE == NormalizationType.BOUNDS_Q99:
868
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
869
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
870
+ else:
871
+ raise ValueError("Unsupported action/proprio normalization type detected!")
872
+
873
+ actions = np.where(
874
+ mask,
875
+ 0.5 * (normalized_actions + 1) * (action_high - action_low + 1e-8) + action_low,
876
+ normalized_actions,
877
+ )
878
+
879
+ return actions
880
+
881
+ def _run_flow_matching_prediction(
882
+ self,
883
+ input_embeddings,
884
+ all_actions_mask,
885
+ noise,
886
+ action_head,
887
+ projected_patch_embeddings,
888
+ labels,
889
+ attention_mask,
890
+ NUM_PATCHES,
891
+ NUM_PROMPT_TOKENS,
892
+ noisy_action_projector
893
+ ):
894
+ """Run flow matching-based action prediction"""
895
+ # Clone embedding for reuse in each timestep
896
+ # orig_projected_patch_embeddings = projected_patch_embeddings.clone()
897
+
898
+ dt = -1.0 / action_head.num_flow_steps
899
+ dt = torch.tensor(dt, dtype=torch.bfloat16, device=labels.device)
900
+
901
+ curr_noisy_actions = noise
902
+ time = torch.tensor(1.0, dtype=torch.bfloat16, device=labels.device)
903
+ while time >= -dt / 2:
904
+ B = curr_noisy_actions.shape[0]
905
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
906
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
907
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
908
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
909
+
910
+ # Replace action token embeddings with noisy action embeddings
911
+ input_embeddings = self._replace_input_embeddings(
912
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
913
+ )
914
+
915
+ # Build multimodal embeddings and attention mask
916
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
917
+ input_embeddings, projected_patch_embeddings, attention_mask
918
+ )
919
+
920
+ # Forward pass through language model
921
+ language_model_output = self.language_model(
922
+ input_ids=None,
923
+ attention_mask=multimodal_attention_mask,
924
+ position_ids=None,
925
+ past_key_values=None,
926
+ inputs_embeds=multimodal_embeddings,
927
+ labels=None,
928
+ use_cache=None,
929
+ output_attentions=False,
930
+ output_hidden_states=True,
931
+ return_dict=True,
932
+ )
933
+
934
+ # Extract hidden states for action portion of response
935
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
936
+ actions_hidden_states = last_hidden_states[
937
+ :,
938
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
939
+ :,
940
+ ] # (B, act_chunk_len, D)
941
+
942
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
943
+ flow_pred = action_head.predict_flow(actions_hidden_states)
944
+ curr_noisy_actions += dt * flow_pred
945
+ time += dt
946
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
947
+
948
+ # Return final actions
949
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
950
+
951
+ def _run_flow_matching_prediction_V1(
952
+ self,
953
+ input_embeddings,
954
+ all_actions_mask,
955
+ noise,
956
+ action_head,
957
+ projected_patch_embeddings,
958
+ labels,
959
+ attention_mask,
960
+ NUM_PATCHES,
961
+ NUM_PROMPT_TOKENS,
962
+ noisy_action_projector,
963
+ proprio,
964
+ proprio_projector,
965
+ ):
966
+ """Run flow matching-based action prediction"""
967
+ # Clone embedding for reuse in each timestep
968
+ # orig_projected_patch_embeddings = projected_patch_embeddings.clone()
969
+
970
+ dt = -1.0 / action_head.num_flow_steps
971
+ dt = torch.tensor(dt, dtype=torch.bfloat16, device=labels.device)
972
+
973
+ curr_noisy_actions = noise
974
+ time = torch.tensor(1.0, dtype=torch.bfloat16, device=labels.device)
975
+
976
+ action_queries = self.action_queries.weight # (1, h)
977
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
978
+ # Replace action token embeddings with noisy action embeddings
979
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
980
+ # all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
981
+ # input_embeddings = input_embeddings * ~all_actions_mask
982
+
983
+ # Build multimodal embeddings and attention mask
984
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
985
+ input_embeddings, projected_patch_embeddings, attention_mask
986
+ )
987
+
988
+ # Forward pass through language model
989
+ language_model_output = self.language_model(
990
+ input_ids=None,
991
+ attention_mask=multimodal_attention_mask,
992
+ position_ids=None,
993
+ past_key_values=None,
994
+ inputs_embeds=multimodal_embeddings,
995
+ labels=None,
996
+ use_cache=None,
997
+ output_attentions=False,
998
+ output_hidden_states=True,
999
+ return_dict=True,
1000
+ )
1001
+
1002
+ # Extract hidden states for action portion of response
1003
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1004
+ actions_hidden_states = last_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1005
+ batch_size = last_hidden_states.shape[0]
1006
+ task_latent_states = last_hidden_states[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES ,-1)
1007
+ all_hidden_states = torch.cat((task_latent_states, actions_hidden_states), 2)
1008
+
1009
+ while time >= -dt / 2:
1010
+ B = curr_noisy_actions.shape[0]
1011
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
1012
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
1013
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
1014
+
1015
+ timesteps = torch.Tensor([1.0-time]).to(labels.device)
1016
+ timestep_embeddings = (
1017
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
1018
+ ) # (1, )
1019
+ timestep_embeddings = timestep_embeddings.unsqueeze(1) # (1, 1)
1020
+
1021
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1022
+ flow_pred = action_head.predict_flow(all_hidden_states,
1023
+ noisy_actions=curr_noisy_actions,
1024
+ timestep_embeddings=timestep_embeddings,
1025
+ noisy_action_projector=noisy_action_projector,
1026
+ proprio=proprio ,
1027
+ proprio_projector=proprio_projector)
1028
+
1029
+ curr_noisy_actions += dt * flow_pred
1030
+ time += dt
1031
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1032
+
1033
+ # Return final actions
1034
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1035
+
1036
+ def _run_diffusion_prediction(
1037
+ self,
1038
+ input_embeddings,
1039
+ all_actions_mask,
1040
+ noise,
1041
+ action_head,
1042
+ projected_patch_embeddings,
1043
+ labels,
1044
+ attention_mask,
1045
+ NUM_PATCHES,
1046
+ NUM_PROMPT_TOKENS,
1047
+ noisy_action_projector,
1048
+ ):
1049
+ """Run diffusion-based action prediction"""
1050
+ # Set diffusion timestep values
1051
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
1052
+ # Clone embedding for reuse in each timestep
1053
+ orig_projected_patch_embeddings = projected_patch_embeddings.clone()
1054
+ curr_noisy_actions = noise
1055
+
1056
+ # Reverse diffusion: Iteratively denoise to generate action prediction
1057
+ for t in action_head.noise_scheduler.timesteps:
1058
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
1059
+ # embedding, and diffusion timestep embedding)
1060
+ timesteps = torch.Tensor([t]).to(labels.device)
1061
+ diffusion_timestep_embeddings = (
1062
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
1063
+ ) # (B, llm_dim)
1064
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
1065
+
1066
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
1067
+ # (Later on, the positional embeddings will be added to them)
1068
+
1069
+ # For simplicity, append diffusion timestep embedding to the end of projected vision tokens
1070
+ projected_patch_embeddings = torch.cat(
1071
+ (orig_projected_patch_embeddings, diffusion_timestep_embeddings), dim=1
1072
+ )
1073
+
1074
+ # Reshape and project noisy actions into language embedding space
1075
+ B = curr_noisy_actions.shape[0]
1076
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
1077
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
1078
+ noisy_action_features = noisy_action_projector(curr_noisy_actions)
1079
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
1080
+
1081
+ # Replace action token embeddings with noisy action embeddings
1082
+ input_embeddings = self._replace_input_embeddings(
1083
+ input_embeddings.clone(), all_actions_mask, noisy_action_features
1084
+ )
1085
+
1086
+ # Build multimodal embeddings and attention mask
1087
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1088
+ input_embeddings, projected_patch_embeddings, attention_mask
1089
+ )
1090
+
1091
+ # Forward pass through language model
1092
+ language_model_output = self.language_model(
1093
+ input_ids=None,
1094
+ attention_mask=multimodal_attention_mask,
1095
+ position_ids=None,
1096
+ past_key_values=None,
1097
+ inputs_embeds=multimodal_embeddings,
1098
+ labels=None,
1099
+ use_cache=None,
1100
+ output_attentions=False,
1101
+ output_hidden_states=True,
1102
+ return_dict=True,
1103
+ )
1104
+
1105
+ # Extract hidden states for action portion of response
1106
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1107
+ actions_hidden_states = last_hidden_states[
1108
+ :,
1109
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1110
+ :,
1111
+ ] # (B, act_chunk_len, D)
1112
+
1113
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1114
+ noise_pred = action_head.predict_noise(actions_hidden_states)
1115
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1116
+
1117
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1118
+
1119
+ # Return final actions
1120
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1121
+
1122
+ def _run_diffusion_prediction_V1(
1123
+ self,
1124
+ input_embeddings,
1125
+ all_actions_mask,
1126
+ noise,
1127
+ action_head,
1128
+ projected_patch_embeddings,
1129
+ labels,
1130
+ attention_mask,
1131
+ NUM_PATCHES,
1132
+ NUM_PROMPT_TOKENS,
1133
+ noisy_action_projector,
1134
+ proprio,
1135
+ proprio_projector,
1136
+ ):
1137
+ """Run diffusion-based action prediction"""
1138
+ # Set diffusion timestep values
1139
+ action_head.noise_scheduler.set_timesteps(action_head.num_diffusion_steps)
1140
+ # Clone embedding for reuse in each timestep
1141
+ curr_noisy_actions = noise
1142
+
1143
+ # import pdb; pdb.set_trace()
1144
+
1145
+ action_queries = self.action_queries.weight # (1, h)
1146
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
1147
+ # Replace action token embeddings with noisy action embeddings
1148
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
1149
+ # input_embeddings = torch.cat((input_embeddings, action_queries), dim=1) # (b, n_tokens+chunk_size, h)
1150
+ # action_attention_mask = None
1151
+ # action_attention_mask = torch.full(
1152
+ # (action_queries.shape[0], action_queries.shape[1]),
1153
+ # fill_value=True,
1154
+ # dtype=attention_mask.dtype,
1155
+ # device=attention_mask.device,)
1156
+ # attention_mask = torch.cat([attention_mask, action_attention_mask], dim=1)
1157
+
1158
+ # Build multimodal embeddings and attention mask
1159
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1160
+ input_embeddings, projected_patch_embeddings, attention_mask
1161
+ )
1162
+
1163
+ # import pdb; pdb.set_trace()
1164
+ # Forward pass through language model
1165
+ language_model_output = self.language_model(
1166
+ input_ids=None,
1167
+ attention_mask=multimodal_attention_mask,
1168
+ position_ids=None,
1169
+ past_key_values=None,
1170
+ inputs_embeds=multimodal_embeddings,
1171
+ labels=None,
1172
+ use_cache=None,
1173
+ output_attentions=False,
1174
+ output_hidden_states=True,
1175
+ return_dict=True,
1176
+ )
1177
+ multi_layer_hidden_states = []
1178
+ # import pdb; pdb.set_trace()
1179
+ for item in language_model_output.hidden_states[0:]:
1180
+ # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1181
+ # Get hidden states for text portion of prompt+response (after the vision patches)
1182
+ text_hidden_states = item
1183
+ # Get hidden states for action portion of response
1184
+ actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1185
+ # import pdb; pdb.set_trace()
1186
+ batch_size = item.shape[0]
1187
+ task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1188
+ all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1189
+ multi_layer_hidden_states.append(all_hidden_states)
1190
+ # import pdb; pdb.set_trace()
1191
+ multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1192
+ # import pdb; pdb.set_trace()
1193
+
1194
+
1195
+
1196
+ # Reverse diffusion: Iteratively denoise to generate action prediction
1197
+ for t in action_head.noise_scheduler.timesteps:
1198
+ # Get diffusion model's noise prediction (conditioned on VLA latent embedding, current noisy action
1199
+ # embedding, and diffusion timestep embedding)
1200
+ timesteps = torch.Tensor([t]).to(labels.device)
1201
+ diffusion_timestep_embeddings = (
1202
+ action_head.time_encoder(timesteps).to(curr_noisy_actions.dtype).to(curr_noisy_actions.device)
1203
+ ) # (B, llm_dim)
1204
+ diffusion_timestep_embeddings = diffusion_timestep_embeddings.unsqueeze(1) # (B, 1, llm_dim)
1205
+
1206
+ # [Diffusion] Replace the embeddings of the action tokens with noisy actions
1207
+ # (Later on, the positional embeddings will be added to them)
1208
+
1209
+ # Reshape and project noisy actions into language embedding space
1210
+ B = curr_noisy_actions.shape[0]
1211
+ orig_curr_noisy_actions_shape = curr_noisy_actions.shape
1212
+ curr_noisy_actions = curr_noisy_actions.reshape(B, -1).unsqueeze(-1)
1213
+ curr_noisy_actions = curr_noisy_actions.reshape(orig_curr_noisy_actions_shape)
1214
+
1215
+ # Predict noise and update noisy actions: x_t -> x_{t-1}
1216
+ # noise_pred = action_head.predict_noise(actions_hidden_states)
1217
+ noise_pred = action_head.predict_noise(multi_layer_hidden_states,
1218
+ noisy_actions=curr_noisy_actions,
1219
+ timestep_embeddings = diffusion_timestep_embeddings,
1220
+ noisy_action_projector=noisy_action_projector,
1221
+ proprio=proprio ,
1222
+ proprio_projector=proprio_projector)
1223
+
1224
+ curr_noisy_actions = action_head.noise_scheduler.step(noise_pred, t, curr_noisy_actions).prev_sample
1225
+ curr_noisy_actions = curr_noisy_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1226
+
1227
+ # Return final actions
1228
+ return curr_noisy_actions.float().cpu().detach().numpy(), actions_hidden_states
1229
+
1230
+ def _regression_or_discrete_prediction_V1(
1231
+ self,
1232
+ input_embeddings,
1233
+ all_actions_mask,
1234
+ projected_patch_embeddings,
1235
+ attention_mask,
1236
+ labels,
1237
+ NUM_PATCHES,
1238
+ NUM_PROMPT_TOKENS,
1239
+ action_head=None,
1240
+ proprio=None,
1241
+ proprio_projector=None,
1242
+ ):
1243
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1244
+
1245
+ action_queries = self.action_queries.weight # (1, h)
1246
+ action_queries = action_queries.view(1, action_queries.shape[0], action_queries.shape[1]).repeat(input_embeddings.shape[0], 1, 1) # (b, chunk_size, h)
1247
+ # Replace action token embeddings with noisy action embeddings
1248
+ input_embeddings = self._replace_input_embeddings(input_embeddings.clone(), all_actions_mask, action_queries)
1249
+
1250
+ # Build multimodal embeddings and attention mask
1251
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1252
+ input_embeddings, projected_patch_embeddings, attention_mask
1253
+ )
1254
+
1255
+ # Forward pass through language model
1256
+ language_model_output = self.language_model(
1257
+ input_ids=None,
1258
+ attention_mask=multimodal_attention_mask,
1259
+ position_ids=None,
1260
+ past_key_values=None,
1261
+ inputs_embeds=multimodal_embeddings,
1262
+ labels=None,
1263
+ use_cache=None,
1264
+ output_attentions=False,
1265
+ output_hidden_states=True,
1266
+ return_dict=True,
1267
+ )
1268
+
1269
+ # Extract hidden states for action tokens
1270
+ multi_layer_hidden_states = []
1271
+ # import pdb; pdb.set_trace()
1272
+ for item in language_model_output.hidden_states[0:]:
1273
+ # last_hidden_states = output.hidden_states[-1] # (B, seq_len, D)
1274
+ # Get hidden states for text portion of prompt+response (after the vision patches)
1275
+ text_hidden_states = item
1276
+ # Get hidden states for action portion of response
1277
+ actions_hidden_states = text_hidden_states[:, NUM_PATCHES+ NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + NUM_TOKENS, :,].reshape(1, 1, NUM_TOKENS, -1).to(torch.bfloat16)
1278
+ # import pdb; pdb.set_trace()
1279
+ batch_size = item.shape[0]
1280
+ task_latten_states = item[:, :NUM_PATCHES].reshape(batch_size, 1, NUM_PATCHES , -1)
1281
+ all_hidden_states = torch.cat((task_latten_states, actions_hidden_states),2)
1282
+ multi_layer_hidden_states.append(all_hidden_states)
1283
+ # import pdb; pdb.set_trace()
1284
+ multi_layer_hidden_states = torch.cat(multi_layer_hidden_states, dim = 1)
1285
+ # import pdb; pdb.set_trace()
1286
+
1287
+ # Handle different prediction methods
1288
+ if action_head is not None:
1289
+ # L1 regression prediction
1290
+ normalized_actions = action_head.predict_action(multi_layer_hidden_states,
1291
+ proprio=proprio,
1292
+ proprio_projector=proprio_projector)
1293
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1294
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1295
+ else:
1296
+ # Discrete token-based prediction
1297
+ predicted_action_token_ids = (
1298
+ language_model_output.logits[
1299
+ :,
1300
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1301
+ ]
1302
+ .argmax(dim=2)
1303
+ .cpu()
1304
+ .numpy()
1305
+ )
1306
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1307
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1308
+ normalized_actions = self.bin_centers[discretized_actions]
1309
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1310
+
1311
+ return normalized_actions, actions_hidden_states
1312
+
1313
+ def _regression_or_discrete_prediction(
1314
+ self,
1315
+ input_embeddings,
1316
+ all_actions_mask,
1317
+ projected_patch_embeddings,
1318
+ attention_mask,
1319
+ labels,
1320
+ NUM_PATCHES,
1321
+ NUM_PROMPT_TOKENS,
1322
+ action_head=None,
1323
+ ):
1324
+ """Run L1 regression-based continuous action prediction or discrete action tokens prediction."""
1325
+ # Zero out action token embeddings
1326
+ all_actions_mask = all_actions_mask.unsqueeze(-1) # (B, seq_len, 1)
1327
+ input_embeddings = input_embeddings * ~all_actions_mask
1328
+
1329
+ # Build multimodal embeddings and attention mask
1330
+ multimodal_embeddings, multimodal_attention_mask = self._build_multimodal_attention(
1331
+ input_embeddings, projected_patch_embeddings, attention_mask
1332
+ )
1333
+
1334
+ # Forward pass through language model
1335
+ language_model_output = self.language_model(
1336
+ input_ids=None,
1337
+ attention_mask=multimodal_attention_mask,
1338
+ position_ids=None,
1339
+ past_key_values=None,
1340
+ inputs_embeds=multimodal_embeddings,
1341
+ labels=None,
1342
+ use_cache=None,
1343
+ output_attentions=False,
1344
+ output_hidden_states=True,
1345
+ return_dict=True,
1346
+ )
1347
+
1348
+ # Extract hidden states for action tokens
1349
+ last_hidden_states = language_model_output.hidden_states[-1] # (B, seq_len, D)
1350
+ actions_hidden_states = last_hidden_states[
1351
+ :,
1352
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1353
+ :,
1354
+ ] # (B, act_chunk_len, D)
1355
+
1356
+ # Handle different prediction methods
1357
+ if action_head is not None:
1358
+ # L1 regression prediction
1359
+ normalized_actions = action_head.predict_action(actions_hidden_states)
1360
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1361
+ normalized_actions = normalized_actions.float().cpu().detach().numpy()
1362
+ else:
1363
+ # Discrete token-based prediction
1364
+ predicted_action_token_ids = (
1365
+ language_model_output.logits[
1366
+ :,
1367
+ NUM_PATCHES + NUM_PROMPT_TOKENS : NUM_PATCHES + NUM_PROMPT_TOKENS + ACTION_DIM * NUM_ACTIONS_CHUNK,
1368
+ ]
1369
+ .argmax(dim=2)
1370
+ .cpu()
1371
+ .numpy()
1372
+ )
1373
+ discretized_actions = self.vocab_size - predicted_action_token_ids
1374
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
1375
+ normalized_actions = self.bin_centers[discretized_actions]
1376
+ normalized_actions = normalized_actions.reshape(NUM_ACTIONS_CHUNK, ACTION_DIM)
1377
+
1378
+ return normalized_actions, actions_hidden_states
1379
+
1380
+ def predict_action(
1381
+ self,
1382
+ input_ids: Optional[torch.LongTensor] = None,
1383
+ unnorm_key: Optional[str] = None,
1384
+ proprio=None,
1385
+ proprio_projector=None,
1386
+ action_head=None,
1387
+ noisy_action_projector=None,
1388
+ use_film: bool = False,
1389
+ **kwargs: str,
1390
+ ) -> np.ndarray:
1391
+ """Predict actions from input sequence, with options for different prediction methods.
1392
+
1393
+ Args:
1394
+ input_ids: Input token ids
1395
+ unnorm_key: Key for unnormalization statistics
1396
+ proprio: Proprioceptive features
1397
+ proprio_projector: Projector for proprioceptive features
1398
+ action_head: Optional head for L1 regression or diffusion-based prediction
1399
+ noisy_action_projector: Projector for noisy actions in diffusion-based prediction
1400
+ use_film: Whether to use FiLM conditioning
1401
+ **kwargs: Additional arguments including pixel_values and attention_mask
1402
+
1403
+ Returns:
1404
+ Tuple of (unnormalized_actions, action_hidden_states)
1405
+ """
1406
+ # import pdb; pdb.set_trace()
1407
+ # If the special empty token ('') does not already appear after the colon (':') token in the prompt
1408
+ # (after "OUT:" or "ASSISTANT:"), insert it to match the inputs seen at training time
1409
+
1410
+ # 如果是 minivla, 不用加这个判断!!!!!
1411
+ # if not torch.all(input_ids[:, -1] == 29871):
1412
+ # input_ids = torch.cat(
1413
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
1414
+ # )
1415
+
1416
+
1417
+ pixel_values = kwargs["pixel_values"] # [1, 12, 224, 224]
1418
+ attention_mask = kwargs["attention_mask"] #
1419
+
1420
+ # Create fake labels tensor (needed for action mask)
1421
+ labels = input_ids.clone()
1422
+ labels[:] = IGNORE_INDEX
1423
+
1424
+ # Get number of tokens in prompt (excluding the start token)
1425
+ NUM_PROMPT_TOKENS = input_ids.shape[-1] - 1 # Subtract action tokens and stop token
1426
+
1427
+ # import pdb; pdb.set_trace()
1428
+
1429
+ # Prepare inputs by adding necessary tokens
1430
+ input_ids, attention_mask = self._prepare_input_for_action_prediction(input_ids, attention_mask)
1431
+
1432
+ # Update labels tensor for action mask computation later
1433
+ labels = self._prepare_labels_for_action_prediction(labels, input_ids)
1434
+
1435
+ # Get input embeddings and action masks
1436
+ input_embeddings = self.get_input_embeddings()(input_ids)
1437
+ all_actions_mask = self._process_action_masks(labels)
1438
+
1439
+ # Extract language embeddings
1440
+ language_embeddings = input_embeddings[~all_actions_mask].reshape(
1441
+ input_embeddings.shape[0], -1, input_embeddings.shape[2]
1442
+ )
1443
+
1444
+ # Process vision features
1445
+ projected_patch_embeddings = self._process_vision_features(pixel_values, language_embeddings, use_film)
1446
+
1447
+ # Add proprioceptive features if provided
1448
+ use_proprio = proprio_projector is not None and proprio is not None
1449
+ if use_proprio:
1450
+ proprio = torch.Tensor(proprio).to(projected_patch_embeddings.device, dtype=projected_patch_embeddings.dtype)
1451
+ if self.version == 'v1':
1452
+ pass
1453
+ else:
1454
+ projected_patch_embeddings = self._process_proprio_features(
1455
+ projected_patch_embeddings, proprio, proprio_projector
1456
+ )
1457
+ # import pdb; pdb.set_trace()
1458
+ # Use diffusion if provided, otherwise use regression or discrete prediction
1459
+ use_diffusion = noisy_action_projector is not None and hasattr(action_head, "noise_scheduler")
1460
+ use_flow_matching = noisy_action_projector is not None and hasattr(action_head, "sample_actions")
1461
+
1462
+
1463
+ # Calculate number of patches (including proprio token and/or diffusion timestep embedding if present)
1464
+ NUM_PATCHES = self.vision_backbone.get_num_patches() * self.vision_backbone.get_num_images_in_input()
1465
+ if self.version == 'v1':
1466
+ # if use_diffusion:
1467
+ # NUM_PATCHES += 1
1468
+ pass
1469
+ else:
1470
+ if use_proprio:
1471
+ NUM_PATCHES += 1
1472
+ if use_diffusion:
1473
+ NUM_PATCHES += 1
1474
+
1475
+ # import pdb; pdb.set_trace()
1476
+ if use_flow_matching:
1477
+ # Sample random noise with shape equal to output action, used as the starting state for flow matching
1478
+ noise = action_head.sample_noise((1, NUM_ACTIONS_CHUNK, ACTION_DIM),device=input_embeddings.device)
1479
+ # breakpoint()
1480
+ if self.version == 'v1':
1481
+ normalized_actions, actions_hidden_states = self._run_flow_matching_prediction_V1(
1482
+ input_embeddings, # [1, 86, 4096]
1483
+ all_actions_mask, # [1, 86]
1484
+ noise, # [1,8, 7]
1485
+ action_head,
1486
+ projected_patch_embeddings, # [1, 512, 4096]
1487
+ labels, # [1, 86]
1488
+ attention_mask, # [1, 86]
1489
+ NUM_PATCHES, # 512
1490
+ NUM_PROMPT_TOKENS, # 28
1491
+ noisy_action_projector,
1492
+ proprio, # [8]
1493
+ proprio_projector,
1494
+ )
1495
+ else:
1496
+ # Run flow matching-based prediction
1497
+ normalized_actions, actions_hidden_states = self._run_flow_matching_prediction(
1498
+ input_embeddings,
1499
+ all_actions_mask,
1500
+ noise,
1501
+ action_head,
1502
+ projected_patch_embeddings,
1503
+ labels,
1504
+ attention_mask,
1505
+ NUM_PATCHES,
1506
+ NUM_PROMPT_TOKENS,
1507
+ noisy_action_projector
1508
+ )
1509
+ elif use_diffusion:
1510
+ # Sample random noise with shape equal to output action, used as the starting state for reverse diffusion
1511
+ noise = torch.randn(
1512
+ size=(1, NUM_ACTIONS_CHUNK, ACTION_DIM), device=input_embeddings.device, dtype=input_embeddings.dtype
1513
+ )
1514
+ # import pdb; pdb.set_trace()
1515
+ if self.version == 'v1':
1516
+
1517
+ # import pdb; pdb.set_trace()
1518
+ # Run diffusion-based prediction
1519
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction_V1(
1520
+ input_embeddings, # [1, 86, 4096]
1521
+ all_actions_mask, # [1, 86]
1522
+ noise, # [1,8, 7]
1523
+ action_head,
1524
+ projected_patch_embeddings, # [1, 512, 4096]
1525
+ labels, # [1, 86]
1526
+ attention_mask, # [1, 86]
1527
+ NUM_PATCHES, # 512
1528
+ NUM_PROMPT_TOKENS, # 28
1529
+ noisy_action_projector,
1530
+ proprio, # [8]
1531
+ proprio_projector,
1532
+ )
1533
+ else:
1534
+ # Run diffusion-based prediction
1535
+ normalized_actions, actions_hidden_states = self._run_diffusion_prediction(
1536
+ input_embeddings,
1537
+ all_actions_mask,
1538
+ noise,
1539
+ action_head,
1540
+ projected_patch_embeddings,
1541
+ labels,
1542
+ attention_mask,
1543
+ NUM_PATCHES,
1544
+ NUM_PROMPT_TOKENS,
1545
+ noisy_action_projector,
1546
+ )
1547
+
1548
+ else:
1549
+ if self.version == 'v1':
1550
+ # Run regression or discrete token-based prediction
1551
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction_V1(
1552
+ input_embeddings,
1553
+ all_actions_mask,
1554
+ projected_patch_embeddings,
1555
+ attention_mask,
1556
+ labels,
1557
+ NUM_PATCHES,
1558
+ NUM_PROMPT_TOKENS,
1559
+ action_head=action_head,
1560
+ proprio=proprio, # [8]
1561
+ proprio_projector=proprio_projector,
1562
+ )
1563
+ else:
1564
+ # Run regression or discrete token-based prediction
1565
+ normalized_actions, actions_hidden_states = self._regression_or_discrete_prediction(
1566
+ input_embeddings,
1567
+ all_actions_mask,
1568
+ projected_patch_embeddings,
1569
+ attention_mask,
1570
+ labels,
1571
+ NUM_PATCHES,
1572
+ NUM_PROMPT_TOKENS,
1573
+ action_head,
1574
+ )
1575
+
1576
+ # import pdb; pdb.set_trace()
1577
+ # Unnormalize predicted actions
1578
+ actions = self._unnormalize_actions(normalized_actions, unnorm_key)
1579
+
1580
+ return actions, actions_hidden_states
1581
+
1582
+ @staticmethod
1583
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
1584
+ """Validate and resolve the unnormalization key for action statistics"""
1585
+ if unnorm_key is None:
1586
+ assert len(norm_stats) == 1, (
1587
+ f"Your model was trained on more than one dataset, "
1588
+ f"please pass a `unnorm_key` from the following options to choose the statistics "
1589
+ f"used for un-normalizing actions: {norm_stats.keys()}"
1590
+ )
1591
+ unnorm_key = next(iter(norm_stats.keys()))
1592
+
1593
+ assert unnorm_key in norm_stats, (
1594
+ f"The `unnorm_key` you chose is not in the set of available dataset statistics, "
1595
+ f"please choose from: {norm_stats.keys()}"
1596
+ )
1597
+ return unnorm_key
1598
+
1599
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
1600
+ """Get the dimensionality of the policy's action space."""
1601
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1602
+ return len(self.norm_stats[unnorm_key]["action"]["min"])
1603
+
1604
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
1605
+ """Get all the logged statistics for the given dataset."""
1606
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
1607
+ return self.norm_stats[unnorm_key]["action"]
noisy_action_projector--checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a82abb2054bf0de6087e91d393dd777f7c5a925fe741d6016ea17186c050b6e8
3
+ size 1613600
preprocessor_config.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
4
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
5
+ },
6
+ "image_processor_type": "PrismaticImageProcessor",
7
+ "image_resize_strategy": "resize-naive",
8
+ "input_sizes": [
9
+ [
10
+ 3,
11
+ 224,
12
+ 224
13
+ ],
14
+ [
15
+ 3,
16
+ 224,
17
+ 224
18
+ ]
19
+ ],
20
+ "interpolations": [
21
+ "bicubic",
22
+ "bicubic"
23
+ ],
24
+ "means": [
25
+ [
26
+ 0.485,
27
+ 0.456,
28
+ 0.406
29
+ ],
30
+ [
31
+ 0.5,
32
+ 0.5,
33
+ 0.5
34
+ ]
35
+ ],
36
+ "processor_class": "PrismaticProcessor",
37
+ "stds": [
38
+ [
39
+ 0.229,
40
+ 0.224,
41
+ 0.225
42
+ ],
43
+ [
44
+ 0.5,
45
+ 0.5,
46
+ 0.5
47
+ ]
48
+ ],
49
+ "tvf_crop_params": [
50
+ {
51
+ "output_size": [
52
+ 224,
53
+ 224
54
+ ]
55
+ },
56
+ {
57
+ "output_size": [
58
+ 224,
59
+ 224
60
+ ]
61
+ }
62
+ ],
63
+ "tvf_do_letterbox": false,
64
+ "tvf_letterbox_fill": null,
65
+ "tvf_normalize_params": [
66
+ {
67
+ "inplace": false,
68
+ "mean": [
69
+ 0.484375,
70
+ 0.455078125,
71
+ 0.40625
72
+ ],
73
+ "std": [
74
+ 0.228515625,
75
+ 0.2236328125,
76
+ 0.224609375
77
+ ]
78
+ },
79
+ {
80
+ "inplace": false,
81
+ "mean": [
82
+ 0.5,
83
+ 0.5,
84
+ 0.5
85
+ ],
86
+ "std": [
87
+ 0.5,
88
+ 0.5,
89
+ 0.5
90
+ ]
91
+ }
92
+ ],
93
+ "tvf_resize_params": [
94
+ {
95
+ "antialias": true,
96
+ "interpolation": 3,
97
+ "max_size": null,
98
+ "size": [
99
+ 224,
100
+ 224
101
+ ]
102
+ },
103
+ {
104
+ "antialias": true,
105
+ "interpolation": 3,
106
+ "max_size": null,
107
+ "size": [
108
+ 224,
109
+ 224
110
+ ]
111
+ }
112
+ ],
113
+ "use_fused_vision_backbone": true
114
+ }
processing_prismatic.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ # breakpoint()
138
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
139
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
140
+ img_idx_t = TVF.to_tensor(img_idx)
141
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
142
+ imgs_t.append(img_idx_t)
143
+
144
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
145
+ img_t = torch.vstack(imgs_t)
146
+
147
+ return img_t
148
+
149
+ def preprocess(
150
+ self,
151
+ images: Union[Image.Image, List[Image.Image]],
152
+ return_tensors: Optional[Union[str, TensorType]] = None,
153
+ **_: str,
154
+ ) -> BatchFeature:
155
+ """
156
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
157
+ explicitly only handle PIL.Image.Image instances for simplicity.
158
+
159
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
160
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
161
+
162
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
163
+ """
164
+ if not isinstance(images, list):
165
+ images = [images]
166
+
167
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
168
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
169
+
170
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
171
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
172
+
173
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
174
+ return self.preprocess(images, **kwargs)
175
+
176
+
177
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
178
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
179
+ class PrismaticProcessor(ProcessorMixin):
180
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
181
+ image_processor_class: str = "AutoImageProcessor"
182
+ tokenizer_class: str = "AutoTokenizer"
183
+
184
+ def __init__(
185
+ self,
186
+ image_processor: Optional[ImageProcessingMixin] = None,
187
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
188
+ ) -> None:
189
+ super().__init__(image_processor, tokenizer)
190
+
191
+ def __call__(
192
+ self,
193
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
194
+ images: Union[Image.Image, List[Image.Image]],
195
+ padding: Union[bool, str, PaddingStrategy] = False,
196
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
197
+ max_length: Optional[int] = None,
198
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
199
+ ) -> BatchFeature:
200
+ """
201
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
202
+ forwards images to PrismaticImageProcessor.
203
+
204
+ @param text: The (batch) of text to encode; must be a string or list of strings.
205
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
206
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
207
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
208
+ @param max_length: Maximum length (in tokens) to truncate
209
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
210
+
211
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
212
+ """
213
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
214
+ text_inputs = self.tokenizer(
215
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
216
+ )
217
+
218
+ # [Validate] Need same number of images and text inputs!
219
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
220
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
221
+
222
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
223
+
224
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
225
+ def batch_decode(
226
+ self,
227
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
228
+ skip_special_tokens: bool = False,
229
+ clean_up_tokenization_spaces: Optional[bool] = None,
230
+ **kwargs: str,
231
+ ) -> List[str]:
232
+ return self.tokenizer.batch_decode(
233
+ sequences=sequences,
234
+ skip_special_tokens=skip_special_tokens,
235
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
236
+ **kwargs,
237
+ )
238
+
239
+ def decode(
240
+ self,
241
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
242
+ skip_special_tokens: bool = False,
243
+ clean_up_tokenization_spaces: Optional[bool] = None,
244
+ **kwargs: str,
245
+ ) -> str:
246
+ return self.tokenizer.decode(
247
+ token_ids=token_ids,
248
+ skip_special_tokens=skip_special_tokens,
249
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
250
+ **kwargs,
251
+ )
252
+
253
+ @property
254
+ def model_input_names(self) -> List[str]:
255
+ tokenizer_input_names = self.tokenizer.model_input_names
256
+ image_processor_input_names = self.image_processor.model_input_names
257
+
258
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
4
+ },
5
+ "processor_class": "PrismaticProcessor"
6
+ }
proprio_projector--checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17e19497e8139475764063850684538879dbe9ce5bcf94664be2209c2a089cb7
3
+ size 1626104
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c5ae00e602b8860cbd784ba82a8aa14e8feecec692e7076590d014d7b7fdafa
3
+ size 11421896
tokenizer_config.json ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "auto_map": {
198
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
199
+ },
200
+ "bos_token": null,
201
+ "clean_up_tokenization_spaces": false,
202
+ "eos_token": "<|endoftext|>",
203
+ "errors": "replace",
204
+ "extra_special_tokens": {},
205
+ "model_max_length": 131072,
206
+ "pad_token": "<|endoftext|>",
207
+ "processor_class": "PrismaticProcessor",
208
+ "split_special_tokens": false,
209
+ "tokenizer_class": "Qwen2Tokenizer",
210
+ "unk_token": null
211
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff