Upload model
Browse files- README.md +0 -0
- chat_template.jinja +91 -0
- config.json +1118 -0
- configuration_longcat_next.py +153 -0
- configuration_longcat_ngram.py +218 -0
- cosy24k_vocoder.py +552 -0
- generation_config.json +36 -0
- image_refiner.py +748 -0
- model-00001-of-00022.safetensors +3 -0
- model-00002-of-00022.safetensors +3 -0
- model-00003-of-00022.safetensors +3 -0
- model-00004-of-00022.safetensors +3 -0
- model-00005-of-00022.safetensors +3 -0
- model-00006-of-00022.safetensors +3 -0
- model-00007-of-00022.safetensors +3 -0
- model-00008-of-00022.safetensors +3 -0
- model-00009-of-00022.safetensors +3 -0
- model-00010-of-00022.safetensors +3 -0
- model-00011-of-00022.safetensors +3 -0
- model-00012-of-00022.safetensors +3 -0
- model-00013-of-00022.safetensors +3 -0
- model-00014-of-00022.safetensors +3 -0
- model-00015-of-00022.safetensors +3 -0
- model-00016-of-00022.safetensors +3 -0
- model-00017-of-00022.safetensors +3 -0
- model-00018-of-00022.safetensors +3 -0
- model-00019-of-00022.safetensors +3 -0
- model-00020-of-00022.safetensors +3 -0
- model-00021-of-00022.safetensors +3 -0
- model-00022-of-00022.safetensors +3 -0
- model.safetensors.index.json +0 -0
- model_extra_tensors.safetensors +3 -0
- modeling_longcat_next.py +824 -0
- modeling_longcat_ngram.py +426 -0
- modular_longcat_next.py +157 -0
- modular_longcat_next_audio.py +2039 -0
- modular_longcat_next_visual.py +1077 -0
- parse_model_response.py +158 -0
- preprocessor_config.json +19 -0
- processing_longcat_next.py +279 -0
- processor_config.json +6 -0
- quantization_config.json +75 -0
- refiner_modules.py +1330 -0
- special_tokens_map.json +85 -0
- tokenizer.json +0 -0
- tokenizer_config.json +2300 -0
README.md
ADDED
|
File without changes
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- set tool_choice = tool_choice | default('auto') %}
|
| 2 |
+
{%- set ns = namespace(tool_types = [], last_query_index = -1, suffix_to_move = '') %}
|
| 3 |
+
|
| 4 |
+
{%- if tools and tool_choice != 'none' %}
|
| 5 |
+
{{- "<longcat_tool_declare>\n"-}}
|
| 6 |
+
{{- "# Tools\n" }}
|
| 7 |
+
{{- "You have access to the following tools:\n\n" }}
|
| 8 |
+
{%- for tool in tools %}
|
| 9 |
+
{%- if tool.type not in ns.tool_types %}
|
| 10 |
+
{%- set ns.tool_types = ns.tool_types + [tool.type] %}
|
| 11 |
+
{{- "## Tool namespace: " ~ tool.type ~ "\n\n" }}
|
| 12 |
+
{%- endif %}
|
| 13 |
+
{%- if tool.type == 'code_interpreter' %}
|
| 14 |
+
{%- set tool = {"type":"code_interpreter","function":{"name":"code_interpreter_preview","description":"The code will be executed in a stateful Jupyter notebook sandbox environment, only supports local computation, data processing, and file operations.\nCode sandbox environment (network isolated) Any external network requests or online API calls are prohibited.\nIf online functionality is needed, please use other permitted tools.\nCode will respond with the output of the execution or time out after 60.0 seconds. ","parameters":{"type":"object","properties":{"language":{"type":"string","description":"The programming language of the code to be executed. Available values: python (Default), java, go, js, ts, c, c++."},"code":{"type":"string","description":"Python code to be executed must not include the following:\n- Importing network libraries such as requests, httplib, etc.\n- Any form of HTTP requests.\n- External API calls.\n- Network port operations. Example: ```python\nimport pandas as pd\npd.DataFrame({'A':[1,2]})\n```"},"timeout":{"type":"number","description":"The maximum execution time of the code, in seconds. Default is 60.0."}}},"required":["code"]}} %}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{{- "### Tool name: " + tool.function.name + "\n" }}
|
| 17 |
+
{{- "Description: " + tool.function.description + "\n\n" }}
|
| 18 |
+
{{- "InputSchema: " + tool.function.parameters | tojson(ensure_ascii=False) + "\n\n" }}
|
| 19 |
+
{%- endfor %}
|
| 20 |
+
{{- '**Note**: For each function call, output the function name and arguments within the following XML format:\n<longcat_tool_call>{function-name}\n<longcat_arg_key>{arg-key-1}</longcat_arg_key>\n<longcat_arg_value>{arg-value-1}</longcat_arg_value>\n<longcat_arg_key>{arg-key-2}</longcat_arg_key>\n<longcat_arg_value>{arg-value-2}</longcat_arg_value>\n...\n</longcat_tool_call>\n' }}
|
| 21 |
+
{{- "</longcat_tool_declare>"-}}
|
| 22 |
+
{%- for idx in range(messages|length - 1) %}
|
| 23 |
+
{%- set msg = messages[idx] %}
|
| 24 |
+
{%- if msg.role == 'assistant' and not msg.tool_calls %}
|
| 25 |
+
{%- set ns.last_query_index = idx %}
|
| 26 |
+
{%- endif %}
|
| 27 |
+
{%- endfor%}
|
| 28 |
+
{%- endif %}
|
| 29 |
+
|
| 30 |
+
{%- for msg in messages %}
|
| 31 |
+
{%- if msg.role == "system" %}
|
| 32 |
+
{{- "<longcat_system>" + msg.content }}
|
| 33 |
+
{%- elif msg.role == "user" %}
|
| 34 |
+
{{- "<longcat_user>" }}
|
| 35 |
+
{%- if msg["files"] %}
|
| 36 |
+
{{- '<longcat_files>\n' ~ msg.files | tojson(indent=2) ~ '\n</longcat_files>' }}
|
| 37 |
+
{%- endif %}
|
| 38 |
+
|
| 39 |
+
{%- if add_generation_prompt and loop.last and msg.content is string and msg.content.endswith("<longcat_img_start>") %}
|
| 40 |
+
{%- set ns.suffix_to_move = "<longcat_img_start>" %}
|
| 41 |
+
{{- msg.content[:-19] }}
|
| 42 |
+
{%- elif add_generation_prompt and loop.last and msg.content is string and msg.content.endswith("<longcat_audiogen_start>") %}
|
| 43 |
+
{%- set ns.suffix_to_move = "<longcat_audiogen_start>" %}
|
| 44 |
+
{{- msg.content[:-24] }}
|
| 45 |
+
{%- else %}
|
| 46 |
+
{{- msg.content }}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
|
| 49 |
+
{%- elif msg.role == "assistant" %}
|
| 50 |
+
{{- "<longcat_assistant>" }}
|
| 51 |
+
{%- if enable_thinking == true and msg.reasoning_content and ns.tool_types != [] and loop.index0 > ns.last_query_index %}
|
| 52 |
+
{{- "\n<longcat_think>\n" ~ msg.reasoning_content ~ "\n</longcat_think>\n" }}
|
| 53 |
+
{%- endif %}
|
| 54 |
+
{%- if msg.content%}
|
| 55 |
+
{{- msg.content }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if msg.tool_calls %}
|
| 58 |
+
{%- for tool_call in msg.tool_calls -%}
|
| 59 |
+
{{- "<longcat_tool_call>" ~ tool_call.function.name ~ "\n" -}}
|
| 60 |
+
{% set _args = tool_call.function.arguments %}
|
| 61 |
+
{% for k, v in _args.items() %}
|
| 62 |
+
{{- "<longcat_arg_key>" ~ k ~ "</longcat_arg_key>\n" -}}
|
| 63 |
+
{{- "<longcat_arg_value>" ~ (v if v is string else v | tojson(ensure_ascii=False)) ~ "</longcat_arg_value>\n" -}}
|
| 64 |
+
{% endfor %}
|
| 65 |
+
{{- "</longcat_tool_call>\n" }}
|
| 66 |
+
{%- endfor %}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- "</longcat_s>" -}}
|
| 69 |
+
{%- elif msg.role == "tool" %}
|
| 70 |
+
{%- if messages[loop.index0 - 1].role != "tool"%}
|
| 71 |
+
{{- "<longcat_user>" -}}
|
| 72 |
+
{%- endif %}
|
| 73 |
+
{{- "<longcat_tool_response>" ~ msg.content ~ "</longcat_tool_response>"-}}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{%- endfor %}
|
| 76 |
+
{%- if add_generation_prompt %}
|
| 77 |
+
{%- if enable_thinking == true %}
|
| 78 |
+
{{- " /think_on" }}
|
| 79 |
+
{%- if thinking_budget %}
|
| 80 |
+
{%- if thinking_budget < 1024 %}
|
| 81 |
+
{%- set thinking_budget = 1024 %}
|
| 82 |
+
{%- endif%}
|
| 83 |
+
{{- "\nthinking_budget: < " ~ thinking_budget ~ "."}}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{{- " <longcat_assistant><longcat_think>\n"}}
|
| 86 |
+
{%- elif enable_thinking == false %}
|
| 87 |
+
{{- " /think_off <longcat_assistant><longcat_think>\n\n</longcat_think>\n" }}
|
| 88 |
+
{%- else %}
|
| 89 |
+
{{- "<longcat_assistant>" ~ ns.suffix_to_move }}
|
| 90 |
+
{%- endif %}
|
| 91 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LongcatNextForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"audio_config": {
|
| 8 |
+
"activation_dropout": 0.0,
|
| 9 |
+
"activation_function": "gelu",
|
| 10 |
+
"apply_spec_augment": false,
|
| 11 |
+
"attention_dropout": 0.0,
|
| 12 |
+
"audio_delim_token_id": 131116,
|
| 13 |
+
"audio_end_token_id": 131104,
|
| 14 |
+
"audio_head_transformer_dims": 3072,
|
| 15 |
+
"audio_head_transformer_ffn_scale": 16,
|
| 16 |
+
"audio_head_transformer_layers": 4,
|
| 17 |
+
"audio_pad_token_id": 131105,
|
| 18 |
+
"audio_start_token_id": 131103,
|
| 19 |
+
"audiogen_end_token_id": 131124,
|
| 20 |
+
"audiogen_start_token_id": 131123,
|
| 21 |
+
"audiotext_end_token_id": 131121,
|
| 22 |
+
"audiotext_pad_token_id": 131122,
|
| 23 |
+
"audiotext_start_token_id": 131120,
|
| 24 |
+
"avg_pooler": 4,
|
| 25 |
+
"classifier_proj_size": 256,
|
| 26 |
+
"cosy24kvocoder_config": {
|
| 27 |
+
"_name_or_path": "",
|
| 28 |
+
"add_cross_attention": false,
|
| 29 |
+
"architectures": null,
|
| 30 |
+
"bad_words_ids": null,
|
| 31 |
+
"begin_suppress_tokens": null,
|
| 32 |
+
"bos_token_id": null,
|
| 33 |
+
"chunk_size_feed_forward": 0,
|
| 34 |
+
"cross_attention_hidden_size": null,
|
| 35 |
+
"decoder_start_token_id": null,
|
| 36 |
+
"diversity_penalty": 0.0,
|
| 37 |
+
"do_sample": false,
|
| 38 |
+
"dtype": null,
|
| 39 |
+
"early_stopping": false,
|
| 40 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 41 |
+
"eos_token_id": null,
|
| 42 |
+
"exponential_decay_length_penalty": null,
|
| 43 |
+
"finetuning_task": null,
|
| 44 |
+
"forced_bos_token_id": null,
|
| 45 |
+
"forced_eos_token_id": null,
|
| 46 |
+
"id2label": {
|
| 47 |
+
"0": "LABEL_0",
|
| 48 |
+
"1": "LABEL_1"
|
| 49 |
+
},
|
| 50 |
+
"is_decoder": false,
|
| 51 |
+
"is_encoder_decoder": false,
|
| 52 |
+
"label2id": {
|
| 53 |
+
"LABEL_0": 0,
|
| 54 |
+
"LABEL_1": 1
|
| 55 |
+
},
|
| 56 |
+
"length_penalty": 1.0,
|
| 57 |
+
"max_length": 20,
|
| 58 |
+
"min_length": 0,
|
| 59 |
+
"model_type": "",
|
| 60 |
+
"no_repeat_ngram_size": 0,
|
| 61 |
+
"num_beam_groups": 1,
|
| 62 |
+
"num_beams": 1,
|
| 63 |
+
"num_return_sequences": 1,
|
| 64 |
+
"output_attentions": false,
|
| 65 |
+
"output_hidden_states": false,
|
| 66 |
+
"output_scores": false,
|
| 67 |
+
"pad_token_id": null,
|
| 68 |
+
"prefix": null,
|
| 69 |
+
"problem_type": null,
|
| 70 |
+
"pruned_heads": {},
|
| 71 |
+
"remove_invalid_values": false,
|
| 72 |
+
"repetition_penalty": 1.0,
|
| 73 |
+
"return_dict": true,
|
| 74 |
+
"return_dict_in_generate": false,
|
| 75 |
+
"sep_token_id": null,
|
| 76 |
+
"suppress_tokens": null,
|
| 77 |
+
"task_specific_params": null,
|
| 78 |
+
"temperature": 1.0,
|
| 79 |
+
"tf_legacy_loss": false,
|
| 80 |
+
"tie_encoder_decoder": false,
|
| 81 |
+
"tie_word_embeddings": true,
|
| 82 |
+
"tokenizer_class": null,
|
| 83 |
+
"top_k": 50,
|
| 84 |
+
"top_p": 1.0,
|
| 85 |
+
"torchscript": false,
|
| 86 |
+
"typical_p": 1.0,
|
| 87 |
+
"use_bfloat16": false,
|
| 88 |
+
"weight_path": "WEIGHT_PATH_TO_LONGCAT_NEXT/cosy24k_vocoder/hift.pt"
|
| 89 |
+
},
|
| 90 |
+
"d_model": 1280,
|
| 91 |
+
"decoder_attention_heads": 20,
|
| 92 |
+
"decoder_ffn_dim": 5120,
|
| 93 |
+
"decoder_kernel_size": 3,
|
| 94 |
+
"decoder_layerdrop": 0.0,
|
| 95 |
+
"decoder_layers": 8,
|
| 96 |
+
"decoder_stride_size": 2,
|
| 97 |
+
"dropout": 0.0,
|
| 98 |
+
"encoder_attention_heads": 20,
|
| 99 |
+
"encoder_ffn_dim": 5120,
|
| 100 |
+
"encoder_layerdrop": 0.0,
|
| 101 |
+
"encoder_layers": 32,
|
| 102 |
+
"flow_matching_config": {
|
| 103 |
+
"_name_or_path": "",
|
| 104 |
+
"act_fn": "gelu",
|
| 105 |
+
"add_cross_attention": false,
|
| 106 |
+
"architectures": null,
|
| 107 |
+
"attention_head_dim": 64,
|
| 108 |
+
"bad_words_ids": null,
|
| 109 |
+
"begin_suppress_tokens": null,
|
| 110 |
+
"bos_token_id": null,
|
| 111 |
+
"cal_mel_mae": true,
|
| 112 |
+
"cfm_params": {
|
| 113 |
+
"_name_or_path": "",
|
| 114 |
+
"add_cross_attention": false,
|
| 115 |
+
"architectures": null,
|
| 116 |
+
"bad_words_ids": null,
|
| 117 |
+
"begin_suppress_tokens": null,
|
| 118 |
+
"bos_token_id": null,
|
| 119 |
+
"chunk_size_feed_forward": 0,
|
| 120 |
+
"cross_attention_hidden_size": null,
|
| 121 |
+
"decoder_start_token_id": null,
|
| 122 |
+
"diversity_penalty": 0.0,
|
| 123 |
+
"do_sample": false,
|
| 124 |
+
"dtype": null,
|
| 125 |
+
"early_stopping": false,
|
| 126 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 127 |
+
"eos_token_id": null,
|
| 128 |
+
"exponential_decay_length_penalty": null,
|
| 129 |
+
"finetuning_task": null,
|
| 130 |
+
"forced_bos_token_id": null,
|
| 131 |
+
"forced_eos_token_id": null,
|
| 132 |
+
"id2label": {
|
| 133 |
+
"0": "LABEL_0",
|
| 134 |
+
"1": "LABEL_1"
|
| 135 |
+
},
|
| 136 |
+
"inference_cfg_rate": 0.7,
|
| 137 |
+
"is_decoder": false,
|
| 138 |
+
"is_encoder_decoder": false,
|
| 139 |
+
"label2id": {
|
| 140 |
+
"LABEL_0": 0,
|
| 141 |
+
"LABEL_1": 1
|
| 142 |
+
},
|
| 143 |
+
"length_penalty": 1.0,
|
| 144 |
+
"max_length": 20,
|
| 145 |
+
"min_length": 0,
|
| 146 |
+
"model_type": "",
|
| 147 |
+
"no_repeat_ngram_size": 0,
|
| 148 |
+
"num_beam_groups": 1,
|
| 149 |
+
"num_beams": 1,
|
| 150 |
+
"num_return_sequences": 1,
|
| 151 |
+
"output_attentions": false,
|
| 152 |
+
"output_hidden_states": false,
|
| 153 |
+
"output_scores": false,
|
| 154 |
+
"pad_token_id": null,
|
| 155 |
+
"prefix": null,
|
| 156 |
+
"problem_type": null,
|
| 157 |
+
"pruned_heads": {},
|
| 158 |
+
"remove_invalid_values": false,
|
| 159 |
+
"repetition_penalty": 1.0,
|
| 160 |
+
"return_dict": true,
|
| 161 |
+
"return_dict_in_generate": false,
|
| 162 |
+
"sep_token_id": null,
|
| 163 |
+
"sigma_min": 1e-06,
|
| 164 |
+
"solver": "euler",
|
| 165 |
+
"suppress_tokens": null,
|
| 166 |
+
"t_scheduler": "cosine",
|
| 167 |
+
"task_specific_params": null,
|
| 168 |
+
"temperature": 1.0,
|
| 169 |
+
"tf_legacy_loss": false,
|
| 170 |
+
"tie_encoder_decoder": false,
|
| 171 |
+
"tie_word_embeddings": true,
|
| 172 |
+
"tokenizer_class": null,
|
| 173 |
+
"top_k": 50,
|
| 174 |
+
"top_p": 1.0,
|
| 175 |
+
"torchscript": false,
|
| 176 |
+
"training_cfg_rate": 0.2,
|
| 177 |
+
"typical_p": 1.0,
|
| 178 |
+
"use_bfloat16": false
|
| 179 |
+
},
|
| 180 |
+
"channels": [
|
| 181 |
+
256
|
| 182 |
+
],
|
| 183 |
+
"chunk_size_feed_forward": 0,
|
| 184 |
+
"cross_attention_hidden_size": null,
|
| 185 |
+
"decoder_start_token_id": null,
|
| 186 |
+
"diffusion_steps": 10,
|
| 187 |
+
"diversity_penalty": 0.0,
|
| 188 |
+
"do_sample": false,
|
| 189 |
+
"dropout": 0.0,
|
| 190 |
+
"dtype": null,
|
| 191 |
+
"early_stopping": false,
|
| 192 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 193 |
+
"eos_token_id": null,
|
| 194 |
+
"exponential_decay_length_penalty": null,
|
| 195 |
+
"finetuning_task": null,
|
| 196 |
+
"forced_bos_token_id": null,
|
| 197 |
+
"forced_eos_token_id": null,
|
| 198 |
+
"id2label": {
|
| 199 |
+
"0": "LABEL_0",
|
| 200 |
+
"1": "LABEL_1"
|
| 201 |
+
},
|
| 202 |
+
"in_channels": 80,
|
| 203 |
+
"is_decoder": false,
|
| 204 |
+
"is_encoder_decoder": false,
|
| 205 |
+
"label2id": {
|
| 206 |
+
"LABEL_0": 0,
|
| 207 |
+
"LABEL_1": 1
|
| 208 |
+
},
|
| 209 |
+
"length_penalty": 1.0,
|
| 210 |
+
"max_length": 20,
|
| 211 |
+
"min_length": 0,
|
| 212 |
+
"model_type": "",
|
| 213 |
+
"n_blocks": 4,
|
| 214 |
+
"no_repeat_ngram_size": 0,
|
| 215 |
+
"num_beam_groups": 1,
|
| 216 |
+
"num_beams": 1,
|
| 217 |
+
"num_heads": 8,
|
| 218 |
+
"num_mid_blocks": 12,
|
| 219 |
+
"num_return_sequences": 1,
|
| 220 |
+
"output_attentions": false,
|
| 221 |
+
"output_hidden_states": false,
|
| 222 |
+
"output_scores": false,
|
| 223 |
+
"pad_token_id": null,
|
| 224 |
+
"prefix": null,
|
| 225 |
+
"prenet_activation_function": "gelu",
|
| 226 |
+
"prenet_attention_heads": 8,
|
| 227 |
+
"prenet_d_model": 512,
|
| 228 |
+
"prenet_ffn_dim": 2048,
|
| 229 |
+
"prenet_in_dim": 1280,
|
| 230 |
+
"prenet_max_source_positions": 5000,
|
| 231 |
+
"prenet_nlayers": 12,
|
| 232 |
+
"prenet_out_dim": 80,
|
| 233 |
+
"prenet_target_mel_length_scale_ratio": 1.0,
|
| 234 |
+
"problem_type": null,
|
| 235 |
+
"pruned_heads": {},
|
| 236 |
+
"remove_invalid_values": false,
|
| 237 |
+
"repetition_penalty": 1.0,
|
| 238 |
+
"return_dict": true,
|
| 239 |
+
"return_dict_in_generate": false,
|
| 240 |
+
"sep_token_id": null,
|
| 241 |
+
"spk_emb_dim": 0,
|
| 242 |
+
"suppress_tokens": null,
|
| 243 |
+
"task_specific_params": null,
|
| 244 |
+
"temperature": 1.0,
|
| 245 |
+
"tf_legacy_loss": false,
|
| 246 |
+
"tie_encoder_decoder": false,
|
| 247 |
+
"tie_word_embeddings": true,
|
| 248 |
+
"tokenizer_class": null,
|
| 249 |
+
"top_k": 50,
|
| 250 |
+
"top_p": 1.0,
|
| 251 |
+
"torchscript": false,
|
| 252 |
+
"typical_p": 1.0,
|
| 253 |
+
"use_bfloat16": false,
|
| 254 |
+
"use_hidden_states_before_dconv2": true
|
| 255 |
+
},
|
| 256 |
+
"hop_length": 160,
|
| 257 |
+
"init_std": 0.02,
|
| 258 |
+
"kernel_size": 3,
|
| 259 |
+
"mask_feature_length": 10,
|
| 260 |
+
"mask_feature_min_masks": 0,
|
| 261 |
+
"mask_feature_prob": 0.0,
|
| 262 |
+
"mask_time_length": 10,
|
| 263 |
+
"mask_time_min_masks": 2,
|
| 264 |
+
"mask_time_prob": 0.05,
|
| 265 |
+
"max_audio_seconds": 30,
|
| 266 |
+
"max_source_positions": 1500,
|
| 267 |
+
"max_target_positions": 448,
|
| 268 |
+
"median_filter_width": 7,
|
| 269 |
+
"model_type": "longcat_next_audio",
|
| 270 |
+
"n_fft": 400,
|
| 271 |
+
"num_hidden_layers": 32,
|
| 272 |
+
"num_mel_bins": 128,
|
| 273 |
+
"sampling_rate": 16000,
|
| 274 |
+
"scale_embedding": false,
|
| 275 |
+
"stride_size": 2,
|
| 276 |
+
"use_cache": true,
|
| 277 |
+
"use_weighted_layer_sum": false,
|
| 278 |
+
"vocab_size": 51865,
|
| 279 |
+
"vocoder_config": {
|
| 280 |
+
"_name_or_path": "",
|
| 281 |
+
"add_cross_attention": false,
|
| 282 |
+
"architectures": null,
|
| 283 |
+
"bad_words_ids": null,
|
| 284 |
+
"begin_suppress_tokens": null,
|
| 285 |
+
"bos_token_id": null,
|
| 286 |
+
"channels": [
|
| 287 |
+
256,
|
| 288 |
+
256,
|
| 289 |
+
256,
|
| 290 |
+
256,
|
| 291 |
+
256
|
| 292 |
+
],
|
| 293 |
+
"chunk_size_feed_forward": 0,
|
| 294 |
+
"cross_attention_hidden_size": null,
|
| 295 |
+
"decoder_start_token_id": null,
|
| 296 |
+
"diversity_penalty": 0.0,
|
| 297 |
+
"do_sample": false,
|
| 298 |
+
"dtype": null,
|
| 299 |
+
"early_stopping": false,
|
| 300 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 301 |
+
"eos_token_id": null,
|
| 302 |
+
"exponential_decay_length_penalty": null,
|
| 303 |
+
"finetuning_task": null,
|
| 304 |
+
"forced_bos_token_id": null,
|
| 305 |
+
"forced_eos_token_id": null,
|
| 306 |
+
"hop_length": 256,
|
| 307 |
+
"id2label": {
|
| 308 |
+
"0": "LABEL_0",
|
| 309 |
+
"1": "LABEL_1"
|
| 310 |
+
},
|
| 311 |
+
"is_decoder": false,
|
| 312 |
+
"is_encoder_decoder": false,
|
| 313 |
+
"label2id": {
|
| 314 |
+
"LABEL_0": 0,
|
| 315 |
+
"LABEL_1": 1
|
| 316 |
+
},
|
| 317 |
+
"length_penalty": 1.0,
|
| 318 |
+
"max_length": 20,
|
| 319 |
+
"min_length": 0,
|
| 320 |
+
"model_type": "",
|
| 321 |
+
"no_repeat_ngram_size": 0,
|
| 322 |
+
"num_beam_groups": 1,
|
| 323 |
+
"num_beams": 1,
|
| 324 |
+
"num_mel_bins": 80,
|
| 325 |
+
"num_return_sequences": 1,
|
| 326 |
+
"output_attentions": false,
|
| 327 |
+
"output_hidden_states": false,
|
| 328 |
+
"output_scores": false,
|
| 329 |
+
"pad_token_id": null,
|
| 330 |
+
"prefix": null,
|
| 331 |
+
"problem_type": null,
|
| 332 |
+
"pruned_heads": {},
|
| 333 |
+
"remove_invalid_values": false,
|
| 334 |
+
"repetition_penalty": 1.0,
|
| 335 |
+
"return_dict": true,
|
| 336 |
+
"return_dict_in_generate": false,
|
| 337 |
+
"sampling_rate": 16000,
|
| 338 |
+
"sep_token_id": null,
|
| 339 |
+
"suppress_tokens": null,
|
| 340 |
+
"task_specific_params": null,
|
| 341 |
+
"temperature": 1.0,
|
| 342 |
+
"tf_legacy_loss": false,
|
| 343 |
+
"tie_encoder_decoder": false,
|
| 344 |
+
"tie_word_embeddings": true,
|
| 345 |
+
"tokenizer_class": null,
|
| 346 |
+
"top_k": 50,
|
| 347 |
+
"top_p": 1.0,
|
| 348 |
+
"torchscript": false,
|
| 349 |
+
"typical_p": 1.0,
|
| 350 |
+
"use_bfloat16": false
|
| 351 |
+
},
|
| 352 |
+
"vq_config": {
|
| 353 |
+
"_name_or_path": "",
|
| 354 |
+
"add_cross_attention": false,
|
| 355 |
+
"architectures": null,
|
| 356 |
+
"bad_words_ids": null,
|
| 357 |
+
"begin_suppress_tokens": null,
|
| 358 |
+
"bos_token_id": null,
|
| 359 |
+
"chunk_size_feed_forward": 0,
|
| 360 |
+
"codebook_sizes": [
|
| 361 |
+
8192,
|
| 362 |
+
4096,
|
| 363 |
+
2048,
|
| 364 |
+
1024,
|
| 365 |
+
1024,
|
| 366 |
+
1024,
|
| 367 |
+
1024,
|
| 368 |
+
1024
|
| 369 |
+
],
|
| 370 |
+
"cross_attention_hidden_size": null,
|
| 371 |
+
"decoder_start_token_id": null,
|
| 372 |
+
"diversity_penalty": 0.0,
|
| 373 |
+
"do_sample": false,
|
| 374 |
+
"dtype": null,
|
| 375 |
+
"early_stopping": false,
|
| 376 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 377 |
+
"eos_token_id": null,
|
| 378 |
+
"exponential_decay_length_penalty": null,
|
| 379 |
+
"finetuning_task": null,
|
| 380 |
+
"forced_bos_token_id": null,
|
| 381 |
+
"forced_eos_token_id": null,
|
| 382 |
+
"id2label": {
|
| 383 |
+
"0": "LABEL_0",
|
| 384 |
+
"1": "LABEL_1"
|
| 385 |
+
},
|
| 386 |
+
"is_decoder": false,
|
| 387 |
+
"is_encoder_decoder": false,
|
| 388 |
+
"label2id": {
|
| 389 |
+
"LABEL_0": 0,
|
| 390 |
+
"LABEL_1": 1
|
| 391 |
+
},
|
| 392 |
+
"length_penalty": 1.0,
|
| 393 |
+
"max_length": 20,
|
| 394 |
+
"min_length": 0,
|
| 395 |
+
"model_type": "",
|
| 396 |
+
"no_repeat_ngram_size": 0,
|
| 397 |
+
"num_beam_groups": 1,
|
| 398 |
+
"num_beams": 1,
|
| 399 |
+
"num_return_sequences": 1,
|
| 400 |
+
"output_attentions": false,
|
| 401 |
+
"output_hidden_states": false,
|
| 402 |
+
"output_scores": false,
|
| 403 |
+
"pad_token_id": null,
|
| 404 |
+
"prefix": null,
|
| 405 |
+
"problem_type": null,
|
| 406 |
+
"pruned_heads": {},
|
| 407 |
+
"remove_invalid_values": false,
|
| 408 |
+
"repetition_penalty": 1.0,
|
| 409 |
+
"return_dict": true,
|
| 410 |
+
"return_dict_in_generate": false,
|
| 411 |
+
"sep_token_id": null,
|
| 412 |
+
"suppress_tokens": null,
|
| 413 |
+
"task_specific_params": null,
|
| 414 |
+
"temperature": 1.0,
|
| 415 |
+
"tf_legacy_loss": false,
|
| 416 |
+
"tie_encoder_decoder": false,
|
| 417 |
+
"tie_word_embeddings": true,
|
| 418 |
+
"tokenizer_class": null,
|
| 419 |
+
"top_k": 50,
|
| 420 |
+
"top_p": 1.0,
|
| 421 |
+
"torchscript": false,
|
| 422 |
+
"typical_p": 1.0,
|
| 423 |
+
"use_bfloat16": false
|
| 424 |
+
}
|
| 425 |
+
},
|
| 426 |
+
"audio_offset": 131125,
|
| 427 |
+
"auto_map": {
|
| 428 |
+
"AutoConfig": "configuration_longcat_next.LongcatNextConfig",
|
| 429 |
+
"AutoModel": "modeling_longcat_next.LongcatNextModel",
|
| 430 |
+
"AutoModelForCausalLM": "modeling_longcat_next.LongcatNextForCausalLM"
|
| 431 |
+
},
|
| 432 |
+
"bos_token_id": 1,
|
| 433 |
+
"dtype": "bfloat16",
|
| 434 |
+
"emb_neighbor_num": 4,
|
| 435 |
+
"emb_split_num": 4,
|
| 436 |
+
"eos_token_id": 2,
|
| 437 |
+
"expert_ffn_hidden_size": 1024,
|
| 438 |
+
"ffn_hidden_size": 6144,
|
| 439 |
+
"head_dim": 64,
|
| 440 |
+
"hidden_act": "silu",
|
| 441 |
+
"hidden_size": 3072,
|
| 442 |
+
"initializer_range": 0.02,
|
| 443 |
+
"kv_lora_rank": 512,
|
| 444 |
+
"max_position_embeddings": 131072,
|
| 445 |
+
"mla_scale_kv_lora": true,
|
| 446 |
+
"mla_scale_q_lora": true,
|
| 447 |
+
"model_type": "longcat_next",
|
| 448 |
+
"moe_topk": 12,
|
| 449 |
+
"n_routed_experts": 256,
|
| 450 |
+
"ngram_vocab_size_ratio": 78,
|
| 451 |
+
"num_attention_heads": 32,
|
| 452 |
+
"num_hidden_layers": 28,
|
| 453 |
+
"num_key_value_heads": 32,
|
| 454 |
+
"num_layers": 14,
|
| 455 |
+
"oe_ignored_token_ids": [
|
| 456 |
+
131072,
|
| 457 |
+
131073,
|
| 458 |
+
131074,
|
| 459 |
+
131075,
|
| 460 |
+
131076,
|
| 461 |
+
131077,
|
| 462 |
+
131078,
|
| 463 |
+
131079,
|
| 464 |
+
131080,
|
| 465 |
+
131081,
|
| 466 |
+
131082,
|
| 467 |
+
131083,
|
| 468 |
+
131084,
|
| 469 |
+
131085,
|
| 470 |
+
131086,
|
| 471 |
+
131087,
|
| 472 |
+
131088,
|
| 473 |
+
131089,
|
| 474 |
+
131090,
|
| 475 |
+
131091,
|
| 476 |
+
131092,
|
| 477 |
+
131093,
|
| 478 |
+
131094,
|
| 479 |
+
131095,
|
| 480 |
+
131096,
|
| 481 |
+
131097,
|
| 482 |
+
131098,
|
| 483 |
+
131099,
|
| 484 |
+
131100,
|
| 485 |
+
131101,
|
| 486 |
+
131102,
|
| 487 |
+
131103,
|
| 488 |
+
131104,
|
| 489 |
+
131105,
|
| 490 |
+
131106,
|
| 491 |
+
131107,
|
| 492 |
+
131108,
|
| 493 |
+
131109,
|
| 494 |
+
131110,
|
| 495 |
+
131111,
|
| 496 |
+
131112,
|
| 497 |
+
131113,
|
| 498 |
+
131114,
|
| 499 |
+
131115,
|
| 500 |
+
131116,
|
| 501 |
+
131117,
|
| 502 |
+
131118,
|
| 503 |
+
131119,
|
| 504 |
+
131120,
|
| 505 |
+
131121,
|
| 506 |
+
131122,
|
| 507 |
+
131123,
|
| 508 |
+
131124
|
| 509 |
+
],
|
| 510 |
+
"q_lora_rank": 1536,
|
| 511 |
+
"qk_head_dim": 192,
|
| 512 |
+
"qk_nope_head_dim": 128,
|
| 513 |
+
"qk_rope_head_dim": 64,
|
| 514 |
+
"quantization_config": {
|
| 515 |
+
"autoround_version": "0.13.0",
|
| 516 |
+
"batch_size": 1,
|
| 517 |
+
"bits": 4,
|
| 518 |
+
"block_name_to_quantize": "model.layers",
|
| 519 |
+
"data_type": "int",
|
| 520 |
+
"extra_config": {
|
| 521 |
+
".*classifier.*": {
|
| 522 |
+
"bits": 16,
|
| 523 |
+
"data_type": "float"
|
| 524 |
+
},
|
| 525 |
+
"model.layers.0.mlp.router.classifier": {
|
| 526 |
+
"bits": 16,
|
| 527 |
+
"data_type": "float"
|
| 528 |
+
},
|
| 529 |
+
"model.layers.1.mlp.router.classifier": {
|
| 530 |
+
"bits": 16,
|
| 531 |
+
"data_type": "float"
|
| 532 |
+
},
|
| 533 |
+
"model.layers.10.mlp.router.classifier": {
|
| 534 |
+
"bits": 16,
|
| 535 |
+
"data_type": "float"
|
| 536 |
+
},
|
| 537 |
+
"model.layers.11.mlp.router.classifier": {
|
| 538 |
+
"bits": 16,
|
| 539 |
+
"data_type": "float"
|
| 540 |
+
},
|
| 541 |
+
"model.layers.12.mlp.router.classifier": {
|
| 542 |
+
"bits": 16,
|
| 543 |
+
"data_type": "float"
|
| 544 |
+
},
|
| 545 |
+
"model.layers.13.mlp.router.classifier": {
|
| 546 |
+
"bits": 16,
|
| 547 |
+
"data_type": "float"
|
| 548 |
+
},
|
| 549 |
+
"model.layers.2.mlp.router.classifier": {
|
| 550 |
+
"bits": 16,
|
| 551 |
+
"data_type": "float"
|
| 552 |
+
},
|
| 553 |
+
"model.layers.3.mlp.router.classifier": {
|
| 554 |
+
"bits": 16,
|
| 555 |
+
"data_type": "float"
|
| 556 |
+
},
|
| 557 |
+
"model.layers.4.mlp.router.classifier": {
|
| 558 |
+
"bits": 16,
|
| 559 |
+
"data_type": "float"
|
| 560 |
+
},
|
| 561 |
+
"model.layers.5.mlp.router.classifier": {
|
| 562 |
+
"bits": 16,
|
| 563 |
+
"data_type": "float"
|
| 564 |
+
},
|
| 565 |
+
"model.layers.6.mlp.router.classifier": {
|
| 566 |
+
"bits": 16,
|
| 567 |
+
"data_type": "float"
|
| 568 |
+
},
|
| 569 |
+
"model.layers.7.mlp.router.classifier": {
|
| 570 |
+
"bits": 16,
|
| 571 |
+
"data_type": "float"
|
| 572 |
+
},
|
| 573 |
+
"model.layers.8.mlp.router.classifier": {
|
| 574 |
+
"bits": 16,
|
| 575 |
+
"data_type": "float"
|
| 576 |
+
},
|
| 577 |
+
"model.layers.9.mlp.router.classifier": {
|
| 578 |
+
"bits": 16,
|
| 579 |
+
"data_type": "float"
|
| 580 |
+
}
|
| 581 |
+
},
|
| 582 |
+
"gradient_accumulate_steps": 8,
|
| 583 |
+
"group_size": 128,
|
| 584 |
+
"packing_format": "auto_round:auto_gptq",
|
| 585 |
+
"quant_method": "auto-round",
|
| 586 |
+
"seqlen": 512,
|
| 587 |
+
"sym": true
|
| 588 |
+
},
|
| 589 |
+
"rms_norm_eps": 1e-05,
|
| 590 |
+
"rope_scaling": null,
|
| 591 |
+
"rope_theta": 10000000,
|
| 592 |
+
"routed_scaling_factor": 6.0,
|
| 593 |
+
"text_vocab_plus_multimodal_special_token_size": 131125,
|
| 594 |
+
"text_vocab_size": 131072,
|
| 595 |
+
"tie_word_embeddings": false,
|
| 596 |
+
"transformers_version": "4.57.6",
|
| 597 |
+
"use_cache": true,
|
| 598 |
+
"v_head_dim": 128,
|
| 599 |
+
"visual_config": {
|
| 600 |
+
"depth": 32,
|
| 601 |
+
"fullatt_block_indexes": [
|
| 602 |
+
7,
|
| 603 |
+
15,
|
| 604 |
+
23,
|
| 605 |
+
31
|
| 606 |
+
],
|
| 607 |
+
"hidden_act": "silu",
|
| 608 |
+
"hidden_size": 1280,
|
| 609 |
+
"image_end_token_id": 131107,
|
| 610 |
+
"image_head_transformer_dims": 2048,
|
| 611 |
+
"image_head_transformer_ffn_scale": 16,
|
| 612 |
+
"image_head_transformer_layers": 4,
|
| 613 |
+
"image_newline_token_id": 131109,
|
| 614 |
+
"image_pad_token_id": 131108,
|
| 615 |
+
"image_start_token_id": 131106,
|
| 616 |
+
"in_channels": 3,
|
| 617 |
+
"initializer_range": 0.02,
|
| 618 |
+
"intermediate_size": 3420,
|
| 619 |
+
"model_type": "longcat_next_visual",
|
| 620 |
+
"num_heads": 16,
|
| 621 |
+
"out_hidden_size": 3584,
|
| 622 |
+
"patch_size": 14,
|
| 623 |
+
"spatial_merge_size": 2,
|
| 624 |
+
"temporal_patch_size": 2,
|
| 625 |
+
"tokens_per_second": 4,
|
| 626 |
+
"visual_decoder_config": {
|
| 627 |
+
"_name_or_path": "",
|
| 628 |
+
"add_cross_attention": false,
|
| 629 |
+
"architectures": null,
|
| 630 |
+
"bad_words_ids": null,
|
| 631 |
+
"begin_suppress_tokens": null,
|
| 632 |
+
"bos_token_id": null,
|
| 633 |
+
"chunk_size_feed_forward": 0,
|
| 634 |
+
"codebook_dim": 3584,
|
| 635 |
+
"cross_attention_hidden_size": null,
|
| 636 |
+
"decoder_start_token_id": null,
|
| 637 |
+
"diversity_penalty": 0.0,
|
| 638 |
+
"do_sample": false,
|
| 639 |
+
"dtype": null,
|
| 640 |
+
"early_stopping": false,
|
| 641 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 642 |
+
"eos_token_id": null,
|
| 643 |
+
"exponential_decay_length_penalty": null,
|
| 644 |
+
"finetuning_task": null,
|
| 645 |
+
"forced_bos_token_id": null,
|
| 646 |
+
"forced_eos_token_id": null,
|
| 647 |
+
"id2label": {
|
| 648 |
+
"0": "LABEL_0",
|
| 649 |
+
"1": "LABEL_1"
|
| 650 |
+
},
|
| 651 |
+
"image_decoder_config": {
|
| 652 |
+
"_name_or_path": "",
|
| 653 |
+
"add_cross_attention": false,
|
| 654 |
+
"architectures": null,
|
| 655 |
+
"attention_dropout": 0.0,
|
| 656 |
+
"bad_words_ids": null,
|
| 657 |
+
"begin_suppress_tokens": null,
|
| 658 |
+
"bos_token_id": null,
|
| 659 |
+
"chunk_size_feed_forward": 0,
|
| 660 |
+
"codebook_dim": 3584,
|
| 661 |
+
"cross_attention_hidden_size": null,
|
| 662 |
+
"decoder_start_token_id": null,
|
| 663 |
+
"distill_taps": [
|
| 664 |
+
3,
|
| 665 |
+
7,
|
| 666 |
+
15,
|
| 667 |
+
23
|
| 668 |
+
],
|
| 669 |
+
"diversity_penalty": 0.0,
|
| 670 |
+
"do_sample": false,
|
| 671 |
+
"dtype": null,
|
| 672 |
+
"early_stopping": false,
|
| 673 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 674 |
+
"eos_token_id": null,
|
| 675 |
+
"exponential_decay_length_penalty": null,
|
| 676 |
+
"finetuning_task": null,
|
| 677 |
+
"forced_bos_token_id": null,
|
| 678 |
+
"forced_eos_token_id": null,
|
| 679 |
+
"hidden_act": "gelu",
|
| 680 |
+
"hidden_size": 1024,
|
| 681 |
+
"id2label": {
|
| 682 |
+
"0": "LABEL_0",
|
| 683 |
+
"1": "LABEL_1"
|
| 684 |
+
},
|
| 685 |
+
"intermediate_size": 2730,
|
| 686 |
+
"is_decoder": false,
|
| 687 |
+
"is_encoder_decoder": false,
|
| 688 |
+
"k_bias": false,
|
| 689 |
+
"label2id": {
|
| 690 |
+
"LABEL_0": 0,
|
| 691 |
+
"LABEL_1": 1
|
| 692 |
+
},
|
| 693 |
+
"layer_norm_eps": 1e-06,
|
| 694 |
+
"length_penalty": 1.0,
|
| 695 |
+
"max_length": 20,
|
| 696 |
+
"min_length": 0,
|
| 697 |
+
"model_type": "",
|
| 698 |
+
"no_repeat_ngram_size": 0,
|
| 699 |
+
"num_attention_heads": 16,
|
| 700 |
+
"num_beam_groups": 1,
|
| 701 |
+
"num_beams": 1,
|
| 702 |
+
"num_hidden_layers": 32,
|
| 703 |
+
"num_return_sequences": 1,
|
| 704 |
+
"output_attentions": false,
|
| 705 |
+
"output_hidden_states": false,
|
| 706 |
+
"output_scores": false,
|
| 707 |
+
"pad_token_id": null,
|
| 708 |
+
"patch_size": 14,
|
| 709 |
+
"prefix": null,
|
| 710 |
+
"problem_type": null,
|
| 711 |
+
"pruned_heads": {},
|
| 712 |
+
"q_bias": true,
|
| 713 |
+
"remove_invalid_values": false,
|
| 714 |
+
"repetition_penalty": 1.0,
|
| 715 |
+
"return_dict": true,
|
| 716 |
+
"return_dict_in_generate": false,
|
| 717 |
+
"sep_token_id": null,
|
| 718 |
+
"spatial_merge_size": 2,
|
| 719 |
+
"subln": true,
|
| 720 |
+
"suppress_tokens": null,
|
| 721 |
+
"swiglu": true,
|
| 722 |
+
"task_specific_params": null,
|
| 723 |
+
"teacher_dims": {
|
| 724 |
+
"15": 1280,
|
| 725 |
+
"23": 1280,
|
| 726 |
+
"3": 1280,
|
| 727 |
+
"7": 1280
|
| 728 |
+
},
|
| 729 |
+
"temperature": 1.0,
|
| 730 |
+
"temporal_patch_size": 2,
|
| 731 |
+
"tf_legacy_loss": false,
|
| 732 |
+
"tie_encoder_decoder": false,
|
| 733 |
+
"tie_word_embeddings": true,
|
| 734 |
+
"tokenizer_class": null,
|
| 735 |
+
"top_k": 50,
|
| 736 |
+
"top_p": 1.0,
|
| 737 |
+
"torchscript": false,
|
| 738 |
+
"typical_p": 1.0,
|
| 739 |
+
"use_bfloat16": false,
|
| 740 |
+
"v_bias": true
|
| 741 |
+
},
|
| 742 |
+
"is_decoder": false,
|
| 743 |
+
"is_encoder_decoder": false,
|
| 744 |
+
"label2id": {
|
| 745 |
+
"LABEL_0": 0,
|
| 746 |
+
"LABEL_1": 1
|
| 747 |
+
},
|
| 748 |
+
"length_penalty": 1.0,
|
| 749 |
+
"max_length": 20,
|
| 750 |
+
"min_length": 0,
|
| 751 |
+
"model_type": "",
|
| 752 |
+
"no_repeat_ngram_size": 0,
|
| 753 |
+
"num_beam_groups": 1,
|
| 754 |
+
"num_beams": 1,
|
| 755 |
+
"num_return_sequences": 1,
|
| 756 |
+
"output_attentions": false,
|
| 757 |
+
"output_hidden_states": false,
|
| 758 |
+
"output_scores": false,
|
| 759 |
+
"pad_token_id": null,
|
| 760 |
+
"prefix": null,
|
| 761 |
+
"problem_type": null,
|
| 762 |
+
"pruned_heads": {},
|
| 763 |
+
"remove_invalid_values": false,
|
| 764 |
+
"repetition_penalty": 1.0,
|
| 765 |
+
"return_dict": true,
|
| 766 |
+
"return_dict_in_generate": false,
|
| 767 |
+
"scheduler_config": {
|
| 768 |
+
"_name_or_path": "",
|
| 769 |
+
"add_cross_attention": false,
|
| 770 |
+
"architectures": null,
|
| 771 |
+
"bad_words_ids": null,
|
| 772 |
+
"begin_suppress_tokens": null,
|
| 773 |
+
"bos_token_id": null,
|
| 774 |
+
"chunk_size_feed_forward": 0,
|
| 775 |
+
"cross_attention_hidden_size": null,
|
| 776 |
+
"decoder_start_token_id": null,
|
| 777 |
+
"diversity_penalty": 0.0,
|
| 778 |
+
"do_sample": false,
|
| 779 |
+
"dtype": null,
|
| 780 |
+
"dynamic_time_shift": true,
|
| 781 |
+
"early_stopping": false,
|
| 782 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 783 |
+
"eos_token_id": null,
|
| 784 |
+
"exponential_decay_length_penalty": null,
|
| 785 |
+
"finetuning_task": null,
|
| 786 |
+
"forced_bos_token_id": null,
|
| 787 |
+
"forced_eos_token_id": null,
|
| 788 |
+
"id2label": {
|
| 789 |
+
"0": "LABEL_0",
|
| 790 |
+
"1": "LABEL_1"
|
| 791 |
+
},
|
| 792 |
+
"is_decoder": false,
|
| 793 |
+
"is_encoder_decoder": false,
|
| 794 |
+
"label2id": {
|
| 795 |
+
"LABEL_0": 0,
|
| 796 |
+
"LABEL_1": 1
|
| 797 |
+
},
|
| 798 |
+
"length_penalty": 1.0,
|
| 799 |
+
"max_length": 20,
|
| 800 |
+
"min_length": 0,
|
| 801 |
+
"model_type": "",
|
| 802 |
+
"no_repeat_ngram_size": 0,
|
| 803 |
+
"num_beam_groups": 1,
|
| 804 |
+
"num_beams": 1,
|
| 805 |
+
"num_return_sequences": 1,
|
| 806 |
+
"num_train_timesteps": 1000,
|
| 807 |
+
"output_attentions": false,
|
| 808 |
+
"output_hidden_states": false,
|
| 809 |
+
"output_scores": false,
|
| 810 |
+
"pad_token_id": null,
|
| 811 |
+
"prefix": null,
|
| 812 |
+
"problem_type": null,
|
| 813 |
+
"pruned_heads": {},
|
| 814 |
+
"remove_invalid_values": false,
|
| 815 |
+
"repetition_penalty": 1.0,
|
| 816 |
+
"return_dict": true,
|
| 817 |
+
"return_dict_in_generate": false,
|
| 818 |
+
"sep_token_id": null,
|
| 819 |
+
"suppress_tokens": null,
|
| 820 |
+
"task_specific_params": null,
|
| 821 |
+
"temperature": 1.0,
|
| 822 |
+
"tf_legacy_loss": false,
|
| 823 |
+
"tie_encoder_decoder": false,
|
| 824 |
+
"tie_word_embeddings": true,
|
| 825 |
+
"tokenizer_class": null,
|
| 826 |
+
"top_k": 50,
|
| 827 |
+
"top_p": 1.0,
|
| 828 |
+
"torchscript": false,
|
| 829 |
+
"typical_p": 1.0,
|
| 830 |
+
"use_bfloat16": false
|
| 831 |
+
},
|
| 832 |
+
"sep_token_id": null,
|
| 833 |
+
"suppress_tokens": null,
|
| 834 |
+
"task_specific_params": null,
|
| 835 |
+
"temperature": 1.0,
|
| 836 |
+
"tf_legacy_loss": false,
|
| 837 |
+
"tie_encoder_decoder": false,
|
| 838 |
+
"tie_word_embeddings": true,
|
| 839 |
+
"tokenizer_class": null,
|
| 840 |
+
"top_k": 50,
|
| 841 |
+
"top_p": 1.0,
|
| 842 |
+
"torchscript": false,
|
| 843 |
+
"transformer_config": {
|
| 844 |
+
"_name_or_path": "",
|
| 845 |
+
"add_cross_attention": false,
|
| 846 |
+
"architectures": null,
|
| 847 |
+
"axes_dim_rope": [
|
| 848 |
+
40,
|
| 849 |
+
40,
|
| 850 |
+
40
|
| 851 |
+
],
|
| 852 |
+
"axes_lens": [
|
| 853 |
+
10000,
|
| 854 |
+
10000,
|
| 855 |
+
10000
|
| 856 |
+
],
|
| 857 |
+
"bad_words_ids": null,
|
| 858 |
+
"begin_suppress_tokens": null,
|
| 859 |
+
"bos_token_id": null,
|
| 860 |
+
"chunk_size_feed_forward": 0,
|
| 861 |
+
"cross_attention_hidden_size": null,
|
| 862 |
+
"decoder_start_token_id": null,
|
| 863 |
+
"diversity_penalty": 0.0,
|
| 864 |
+
"do_sample": false,
|
| 865 |
+
"dtype": null,
|
| 866 |
+
"early_stopping": false,
|
| 867 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 868 |
+
"eos_token_id": null,
|
| 869 |
+
"exponential_decay_length_penalty": null,
|
| 870 |
+
"finetuning_task": null,
|
| 871 |
+
"forced_bos_token_id": null,
|
| 872 |
+
"forced_eos_token_id": null,
|
| 873 |
+
"hidden_size": 2520,
|
| 874 |
+
"id2label": {
|
| 875 |
+
"0": "LABEL_0",
|
| 876 |
+
"1": "LABEL_1"
|
| 877 |
+
},
|
| 878 |
+
"in_channels": 16,
|
| 879 |
+
"is_decoder": false,
|
| 880 |
+
"is_encoder_decoder": false,
|
| 881 |
+
"label2id": {
|
| 882 |
+
"LABEL_0": 0,
|
| 883 |
+
"LABEL_1": 1
|
| 884 |
+
},
|
| 885 |
+
"length_penalty": 1.0,
|
| 886 |
+
"max_length": 20,
|
| 887 |
+
"min_length": 0,
|
| 888 |
+
"model_type": "",
|
| 889 |
+
"multiple_of": 256,
|
| 890 |
+
"no_repeat_ngram_size": 0,
|
| 891 |
+
"norm_eps": 1e-05,
|
| 892 |
+
"num_attention_heads": 21,
|
| 893 |
+
"num_beam_groups": 1,
|
| 894 |
+
"num_beams": 1,
|
| 895 |
+
"num_kv_heads": 7,
|
| 896 |
+
"num_layers": 32,
|
| 897 |
+
"num_refiner_layers": 2,
|
| 898 |
+
"num_return_sequences": 1,
|
| 899 |
+
"output_attentions": false,
|
| 900 |
+
"output_hidden_states": false,
|
| 901 |
+
"output_scores": false,
|
| 902 |
+
"pad_token_id": null,
|
| 903 |
+
"patch_size": 2,
|
| 904 |
+
"prefix": null,
|
| 905 |
+
"problem_type": null,
|
| 906 |
+
"pruned_heads": {},
|
| 907 |
+
"remove_invalid_values": false,
|
| 908 |
+
"repetition_penalty": 1.0,
|
| 909 |
+
"return_dict": true,
|
| 910 |
+
"return_dict_in_generate": false,
|
| 911 |
+
"sep_token_id": null,
|
| 912 |
+
"suppress_tokens": null,
|
| 913 |
+
"task_specific_params": null,
|
| 914 |
+
"temperature": 1.0,
|
| 915 |
+
"text_feat_dim": 2048,
|
| 916 |
+
"tf_legacy_loss": false,
|
| 917 |
+
"tie_encoder_decoder": false,
|
| 918 |
+
"tie_word_embeddings": true,
|
| 919 |
+
"timestep_scale": 1000.0,
|
| 920 |
+
"tokenizer_class": null,
|
| 921 |
+
"top_k": 50,
|
| 922 |
+
"top_p": 1.0,
|
| 923 |
+
"torchscript": false,
|
| 924 |
+
"typical_p": 1.0,
|
| 925 |
+
"use_bfloat16": false
|
| 926 |
+
},
|
| 927 |
+
"typical_p": 1.0,
|
| 928 |
+
"use_bfloat16": false,
|
| 929 |
+
"vae_config": {
|
| 930 |
+
"_name_or_path": "",
|
| 931 |
+
"act_fn": "silu",
|
| 932 |
+
"add_cross_attention": false,
|
| 933 |
+
"architectures": null,
|
| 934 |
+
"bad_words_ids": null,
|
| 935 |
+
"begin_suppress_tokens": null,
|
| 936 |
+
"block_out_channels": [
|
| 937 |
+
128,
|
| 938 |
+
256,
|
| 939 |
+
512,
|
| 940 |
+
512
|
| 941 |
+
],
|
| 942 |
+
"bos_token_id": null,
|
| 943 |
+
"chunk_size_feed_forward": 0,
|
| 944 |
+
"cross_attention_hidden_size": null,
|
| 945 |
+
"decoder_start_token_id": null,
|
| 946 |
+
"diversity_penalty": 0.0,
|
| 947 |
+
"do_sample": false,
|
| 948 |
+
"down_block_types": [
|
| 949 |
+
"DownEncoderBlock2D",
|
| 950 |
+
"DownEncoderBlock2D",
|
| 951 |
+
"DownEncoderBlock2D",
|
| 952 |
+
"DownEncoderBlock2D"
|
| 953 |
+
],
|
| 954 |
+
"dtype": null,
|
| 955 |
+
"early_stopping": false,
|
| 956 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 957 |
+
"eos_token_id": null,
|
| 958 |
+
"exponential_decay_length_penalty": null,
|
| 959 |
+
"finetuning_task": null,
|
| 960 |
+
"force_upcast": true,
|
| 961 |
+
"forced_bos_token_id": null,
|
| 962 |
+
"forced_eos_token_id": null,
|
| 963 |
+
"id2label": {
|
| 964 |
+
"0": "LABEL_0",
|
| 965 |
+
"1": "LABEL_1"
|
| 966 |
+
},
|
| 967 |
+
"in_channels": 3,
|
| 968 |
+
"is_decoder": false,
|
| 969 |
+
"is_encoder_decoder": false,
|
| 970 |
+
"label2id": {
|
| 971 |
+
"LABEL_0": 0,
|
| 972 |
+
"LABEL_1": 1
|
| 973 |
+
},
|
| 974 |
+
"latent_channels": 16,
|
| 975 |
+
"layers_per_block": 2,
|
| 976 |
+
"length_penalty": 1.0,
|
| 977 |
+
"max_length": 20,
|
| 978 |
+
"mid_block_add_attention": true,
|
| 979 |
+
"min_length": 0,
|
| 980 |
+
"model_type": "",
|
| 981 |
+
"no_repeat_ngram_size": 0,
|
| 982 |
+
"norm_num_groups": 32,
|
| 983 |
+
"num_beam_groups": 1,
|
| 984 |
+
"num_beams": 1,
|
| 985 |
+
"num_return_sequences": 1,
|
| 986 |
+
"out_channels": 3,
|
| 987 |
+
"output_attentions": false,
|
| 988 |
+
"output_hidden_states": false,
|
| 989 |
+
"output_scores": false,
|
| 990 |
+
"pad_token_id": null,
|
| 991 |
+
"prefix": null,
|
| 992 |
+
"problem_type": null,
|
| 993 |
+
"pruned_heads": {},
|
| 994 |
+
"remove_invalid_values": false,
|
| 995 |
+
"repetition_penalty": 1.0,
|
| 996 |
+
"return_dict": true,
|
| 997 |
+
"return_dict_in_generate": false,
|
| 998 |
+
"sample_size": 1024,
|
| 999 |
+
"scaling_factor": 0.3611,
|
| 1000 |
+
"sep_token_id": null,
|
| 1001 |
+
"shift_factor": 0.1159,
|
| 1002 |
+
"suppress_tokens": null,
|
| 1003 |
+
"task_specific_params": null,
|
| 1004 |
+
"temperature": 1.0,
|
| 1005 |
+
"tf_legacy_loss": false,
|
| 1006 |
+
"tie_encoder_decoder": false,
|
| 1007 |
+
"tie_word_embeddings": true,
|
| 1008 |
+
"tokenizer_class": null,
|
| 1009 |
+
"top_k": 50,
|
| 1010 |
+
"top_p": 1.0,
|
| 1011 |
+
"torchscript": false,
|
| 1012 |
+
"typical_p": 1.0,
|
| 1013 |
+
"up_block_types": [
|
| 1014 |
+
"UpDecoderBlock2D",
|
| 1015 |
+
"UpDecoderBlock2D",
|
| 1016 |
+
"UpDecoderBlock2D",
|
| 1017 |
+
"UpDecoderBlock2D"
|
| 1018 |
+
],
|
| 1019 |
+
"use_bfloat16": false,
|
| 1020 |
+
"use_post_quant_conv": false,
|
| 1021 |
+
"use_quant_conv": false
|
| 1022 |
+
},
|
| 1023 |
+
"weight_path": "WEIGHT_PATH_TO_LONGCAT_NEXT/image_decoder/image_decoder.safetensors"
|
| 1024 |
+
},
|
| 1025 |
+
"vq_config": {
|
| 1026 |
+
"_name_or_path": "",
|
| 1027 |
+
"add_cross_attention": false,
|
| 1028 |
+
"architectures": null,
|
| 1029 |
+
"bad_words_ids": null,
|
| 1030 |
+
"begin_suppress_tokens": null,
|
| 1031 |
+
"bos_token_id": null,
|
| 1032 |
+
"chunk_size_feed_forward": 0,
|
| 1033 |
+
"codebook_dim": 3584,
|
| 1034 |
+
"codebook_size": 16384,
|
| 1035 |
+
"codebook_sizes": [
|
| 1036 |
+
16384,
|
| 1037 |
+
16384,
|
| 1038 |
+
16384,
|
| 1039 |
+
16384,
|
| 1040 |
+
16384,
|
| 1041 |
+
16384,
|
| 1042 |
+
16384,
|
| 1043 |
+
16384
|
| 1044 |
+
],
|
| 1045 |
+
"commit_loss_ratio": 0.25,
|
| 1046 |
+
"cross_attention_hidden_size": null,
|
| 1047 |
+
"decay": 0.99,
|
| 1048 |
+
"decoder_start_token_id": null,
|
| 1049 |
+
"depth": 8,
|
| 1050 |
+
"diversity_penalty": 0.0,
|
| 1051 |
+
"do_sample": false,
|
| 1052 |
+
"dtype": null,
|
| 1053 |
+
"early_stopping": false,
|
| 1054 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 1055 |
+
"entropy_loss_ratio": 0,
|
| 1056 |
+
"eos_token_id": null,
|
| 1057 |
+
"exponential_decay_length_penalty": null,
|
| 1058 |
+
"finetuning_task": null,
|
| 1059 |
+
"forced_bos_token_id": null,
|
| 1060 |
+
"forced_eos_token_id": null,
|
| 1061 |
+
"id2label": {
|
| 1062 |
+
"0": "LABEL_0",
|
| 1063 |
+
"1": "LABEL_1"
|
| 1064 |
+
},
|
| 1065 |
+
"in_channels": 3584,
|
| 1066 |
+
"is_decoder": false,
|
| 1067 |
+
"is_encoder_decoder": false,
|
| 1068 |
+
"label2id": {
|
| 1069 |
+
"LABEL_0": 0,
|
| 1070 |
+
"LABEL_1": 1
|
| 1071 |
+
},
|
| 1072 |
+
"length_penalty": 1.0,
|
| 1073 |
+
"max_length": 20,
|
| 1074 |
+
"min_length": 0,
|
| 1075 |
+
"model_type": "",
|
| 1076 |
+
"no_repeat_ngram_size": 0,
|
| 1077 |
+
"num_beam_groups": 1,
|
| 1078 |
+
"num_beams": 1,
|
| 1079 |
+
"num_return_sequences": 1,
|
| 1080 |
+
"output_attentions": false,
|
| 1081 |
+
"output_hidden_states": false,
|
| 1082 |
+
"output_scores": false,
|
| 1083 |
+
"pad_token_id": null,
|
| 1084 |
+
"prefix": null,
|
| 1085 |
+
"problem_type": null,
|
| 1086 |
+
"pruned_heads": {},
|
| 1087 |
+
"quant_conv": true,
|
| 1088 |
+
"quantizer_type": "rq",
|
| 1089 |
+
"remove_invalid_values": false,
|
| 1090 |
+
"repetition_penalty": 1.0,
|
| 1091 |
+
"restart_unused_codes": true,
|
| 1092 |
+
"return_dict": true,
|
| 1093 |
+
"return_dict_in_generate": false,
|
| 1094 |
+
"sep_token_id": null,
|
| 1095 |
+
"shared_codebook": true,
|
| 1096 |
+
"suppress_tokens": null,
|
| 1097 |
+
"task_specific_params": null,
|
| 1098 |
+
"temperature": 1.0,
|
| 1099 |
+
"tf_legacy_loss": false,
|
| 1100 |
+
"tie_encoder_decoder": false,
|
| 1101 |
+
"tie_word_embeddings": true,
|
| 1102 |
+
"tokenizer_class": null,
|
| 1103 |
+
"top_k": 50,
|
| 1104 |
+
"top_p": 1.0,
|
| 1105 |
+
"torchscript": false,
|
| 1106 |
+
"typical_p": 1.0,
|
| 1107 |
+
"use_bfloat16": false,
|
| 1108 |
+
"vq_loss_ratio": 0
|
| 1109 |
+
},
|
| 1110 |
+
"window_size": 112
|
| 1111 |
+
},
|
| 1112 |
+
"visual_embedding_layer_hidden_act": "silu",
|
| 1113 |
+
"visual_embedding_layer_intermediate_size": 8192,
|
| 1114 |
+
"visual_offset": 150581,
|
| 1115 |
+
"vocab_size": 282624,
|
| 1116 |
+
"zero_expert_num": 128,
|
| 1117 |
+
"zero_expert_type": "identity"
|
| 1118 |
+
}
|
configuration_longcat_next.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 3 |
+
from transformers.models.whisper.configuration_whisper import WhisperConfig
|
| 4 |
+
|
| 5 |
+
from .configuration_longcat_ngram import LongcatFlashNgramConfig
|
| 6 |
+
|
| 7 |
+
class LongcatNextConfig(LongcatFlashNgramConfig):
|
| 8 |
+
model_type = "longcat_next"
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
vocab_size=131072,
|
| 12 |
+
hidden_size=6144,
|
| 13 |
+
num_hidden_layers=56,
|
| 14 |
+
num_layers=28,
|
| 15 |
+
num_attention_heads=64,
|
| 16 |
+
num_key_value_heads=None,
|
| 17 |
+
hidden_act="silu",
|
| 18 |
+
max_position_embeddings=131072,
|
| 19 |
+
initializer_range=0.02,
|
| 20 |
+
rms_norm_eps=1e-5,
|
| 21 |
+
use_cache=True,
|
| 22 |
+
pad_token_id=None,
|
| 23 |
+
bos_token_id=1,
|
| 24 |
+
eos_token_id=2,
|
| 25 |
+
tie_word_embeddings=False,
|
| 26 |
+
rope_theta=10000000.0,
|
| 27 |
+
rope_scaling=None,
|
| 28 |
+
attention_bias=False,
|
| 29 |
+
attention_dropout=0.0,
|
| 30 |
+
ffn_hidden_size=12288,
|
| 31 |
+
q_lora_rank=1536,
|
| 32 |
+
kv_lora_rank=512,
|
| 33 |
+
qk_nope_head_dim=128,
|
| 34 |
+
qk_rope_head_dim=64,
|
| 35 |
+
head_dim=64,
|
| 36 |
+
v_head_dim=128,
|
| 37 |
+
qk_head_dim=None,
|
| 38 |
+
moe_topk=12,
|
| 39 |
+
n_routed_experts=512,
|
| 40 |
+
zero_expert_num=256,
|
| 41 |
+
expert_ffn_hidden_size=2048,
|
| 42 |
+
routed_scaling_factor=6.0,
|
| 43 |
+
emb_neighbor_num=None,
|
| 44 |
+
emb_split_num=None,
|
| 45 |
+
ngram_vocab_size_ratio=None,
|
| 46 |
+
oe_ignored_token_ids=[],
|
| 47 |
+
text_vocab_size=131072, # text vocab size (vocab_size = text_vocab_size + audio_token + visual_token + multimodal_special_token_list)
|
| 48 |
+
text_vocab_plus_multimodal_special_token_size=131125,
|
| 49 |
+
visual_embedding_layer_intermediate_size=8192,
|
| 50 |
+
visual_embedding_layer_hidden_act="silu",
|
| 51 |
+
visual_offset=150581,
|
| 52 |
+
audio_offset=131125,
|
| 53 |
+
visual_config={},
|
| 54 |
+
audio_config={},
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
self.text_vocab_size = text_vocab_size
|
| 58 |
+
self.text_vocab_plus_multimodal_special_token_size = text_vocab_plus_multimodal_special_token_size
|
| 59 |
+
self.visual_embedding_layer_intermediate_size = visual_embedding_layer_intermediate_size
|
| 60 |
+
self.visual_embedding_layer_hidden_act = visual_embedding_layer_hidden_act
|
| 61 |
+
self.visual_offset = visual_offset
|
| 62 |
+
self.audio_offset = audio_offset
|
| 63 |
+
self.visual_config = LongcatNextVisualConfig(**visual_config)
|
| 64 |
+
self.audio_config = LongcatNextAudioConfig(**audio_config)
|
| 65 |
+
oe_ignored_token_ids = oe_ignored_token_ids or list(range(self.text_vocab_size, self.text_vocab_plus_multimodal_special_token_size))
|
| 66 |
+
|
| 67 |
+
super().__init__(
|
| 68 |
+
vocab_size=vocab_size,
|
| 69 |
+
hidden_size=hidden_size,
|
| 70 |
+
num_hidden_layers=num_hidden_layers,
|
| 71 |
+
num_layers=num_layers,
|
| 72 |
+
num_attention_heads=num_attention_heads,
|
| 73 |
+
num_key_value_heads=num_key_value_heads,
|
| 74 |
+
hidden_act=hidden_act,
|
| 75 |
+
max_position_embeddings=max_position_embeddings,
|
| 76 |
+
initializer_range=initializer_range,
|
| 77 |
+
rms_norm_eps=rms_norm_eps,
|
| 78 |
+
use_cache=use_cache,
|
| 79 |
+
pad_token_id=pad_token_id,
|
| 80 |
+
bos_token_id=bos_token_id,
|
| 81 |
+
eos_token_id=eos_token_id,
|
| 82 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 83 |
+
rope_theta=rope_theta,
|
| 84 |
+
rope_scaling=rope_scaling,
|
| 85 |
+
attention_bias=attention_bias,
|
| 86 |
+
attention_dropout=attention_dropout,
|
| 87 |
+
ffn_hidden_size=ffn_hidden_size,
|
| 88 |
+
q_lora_rank=q_lora_rank,
|
| 89 |
+
kv_lora_rank=kv_lora_rank,
|
| 90 |
+
qk_nope_head_dim=qk_nope_head_dim,
|
| 91 |
+
qk_rope_head_dim=qk_rope_head_dim,
|
| 92 |
+
head_dim=head_dim,
|
| 93 |
+
v_head_dim=v_head_dim,
|
| 94 |
+
qk_head_dim=qk_head_dim,
|
| 95 |
+
moe_topk=moe_topk,
|
| 96 |
+
n_routed_experts=n_routed_experts,
|
| 97 |
+
zero_expert_num=zero_expert_num,
|
| 98 |
+
expert_ffn_hidden_size=expert_ffn_hidden_size,
|
| 99 |
+
routed_scaling_factor=routed_scaling_factor,
|
| 100 |
+
emb_neighbor_num=emb_neighbor_num,
|
| 101 |
+
emb_split_num=emb_split_num,
|
| 102 |
+
ngram_vocab_size_ratio=ngram_vocab_size_ratio,
|
| 103 |
+
oe_ignored_token_ids=oe_ignored_token_ids,
|
| 104 |
+
**kwargs,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
class LongcatNextVisualConfig(Qwen2_5_VLVisionConfig):
|
| 108 |
+
model_type = "longcat_next_visual"
|
| 109 |
+
base_config_key = ""
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
image_start_token_id=131106,
|
| 114 |
+
image_end_token_id=131107,
|
| 115 |
+
image_pad_token_id=131108,
|
| 116 |
+
image_newline_token_id=131109,
|
| 117 |
+
vq_config={},
|
| 118 |
+
visual_decoder_config={},
|
| 119 |
+
**kwargs,
|
| 120 |
+
):
|
| 121 |
+
self.image_start_token_id = image_start_token_id
|
| 122 |
+
self.image_end_token_id = image_end_token_id
|
| 123 |
+
self.image_pad_token_id = image_pad_token_id
|
| 124 |
+
self.image_newline_token_id = image_newline_token_id
|
| 125 |
+
self.vq_config = PretrainedConfig(**vq_config)
|
| 126 |
+
self.visual_decoder_config = PretrainedConfig(**visual_decoder_config)
|
| 127 |
+
self.visual_decoder_config.image_decoder_config = PretrainedConfig(**getattr(self.visual_decoder_config, "image_decoder_config", {}))
|
| 128 |
+
self.visual_decoder_config.transformer_config = PretrainedConfig(**getattr(self.visual_decoder_config, "transformer_config", {}))
|
| 129 |
+
self.visual_decoder_config.vae_config = PretrainedConfig(**getattr(self.visual_decoder_config, "vae_config", {}))
|
| 130 |
+
self.visual_decoder_config.scheduler_config = PretrainedConfig(**getattr(self.visual_decoder_config, "scheduler_config", {}))
|
| 131 |
+
super().__init__(**kwargs)
|
| 132 |
+
|
| 133 |
+
class LongcatNextAudioConfig(WhisperConfig):
|
| 134 |
+
model_type = "longcat_next_audio"
|
| 135 |
+
base_config_key = ""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
vq_config={},
|
| 140 |
+
vocoder_config={},
|
| 141 |
+
flow_matching_config={},
|
| 142 |
+
cosy24kvocoder_config={},
|
| 143 |
+
**kwargs
|
| 144 |
+
):
|
| 145 |
+
self.vq_config = PretrainedConfig(**vq_config)
|
| 146 |
+
self.vocoder_config = PretrainedConfig(**vocoder_config)
|
| 147 |
+
self.flow_matching_config = PretrainedConfig(**flow_matching_config)
|
| 148 |
+
self.flow_matching_config.cfm_params = PretrainedConfig(**getattr(self.flow_matching_config, "cfm_params", {}))
|
| 149 |
+
self.cosy24kvocoder_config = PretrainedConfig(**cosy24kvocoder_config)
|
| 150 |
+
super().__init__(**kwargs)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
__all__ = ["LongcatNextConfig", "LongcatNextVisualConfig", "LongcatNextAudioConfig"]
|
configuration_longcat_ngram.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.longcat_flash import LongcatFlashConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LongcatFlashNgramConfig(LongcatFlashConfig):
|
| 5 |
+
r"""
|
| 6 |
+
This is the configuration class to store the configuration of a [`LongcatFlashNgramModel`]. It is used to instantiate
|
| 7 |
+
a LongCat Flash model with N-gram enhanced embeddings according to the specified arguments, defining the model architecture.
|
| 8 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 9 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 14 |
+
Vocabulary size of the LongCat Flash model. Defines the number of different tokens that can be represented by the
|
| 15 |
+
`input_ids` passed when calling [`LongcatFlashNgramModel`]
|
| 16 |
+
hidden_size (`int`, *optional*, defaults to 6144):
|
| 17 |
+
Dimension of the hidden representations.
|
| 18 |
+
num_hidden_layers (`int`, *optional*, defaults to 56):
|
| 19 |
+
Number of hidden layers in the Transformer decoder.
|
| 20 |
+
num_layers (`int`, *optional*, defaults to 28):
|
| 21 |
+
Number of layers, each with 2 sublayers.
|
| 22 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 23 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 24 |
+
num_key_value_heads (`int`, *optional*):
|
| 25 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 26 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 27 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 28 |
+
converting from a multi-head checkpoint to a GQA checkpoint, each group key and value head should be
|
| 29 |
+
constructed by meanpooling all the original heads within that group. For more details checkout [this
|
| 30 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 31 |
+
`num_attention_heads`.
|
| 32 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 33 |
+
The non-linear activation function (function or string) in the decoder.
|
| 34 |
+
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
| 35 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 36 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 37 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 38 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 39 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 40 |
+
The epsilon value used by the RMS normalization layers.
|
| 41 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 42 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 43 |
+
relevant if `config.is_decoder=True`.
|
| 44 |
+
pad_token_id (`int`, *optional*):
|
| 45 |
+
Padding token id.
|
| 46 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 47 |
+
Beginning of stream token id.
|
| 48 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 49 |
+
End of stream token id.
|
| 50 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
Whether to tie input and output embeddings.
|
| 52 |
+
rope_theta (`float`, *optional*, defaults to 10000000.0):
|
| 53 |
+
The base period of the RoPE embeddings.
|
| 54 |
+
rope_scaling (`Dict`, *optional*):
|
| 55 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 56 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 57 |
+
`{"type": strategy name, "factor": scaling factor}`.
|
| 58 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 60 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 61 |
+
The dropout ratio for the attention probabilities.
|
| 62 |
+
ffn_hidden_size (`int`, *optional*, defaults to 12288):
|
| 63 |
+
Dimension of the MLP representations.
|
| 64 |
+
q_lora_rank (`int`, *optional*, defaults to 1536):
|
| 65 |
+
The rank of the query LoRA projection in MLA (Multi-head Latent Attention).
|
| 66 |
+
kv_lora_rank (`int`, *optional*, defaults to 512):
|
| 67 |
+
The rank of the key-value LoRA projection in MLA.
|
| 68 |
+
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
| 69 |
+
The dimension of the non-position encoding part of query/key heads.
|
| 70 |
+
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
| 71 |
+
The dimension of the RoPE part of query/key heads.
|
| 72 |
+
head_dim (`int`, *optional*, defaults to 64):
|
| 73 |
+
Standard dimension of qk heads, unused except for CI.
|
| 74 |
+
v_head_dim (`int`, *optional*, defaults to 128):
|
| 75 |
+
The dimension of value heads.
|
| 76 |
+
qk_head_dim (`int`, *optional*):
|
| 77 |
+
The total dimension of query/key heads. If not specified, set to `qk_nope_head_dim + qk_rope_head_dim`.
|
| 78 |
+
moe_topk (`int`, *optional*, defaults to 12):
|
| 79 |
+
Number of experts to route to for each token in the MoE layer.
|
| 80 |
+
n_routed_experts (`int`, *optional*, defaults to 512):
|
| 81 |
+
Number of routed experts in the MoE layer.
|
| 82 |
+
zero_expert_num (`int`, *optional*, defaults to 256):
|
| 83 |
+
Number of zero experts (identity function) to add to the expert pool.
|
| 84 |
+
expert_ffn_hidden_size (`int`, *optional*, defaults to 2048):
|
| 85 |
+
Hidden size of individual expert FFN layers.
|
| 86 |
+
routed_scaling_factor (`float`, *optional*, defaults to 6.0):
|
| 87 |
+
Scaling factor applied to the routing weights.
|
| 88 |
+
emb_neighbor_num (`int`, *optional*):
|
| 89 |
+
Maximum N-gram length for N-gram embeddings. This parameter determines the context window size for N-gram computation. Higher values capture
|
| 90 |
+
longer-range lexical patterns but increase memory usage.
|
| 91 |
+
emb_split_num (`int`, *optional*):
|
| 92 |
+
Number of hash functions (or splits) to use for N-gram embeddings. Multiple hash functions help improve the quality of N-gram representations.
|
| 93 |
+
ngram_vocab_size_ratio (`float`, *optional*):
|
| 94 |
+
Ratio multiplier for N-gram vocabulary size relative to the base vocabulary size. The N-gram vocabulary
|
| 95 |
+
size is calculated as `vocab_size * ngram_vocab_size_ratio`.
|
| 96 |
+
|
| 97 |
+
Example:
|
| 98 |
+
```python
|
| 99 |
+
>>> from transformers import LongcatFlashNgramModel, LongcatFlashNgramConfig
|
| 100 |
+
|
| 101 |
+
>>> # Initializing a LongCat Flash N-gram style configuration
|
| 102 |
+
>>> configuration = LongcatFlashNgramConfig(
|
| 103 |
+
... emb_neighbor_num=3,
|
| 104 |
+
... emb_split_num=4,
|
| 105 |
+
... ngram_vocab_size_ratio=1.5
|
| 106 |
+
... )
|
| 107 |
+
|
| 108 |
+
>>> # Initializing a model from the configuration
|
| 109 |
+
>>> model = LongcatFlashNgramModel(configuration)
|
| 110 |
+
|
| 111 |
+
>>> # Accessing the model configuration
|
| 112 |
+
>>> configuration = model.config
|
| 113 |
+
```"""
|
| 114 |
+
|
| 115 |
+
model_type = "longcat_flash_ngram"
|
| 116 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 117 |
+
base_model_tp_plan = {
|
| 118 |
+
"layers.*.self_attn.*.q_b_proj": "colwise",
|
| 119 |
+
"layers.*.self_attn.*.kv_b_proj": "colwise",
|
| 120 |
+
"layers.*.self_attn.*.o_proj": "rowwise",
|
| 121 |
+
"layers.*.mlps.*.gate_proj": "colwise",
|
| 122 |
+
"layers.*.mlps.*.up_proj": "colwise",
|
| 123 |
+
"layers.*.mlps.*.down_proj": "rowwise",
|
| 124 |
+
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
| 125 |
+
"layers.*.mlp.experts.*.up_proj": "colwise",
|
| 126 |
+
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
base_model_pp_plan = {
|
| 130 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 131 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 132 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
vocab_size=131072,
|
| 138 |
+
hidden_size=6144,
|
| 139 |
+
num_hidden_layers=56,
|
| 140 |
+
num_layers=28,
|
| 141 |
+
num_attention_heads=64,
|
| 142 |
+
num_key_value_heads=None,
|
| 143 |
+
hidden_act="silu",
|
| 144 |
+
max_position_embeddings=131072,
|
| 145 |
+
initializer_range=0.02,
|
| 146 |
+
rms_norm_eps=1e-5,
|
| 147 |
+
use_cache=True,
|
| 148 |
+
pad_token_id=None,
|
| 149 |
+
bos_token_id=1,
|
| 150 |
+
eos_token_id=2,
|
| 151 |
+
tie_word_embeddings=False,
|
| 152 |
+
rope_theta=10000000.0,
|
| 153 |
+
rope_scaling=None,
|
| 154 |
+
attention_bias=False,
|
| 155 |
+
attention_dropout=0.0,
|
| 156 |
+
ffn_hidden_size=12288,
|
| 157 |
+
q_lora_rank=1536,
|
| 158 |
+
kv_lora_rank=512,
|
| 159 |
+
qk_nope_head_dim=128,
|
| 160 |
+
qk_rope_head_dim=64,
|
| 161 |
+
head_dim=64,
|
| 162 |
+
v_head_dim=128,
|
| 163 |
+
qk_head_dim=None,
|
| 164 |
+
moe_topk=12,
|
| 165 |
+
n_routed_experts=512,
|
| 166 |
+
zero_expert_num=256,
|
| 167 |
+
expert_ffn_hidden_size=2048,
|
| 168 |
+
routed_scaling_factor=6.0,
|
| 169 |
+
emb_neighbor_num=None,
|
| 170 |
+
emb_split_num=None,
|
| 171 |
+
ngram_vocab_size_ratio=None,
|
| 172 |
+
oe_ignored_token_ids=[],
|
| 173 |
+
**kwargs,
|
| 174 |
+
):
|
| 175 |
+
# N-gram embedding specific parameters
|
| 176 |
+
self.emb_neighbor_num = emb_neighbor_num
|
| 177 |
+
self.emb_split_num = emb_split_num
|
| 178 |
+
self.ngram_vocab_size_ratio = ngram_vocab_size_ratio
|
| 179 |
+
self.oe_ignored_token_ids = oe_ignored_token_ids
|
| 180 |
+
|
| 181 |
+
super().__init__(
|
| 182 |
+
vocab_size=vocab_size,
|
| 183 |
+
hidden_size=hidden_size,
|
| 184 |
+
num_hidden_layers=num_hidden_layers,
|
| 185 |
+
num_layers=num_layers,
|
| 186 |
+
num_attention_heads=num_attention_heads,
|
| 187 |
+
num_key_value_heads=num_key_value_heads,
|
| 188 |
+
hidden_act=hidden_act,
|
| 189 |
+
max_position_embeddings=max_position_embeddings,
|
| 190 |
+
initializer_range=initializer_range,
|
| 191 |
+
rms_norm_eps=rms_norm_eps,
|
| 192 |
+
use_cache=use_cache,
|
| 193 |
+
pad_token_id=pad_token_id,
|
| 194 |
+
bos_token_id=bos_token_id,
|
| 195 |
+
eos_token_id=eos_token_id,
|
| 196 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 197 |
+
rope_theta=rope_theta,
|
| 198 |
+
rope_scaling=rope_scaling,
|
| 199 |
+
attention_bias=attention_bias,
|
| 200 |
+
attention_dropout=attention_dropout,
|
| 201 |
+
ffn_hidden_size=ffn_hidden_size,
|
| 202 |
+
q_lora_rank=q_lora_rank,
|
| 203 |
+
kv_lora_rank=kv_lora_rank,
|
| 204 |
+
qk_nope_head_dim=qk_nope_head_dim,
|
| 205 |
+
qk_rope_head_dim=qk_rope_head_dim,
|
| 206 |
+
head_dim=head_dim,
|
| 207 |
+
v_head_dim=v_head_dim,
|
| 208 |
+
qk_head_dim=qk_head_dim,
|
| 209 |
+
moe_topk=moe_topk,
|
| 210 |
+
n_routed_experts=n_routed_experts,
|
| 211 |
+
zero_expert_num=zero_expert_num,
|
| 212 |
+
expert_ffn_hidden_size=expert_ffn_hidden_size,
|
| 213 |
+
routed_scaling_factor=routed_scaling_factor,
|
| 214 |
+
**kwargs,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
__all__ = ["LongcatFlashNgramConfig"]
|
cosy24k_vocoder.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""HIFI-GAN"""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, Optional, List
|
| 18 |
+
import numpy as np
|
| 19 |
+
from scipy.signal import get_window
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch.nn import Conv1d
|
| 24 |
+
from torch.nn import ConvTranspose1d
|
| 25 |
+
from torch.nn.utils import remove_weight_norm
|
| 26 |
+
from torch.nn.utils import weight_norm
|
| 27 |
+
from torch.distributions.uniform import Uniform
|
| 28 |
+
from torch.nn import Parameter
|
| 29 |
+
from torch import nn, sin, pow
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Snake(nn.Module):
|
| 33 |
+
'''
|
| 34 |
+
Implementation of a sine-based periodic activation function
|
| 35 |
+
Shape:
|
| 36 |
+
- Input: (B, C, T)
|
| 37 |
+
- Output: (B, C, T), same shape as the input
|
| 38 |
+
Parameters:
|
| 39 |
+
- alpha - trainable parameter
|
| 40 |
+
References:
|
| 41 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 42 |
+
https://arxiv.org/abs/2006.08195
|
| 43 |
+
Examples:
|
| 44 |
+
>>> a1 = snake(256)
|
| 45 |
+
>>> x = torch.randn(256)
|
| 46 |
+
>>> x = a1(x)
|
| 47 |
+
'''
|
| 48 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 49 |
+
'''
|
| 50 |
+
Initialization.
|
| 51 |
+
INPUT:
|
| 52 |
+
- in_features: shape of the input
|
| 53 |
+
- alpha: trainable parameter
|
| 54 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 55 |
+
alpha will be trained along with the rest of your model.
|
| 56 |
+
'''
|
| 57 |
+
super(Snake, self).__init__()
|
| 58 |
+
self.in_features = in_features
|
| 59 |
+
|
| 60 |
+
# initialize alpha
|
| 61 |
+
self.alpha_logscale = alpha_logscale
|
| 62 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 63 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 64 |
+
else: # linear scale alphas initialized to ones
|
| 65 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 66 |
+
|
| 67 |
+
self.alpha.requires_grad = alpha_trainable
|
| 68 |
+
|
| 69 |
+
self.no_div_by_zero = 0.000000001
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
'''
|
| 73 |
+
Forward pass of the function.
|
| 74 |
+
Applies the function to the input elementwise.
|
| 75 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 76 |
+
'''
|
| 77 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 78 |
+
if self.alpha_logscale:
|
| 79 |
+
alpha = torch.exp(alpha)
|
| 80 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 81 |
+
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def get_padding(kernel_size, dilation=1):
|
| 85 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 86 |
+
|
| 87 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 88 |
+
classname = m.__class__.__name__
|
| 89 |
+
if classname.find("Conv") != -1:
|
| 90 |
+
m.weight.data.normal_(mean, std)
|
| 91 |
+
|
| 92 |
+
"""hifigan based generator implementation.
|
| 93 |
+
|
| 94 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
| 95 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
| 96 |
+
https://github.com/NVIDIA/BigVGAN
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ResBlock(torch.nn.Module):
|
| 102 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
channels: int = 512,
|
| 106 |
+
kernel_size: int = 3,
|
| 107 |
+
dilations: List[int] = [1, 3, 5],
|
| 108 |
+
):
|
| 109 |
+
super(ResBlock, self).__init__()
|
| 110 |
+
self.convs1 = nn.ModuleList()
|
| 111 |
+
self.convs2 = nn.ModuleList()
|
| 112 |
+
|
| 113 |
+
for dilation in dilations:
|
| 114 |
+
self.convs1.append(
|
| 115 |
+
weight_norm(
|
| 116 |
+
Conv1d(
|
| 117 |
+
channels,
|
| 118 |
+
channels,
|
| 119 |
+
kernel_size,
|
| 120 |
+
1,
|
| 121 |
+
dilation=dilation,
|
| 122 |
+
padding=get_padding(kernel_size, dilation)
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
self.convs2.append(
|
| 127 |
+
weight_norm(
|
| 128 |
+
Conv1d(
|
| 129 |
+
channels,
|
| 130 |
+
channels,
|
| 131 |
+
kernel_size,
|
| 132 |
+
1,
|
| 133 |
+
dilation=1,
|
| 134 |
+
padding=get_padding(kernel_size, 1)
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
self.convs1.apply(init_weights)
|
| 139 |
+
self.convs2.apply(init_weights)
|
| 140 |
+
self.activations1 = nn.ModuleList([
|
| 141 |
+
Snake(channels, alpha_logscale=False)
|
| 142 |
+
for _ in range(len(self.convs1))
|
| 143 |
+
])
|
| 144 |
+
self.activations2 = nn.ModuleList([
|
| 145 |
+
Snake(channels, alpha_logscale=False)
|
| 146 |
+
for _ in range(len(self.convs2))
|
| 147 |
+
])
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 150 |
+
for idx in range(len(self.convs1)):
|
| 151 |
+
xt = self.activations1[idx](x)
|
| 152 |
+
xt = self.convs1[idx](xt)
|
| 153 |
+
xt = self.activations2[idx](xt)
|
| 154 |
+
xt = self.convs2[idx](xt)
|
| 155 |
+
x = xt + x
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
def remove_weight_norm(self):
|
| 159 |
+
for idx in range(len(self.convs1)):
|
| 160 |
+
remove_weight_norm(self.convs1[idx])
|
| 161 |
+
remove_weight_norm(self.convs2[idx])
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class SineGen(torch.nn.Module):
|
| 165 |
+
""" Definition of sine generator
|
| 166 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 167 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 168 |
+
voiced_threshold = 0,
|
| 169 |
+
flag_for_pulse=False)
|
| 170 |
+
samp_rate: sampling rate in Hz
|
| 171 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 172 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 173 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 174 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 175 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 176 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 177 |
+
segment is always sin(np.pi) or cos(0)
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
| 181 |
+
sine_amp=0.1, noise_std=0.003,
|
| 182 |
+
voiced_threshold=0):
|
| 183 |
+
super(SineGen, self).__init__()
|
| 184 |
+
self.sine_amp = sine_amp
|
| 185 |
+
self.noise_std = noise_std
|
| 186 |
+
self.harmonic_num = harmonic_num
|
| 187 |
+
self.sampling_rate = samp_rate
|
| 188 |
+
self.voiced_threshold = voiced_threshold
|
| 189 |
+
|
| 190 |
+
def _f02uv(self, f0):
|
| 191 |
+
# generate uv signal
|
| 192 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 193 |
+
return uv
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def forward(self, f0):
|
| 197 |
+
"""
|
| 198 |
+
:param f0: [B, 1, sample_len], Hz
|
| 199 |
+
:return: [B, 1, sample_len]
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
| 203 |
+
for i in range(self.harmonic_num + 1):
|
| 204 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
| 205 |
+
|
| 206 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 207 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 208 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
| 209 |
+
phase_vec[:, 0, :] = 0
|
| 210 |
+
|
| 211 |
+
# generate sine waveforms
|
| 212 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
| 213 |
+
|
| 214 |
+
# generate uv signal
|
| 215 |
+
uv = self._f02uv(f0)
|
| 216 |
+
|
| 217 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 218 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 219 |
+
# . for voiced regions is self.noise_std
|
| 220 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 221 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 222 |
+
|
| 223 |
+
# first: set the unvoiced part to 0 by uv
|
| 224 |
+
# then: additive noise
|
| 225 |
+
sine_waves = sine_waves * uv + noise
|
| 226 |
+
return sine_waves, uv, noise
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 230 |
+
""" SourceModule for hn-nsf
|
| 231 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 232 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 233 |
+
sampling_rate: sampling_rate in Hz
|
| 234 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 235 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 236 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 237 |
+
note that amplitude of noise in unvoiced is decided
|
| 238 |
+
by sine_amp
|
| 239 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 240 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 241 |
+
F0_sampled (batchsize, length, 1)
|
| 242 |
+
Sine_source (batchsize, length, 1)
|
| 243 |
+
noise_source (batchsize, length 1)
|
| 244 |
+
uv (batchsize, length, 1)
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 248 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 249 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 250 |
+
|
| 251 |
+
self.sine_amp = sine_amp
|
| 252 |
+
self.noise_std = add_noise_std
|
| 253 |
+
|
| 254 |
+
# to produce sine waveforms
|
| 255 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
| 256 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 257 |
+
|
| 258 |
+
# to merge source harmonics into a single excitation
|
| 259 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 260 |
+
self.l_tanh = torch.nn.Tanh()
|
| 261 |
+
|
| 262 |
+
def forward(self, x):
|
| 263 |
+
"""
|
| 264 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 265 |
+
F0_sampled (batchsize, length, 1)
|
| 266 |
+
Sine_source (batchsize, length, 1)
|
| 267 |
+
noise_source (batchsize, length 1)
|
| 268 |
+
"""
|
| 269 |
+
# source for harmonic branch
|
| 270 |
+
with torch.no_grad():
|
| 271 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
| 272 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
| 273 |
+
uv = uv.transpose(1, 2)
|
| 274 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 275 |
+
|
| 276 |
+
# source for noise branch, in the same shape as uv
|
| 277 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 278 |
+
return sine_merge, noise, uv
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class HiFTGenerator(nn.Module):
|
| 282 |
+
"""
|
| 283 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 284 |
+
https://arxiv.org/abs/2309.09493
|
| 285 |
+
"""
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
in_channels: int = 80,
|
| 289 |
+
base_channels: int = 512,
|
| 290 |
+
nb_harmonics: int = 8,
|
| 291 |
+
sampling_rate: int = 22050,
|
| 292 |
+
nsf_alpha: float = 0.1,
|
| 293 |
+
nsf_sigma: float = 0.003,
|
| 294 |
+
nsf_voiced_threshold: float = 10,
|
| 295 |
+
upsample_rates: List[int] = [8, 8],
|
| 296 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
| 297 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
| 298 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
| 299 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 300 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
| 301 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
| 302 |
+
lrelu_slope: float = 0.1,
|
| 303 |
+
audio_limit: float = 0.99,
|
| 304 |
+
f0_predictor: torch.nn.Module = None,
|
| 305 |
+
):
|
| 306 |
+
super(HiFTGenerator, self).__init__()
|
| 307 |
+
|
| 308 |
+
self.out_channels = 1
|
| 309 |
+
self.nb_harmonics = nb_harmonics
|
| 310 |
+
self.sampling_rate = sampling_rate
|
| 311 |
+
self.istft_params = istft_params
|
| 312 |
+
self.lrelu_slope = lrelu_slope
|
| 313 |
+
self.audio_limit = audio_limit
|
| 314 |
+
|
| 315 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 316 |
+
self.num_upsamples = len(upsample_rates)
|
| 317 |
+
self.m_source = SourceModuleHnNSF(
|
| 318 |
+
sampling_rate=sampling_rate,
|
| 319 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
| 320 |
+
harmonic_num=nb_harmonics,
|
| 321 |
+
sine_amp=nsf_alpha,
|
| 322 |
+
add_noise_std=nsf_sigma,
|
| 323 |
+
voiced_threshod=nsf_voiced_threshold)
|
| 324 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
| 325 |
+
|
| 326 |
+
self.conv_pre = weight_norm(
|
| 327 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Up
|
| 331 |
+
self.ups = nn.ModuleList()
|
| 332 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 333 |
+
self.ups.append(
|
| 334 |
+
weight_norm(
|
| 335 |
+
ConvTranspose1d(
|
| 336 |
+
base_channels // (2**i),
|
| 337 |
+
base_channels // (2**(i + 1)),
|
| 338 |
+
k,
|
| 339 |
+
u,
|
| 340 |
+
padding=(k - u) // 2,
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Down
|
| 346 |
+
self.source_downs = nn.ModuleList()
|
| 347 |
+
self.source_resblocks = nn.ModuleList()
|
| 348 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 349 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 350 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
| 351 |
+
if u == 1:
|
| 352 |
+
self.source_downs.append(
|
| 353 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
self.source_downs.append(
|
| 357 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
self.source_resblocks.append(
|
| 361 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.resblocks = nn.ModuleList()
|
| 365 |
+
for i in range(len(self.ups)):
|
| 366 |
+
ch = base_channels // (2**(i + 1))
|
| 367 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 368 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
| 369 |
+
|
| 370 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
| 371 |
+
self.ups.apply(init_weights)
|
| 372 |
+
self.conv_post.apply(init_weights)
|
| 373 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 374 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
| 375 |
+
self.f0_predictor = f0_predictor
|
| 376 |
+
|
| 377 |
+
def remove_weight_norm(self):
|
| 378 |
+
print('Removing weight norm...')
|
| 379 |
+
for l in self.ups:
|
| 380 |
+
remove_weight_norm(l)
|
| 381 |
+
for l in self.resblocks:
|
| 382 |
+
l.remove_weight_norm()
|
| 383 |
+
remove_weight_norm(self.conv_pre)
|
| 384 |
+
remove_weight_norm(self.conv_post)
|
| 385 |
+
self.m_source.remove_weight_norm()
|
| 386 |
+
for l in self.source_downs:
|
| 387 |
+
remove_weight_norm(l)
|
| 388 |
+
for l in self.source_resblocks:
|
| 389 |
+
l.remove_weight_norm()
|
| 390 |
+
|
| 391 |
+
def _stft(self, x):
|
| 392 |
+
spec = torch.stft(
|
| 393 |
+
x,
|
| 394 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
| 395 |
+
return_complex=True)
|
| 396 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 397 |
+
return spec[..., 0], spec[..., 1]
|
| 398 |
+
|
| 399 |
+
def _istft(self, magnitude, phase):
|
| 400 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 401 |
+
real = magnitude * torch.cos(phase)
|
| 402 |
+
img = magnitude * torch.sin(phase)
|
| 403 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
| 404 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
| 405 |
+
return inverse_transform
|
| 406 |
+
|
| 407 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 408 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 409 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 410 |
+
|
| 411 |
+
x = self.conv_pre(x)
|
| 412 |
+
for i in range(self.num_upsamples):
|
| 413 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 414 |
+
x = self.ups[i](x)
|
| 415 |
+
|
| 416 |
+
if i == self.num_upsamples - 1:
|
| 417 |
+
x = self.reflection_pad(x)
|
| 418 |
+
|
| 419 |
+
# fusion
|
| 420 |
+
si = self.source_downs[i](s_stft)
|
| 421 |
+
si = self.source_resblocks[i](si)
|
| 422 |
+
x = x + si
|
| 423 |
+
|
| 424 |
+
xs = None
|
| 425 |
+
for j in range(self.num_kernels):
|
| 426 |
+
if xs is None:
|
| 427 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 428 |
+
else:
|
| 429 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 430 |
+
x = xs / self.num_kernels
|
| 431 |
+
|
| 432 |
+
x = F.leaky_relu(x)
|
| 433 |
+
x = self.conv_post(x)
|
| 434 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
| 435 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
| 436 |
+
|
| 437 |
+
x = self._istft(magnitude, phase)
|
| 438 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
| 439 |
+
return x
|
| 440 |
+
|
| 441 |
+
def forward(
|
| 442 |
+
self,
|
| 443 |
+
batch: dict,
|
| 444 |
+
# device: torch.device,
|
| 445 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 446 |
+
speech_feat = batch['speech_feat'].transpose(1, 2) # .to(device)
|
| 447 |
+
# mel->f0
|
| 448 |
+
f0 = self.f0_predictor(speech_feat)
|
| 449 |
+
# f0->source
|
| 450 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 451 |
+
s, _, _ = self.m_source(s)
|
| 452 |
+
s = s.transpose(1, 2)
|
| 453 |
+
# mel+source->speech
|
| 454 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 455 |
+
return generated_speech, f0
|
| 456 |
+
|
| 457 |
+
@torch.inference_mode()
|
| 458 |
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 459 |
+
# mel->f0
|
| 460 |
+
f0 = self.f0_predictor(speech_feat)
|
| 461 |
+
# f0->source
|
| 462 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 463 |
+
s, _, _ = self.m_source(s)
|
| 464 |
+
s = s.transpose(1, 2)
|
| 465 |
+
# use cache_source to avoid glitch
|
| 466 |
+
if cache_source.shape[2] != 0:
|
| 467 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
| 468 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 469 |
+
return generated_speech, s
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class ConvRNNF0Predictor(nn.Module):
|
| 473 |
+
def __init__(self,
|
| 474 |
+
num_class: int = 1,
|
| 475 |
+
in_channels: int = 80,
|
| 476 |
+
cond_channels: int = 512
|
| 477 |
+
):
|
| 478 |
+
super().__init__()
|
| 479 |
+
|
| 480 |
+
self.num_class = num_class
|
| 481 |
+
self.condnet = nn.Sequential(
|
| 482 |
+
weight_norm(
|
| 483 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
| 484 |
+
),
|
| 485 |
+
nn.ELU(),
|
| 486 |
+
weight_norm(
|
| 487 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 488 |
+
),
|
| 489 |
+
nn.ELU(),
|
| 490 |
+
weight_norm(
|
| 491 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 492 |
+
),
|
| 493 |
+
nn.ELU(),
|
| 494 |
+
weight_norm(
|
| 495 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 496 |
+
),
|
| 497 |
+
nn.ELU(),
|
| 498 |
+
weight_norm(
|
| 499 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 500 |
+
),
|
| 501 |
+
nn.ELU(),
|
| 502 |
+
)
|
| 503 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 504 |
+
|
| 505 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 506 |
+
x = self.condnet(x)
|
| 507 |
+
x = x.transpose(1, 2)
|
| 508 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class Cosy24kVocoder(nn.Module):
|
| 512 |
+
def __init__(self):
|
| 513 |
+
super().__init__()
|
| 514 |
+
self.hifigan_generator = HiFTGenerator(
|
| 515 |
+
in_channels=80,
|
| 516 |
+
base_channels=512,
|
| 517 |
+
nb_harmonics=8,
|
| 518 |
+
sampling_rate=24000,
|
| 519 |
+
nsf_alpha=0.1,
|
| 520 |
+
nsf_sigma=0.003,
|
| 521 |
+
nsf_voiced_threshold=10,
|
| 522 |
+
upsample_rates=[8, 5, 3],
|
| 523 |
+
upsample_kernel_sizes=[16, 11, 7],
|
| 524 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 525 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 526 |
+
source_resblock_kernel_sizes=[7, 7, 11],
|
| 527 |
+
source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 528 |
+
lrelu_slope=0.1,
|
| 529 |
+
audio_limit=0.99,
|
| 530 |
+
f0_predictor=ConvRNNF0Predictor(
|
| 531 |
+
num_class=1,
|
| 532 |
+
in_channels=80,
|
| 533 |
+
cond_channels=512,
|
| 534 |
+
),
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def decode(self, mel, device="cuda"):
|
| 538 |
+
"""
|
| 539 |
+
Args: mel: (batch_size, n_frames, n_mel)
|
| 540 |
+
"""
|
| 541 |
+
generated_speech, f0 = self.hifigan_generator.forward(
|
| 542 |
+
{"speech_feat": mel.transpose(1, 2)}, # device=device
|
| 543 |
+
)
|
| 544 |
+
return generated_speech
|
| 545 |
+
|
| 546 |
+
@classmethod
|
| 547 |
+
def from_pretrained(cls, model_path: str):
|
| 548 |
+
"""Load a pretrained model from a checkpoint."""
|
| 549 |
+
model = cls()
|
| 550 |
+
model.hifigan_generator.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
|
| 551 |
+
model.eval()
|
| 552 |
+
return model
|
generation_config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio_generation_config": {
|
| 3 |
+
"audio_parallel_decoding": false,
|
| 4 |
+
"custom_params": {
|
| 5 |
+
"sampling_rate": 24000,
|
| 6 |
+
"wave_concat_overlap": 1200
|
| 7 |
+
},
|
| 8 |
+
"do_sample": true,
|
| 9 |
+
"repetition_penalty": 1.3,
|
| 10 |
+
"temperature": 0.5,
|
| 11 |
+
"top_k": 5,
|
| 12 |
+
"top_p": 0.85
|
| 13 |
+
},
|
| 14 |
+
"bos_token_id": 1,
|
| 15 |
+
"do_sample": true,
|
| 16 |
+
"eos_token_id": 2,
|
| 17 |
+
"max_new_tokens": 2048,
|
| 18 |
+
"pad_token_id": 3,
|
| 19 |
+
"repetition_penalty": 1.1,
|
| 20 |
+
"temperature": 0.4,
|
| 21 |
+
"top_k": 20,
|
| 22 |
+
"top_p": 0.85,
|
| 23 |
+
"transformers_version": "4.57.6",
|
| 24 |
+
"visual_generation_config": {
|
| 25 |
+
"custom_params": {
|
| 26 |
+
"anyres_prefix": "<longcat_img_token_size>{h} {w}</longcat_img_token_size>",
|
| 27 |
+
"cfg_scale": 3.0,
|
| 28 |
+
"token_h": 37,
|
| 29 |
+
"token_w": 37
|
| 30 |
+
},
|
| 31 |
+
"do_sample": true,
|
| 32 |
+
"temperature": 0.5,
|
| 33 |
+
"top_k": 1024,
|
| 34 |
+
"top_p": 0.75
|
| 35 |
+
}
|
| 36 |
+
}
|
image_refiner.py
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image refiner: refiner pipeline, refiner container, and utilities.
|
| 2 |
+
|
| 3 |
+
Contains:
|
| 4 |
+
- RefinerImageProcessor: Image pre/post-processing for the diffusion pipeline
|
| 5 |
+
- RefinerPipeline: DiffusionPipeline for image refinement
|
| 6 |
+
- ImageRefinerContainer: nn.Module container for refiner sub-modules
|
| 7 |
+
- IdentityWithArgs: Placeholder module for cond_proj
|
| 8 |
+
- de_transform / tensor2pil: Tensor-to-PIL conversion utilities
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import inspect
|
| 12 |
+
import math
|
| 13 |
+
import warnings
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from diffusers import DiffusionPipeline
|
| 25 |
+
from diffusers.configuration_utils import register_to_config
|
| 26 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
|
| 27 |
+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
| 28 |
+
from .refiner_modules import FlowMatchEulerDiscreteScheduler
|
| 29 |
+
|
| 30 |
+
from .refiner_modules import Transformer2DModel, RotaryPosEmbed
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Helpers
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _clean_config_dict(cfg, cls=None) -> dict:
|
| 38 |
+
"""Convert a PretrainedConfig to a clean dict for model construction.
|
| 39 |
+
|
| 40 |
+
If ``cls`` is provided, only keeps keys that match the cls.__init__ params
|
| 41 |
+
(allowlist approach). Otherwise falls back to blocklist filtering.
|
| 42 |
+
"""
|
| 43 |
+
if hasattr(cfg, "to_dict"):
|
| 44 |
+
d = cfg.to_dict()
|
| 45 |
+
elif isinstance(cfg, dict):
|
| 46 |
+
d = dict(cfg)
|
| 47 |
+
else:
|
| 48 |
+
d = {k: v for k, v in vars(cfg).items()}
|
| 49 |
+
|
| 50 |
+
if cls is not None:
|
| 51 |
+
import inspect
|
| 52 |
+
sig = inspect.signature(cls.__init__)
|
| 53 |
+
valid_keys = set(sig.parameters.keys()) - {"self"}
|
| 54 |
+
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
|
| 55 |
+
# Has **kwargs — can't filter by allowlist, fall through to blocklist
|
| 56 |
+
pass
|
| 57 |
+
else:
|
| 58 |
+
return {k: v for k, v in d.items() if k in valid_keys}
|
| 59 |
+
|
| 60 |
+
# Blocklist: remove HuggingFace PretrainedConfig metadata
|
| 61 |
+
_PRETRAINED_CONFIG_KEYS = {
|
| 62 |
+
"_name_or_path", "transformers_version", "model_type", "_commit_hash",
|
| 63 |
+
"_attn_implementation", "_attn_implementation_autoset", "return_dict",
|
| 64 |
+
"output_hidden_states", "output_attentions", "use_bfloat16",
|
| 65 |
+
"torchscript", "torch_dtype", "is_encoder_decoder", "is_decoder",
|
| 66 |
+
"add_cross_attention", "tie_encoder_decoder", "tie_word_embeddings",
|
| 67 |
+
"cross_attention_hidden_size", "chunk_size_feed_forward", "decoder_start_token_id",
|
| 68 |
+
"architectures", "finetuning_task", "id2label", "label2id", "prefix",
|
| 69 |
+
"problem_type", "tokenizer_class", "task_specific_params", "pruned_heads",
|
| 70 |
+
"bos_token_id", "eos_token_id", "pad_token_id", "sep_token_id",
|
| 71 |
+
"max_length", "min_length", "do_sample", "early_stopping",
|
| 72 |
+
"num_beams", "num_beam_groups", "diversity_penalty", "temperature",
|
| 73 |
+
"top_k", "top_p", "typical_p", "repetition_penalty", "length_penalty",
|
| 74 |
+
"no_repeat_ngram_size", "encoder_no_repeat_ngram_size", "bad_words_ids",
|
| 75 |
+
"num_return_sequences", "output_scores", "return_dict_in_generate",
|
| 76 |
+
"forced_bos_token_id", "forced_eos_token_id", "remove_invalid_values",
|
| 77 |
+
"exponential_decay_length_penalty", "suppress_tokens", "begin_suppress_tokens",
|
| 78 |
+
"tf_legacy_loss", "dtype",
|
| 79 |
+
}
|
| 80 |
+
return {k: v for k, v in d.items() if not k.startswith("_") and k not in _PRETRAINED_CONFIG_KEYS}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Image Refiner Container (nn.Module for state_dict loading)
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ImageRefinerContainer(nn.Module):
|
| 89 |
+
"""Container for refiner components.
|
| 90 |
+
|
| 91 |
+
Holds base_transformer, vae, cond_proj as nn.Module children so their
|
| 92 |
+
parameters appear in the parent model's state_dict and are loaded
|
| 93 |
+
automatically via from_pretrained.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, visual_decoder_config):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
tc = visual_decoder_config.transformer_config
|
| 100 |
+
vc = visual_decoder_config.vae_config
|
| 101 |
+
|
| 102 |
+
self.base_transformer = Transformer2DModel(**_clean_config_dict(tc))
|
| 103 |
+
|
| 104 |
+
self.vae = AutoencoderKL(**_clean_config_dict(vc))
|
| 105 |
+
self.vae.requires_grad_(False)
|
| 106 |
+
|
| 107 |
+
text_feat_dim = getattr(tc, "text_feat_dim", 3584)
|
| 108 |
+
codebook_dim = getattr(visual_decoder_config, "codebook_dim", text_feat_dim)
|
| 109 |
+
if codebook_dim != text_feat_dim:
|
| 110 |
+
self.cond_proj = nn.Linear(codebook_dim, text_feat_dim)
|
| 111 |
+
else:
|
| 112 |
+
self.cond_proj = IdentityWithArgs()
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_pretrained(cls, config, model_path: str):
|
| 116 |
+
model = cls(config)
|
| 117 |
+
weight_dict = load_file(model_path, device="cpu")
|
| 118 |
+
model.load_state_dict({k.removeprefix("image_refiner."): v for k, v in weight_dict.items() if k.startswith("image_refiner.")}, strict=True)
|
| 119 |
+
model.eval()
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def device(self):
|
| 124 |
+
return next(self.parameters()).device
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def dtype(self):
|
| 128 |
+
return next(self.parameters()).dtype
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class RefinerImageProcessor(VaeImageProcessor):
|
| 132 |
+
"""Image processor for refiner - extends diffusers' VaeImageProcessor."""
|
| 133 |
+
|
| 134 |
+
@register_to_config
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
do_resize: bool = True,
|
| 138 |
+
vae_scale_factor: int = 16,
|
| 139 |
+
resample: str = "lanczos",
|
| 140 |
+
max_pixels: Optional[int] = None,
|
| 141 |
+
max_side_length: Optional[int] = None,
|
| 142 |
+
do_normalize: bool = True,
|
| 143 |
+
do_binarize: bool = False,
|
| 144 |
+
do_convert_grayscale: bool = False,
|
| 145 |
+
):
|
| 146 |
+
super().__init__(
|
| 147 |
+
do_resize=do_resize,
|
| 148 |
+
vae_scale_factor=vae_scale_factor,
|
| 149 |
+
resample=resample,
|
| 150 |
+
do_normalize=do_normalize,
|
| 151 |
+
do_binarize=do_binarize,
|
| 152 |
+
do_convert_grayscale=do_convert_grayscale,
|
| 153 |
+
)
|
| 154 |
+
self.max_pixels = max_pixels
|
| 155 |
+
self.max_side_length = max_side_length
|
| 156 |
+
|
| 157 |
+
def get_new_height_width(
|
| 158 |
+
self,
|
| 159 |
+
image: Union["PIL.Image.Image", np.ndarray, torch.Tensor],
|
| 160 |
+
height: Optional[int] = None,
|
| 161 |
+
width: Optional[int] = None,
|
| 162 |
+
max_pixels: Optional[int] = None,
|
| 163 |
+
max_side_length: Optional[int] = None,
|
| 164 |
+
) -> Tuple[int, int]:
|
| 165 |
+
import PIL.Image
|
| 166 |
+
|
| 167 |
+
if height is None:
|
| 168 |
+
if isinstance(image, PIL.Image.Image):
|
| 169 |
+
height = image.height
|
| 170 |
+
elif isinstance(image, torch.Tensor):
|
| 171 |
+
height = image.shape[2]
|
| 172 |
+
else:
|
| 173 |
+
height = image.shape[1]
|
| 174 |
+
|
| 175 |
+
if width is None:
|
| 176 |
+
if isinstance(image, PIL.Image.Image):
|
| 177 |
+
width = image.width
|
| 178 |
+
elif isinstance(image, torch.Tensor):
|
| 179 |
+
width = image.shape[3]
|
| 180 |
+
else:
|
| 181 |
+
width = image.shape[2]
|
| 182 |
+
|
| 183 |
+
if max_side_length is None:
|
| 184 |
+
max_side_length = self.max_side_length
|
| 185 |
+
if max_pixels is None:
|
| 186 |
+
max_pixels = self.max_pixels
|
| 187 |
+
|
| 188 |
+
ratio = 1.0
|
| 189 |
+
if max_side_length is not None:
|
| 190 |
+
max_side_length_ratio = max_side_length / max(height, width)
|
| 191 |
+
else:
|
| 192 |
+
max_side_length_ratio = 1.0
|
| 193 |
+
|
| 194 |
+
cur_pixels = height * width
|
| 195 |
+
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 if max_pixels is not None else 1.0
|
| 196 |
+
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0)
|
| 197 |
+
|
| 198 |
+
sf = self.config.vae_scale_factor
|
| 199 |
+
new_height = int(height * ratio) // sf * sf
|
| 200 |
+
new_width = int(width * ratio) // sf * sf
|
| 201 |
+
return new_height, new_width
|
| 202 |
+
|
| 203 |
+
def preprocess(
|
| 204 |
+
self,
|
| 205 |
+
image: PipelineImageInput,
|
| 206 |
+
height: Optional[int] = None,
|
| 207 |
+
width: Optional[int] = None,
|
| 208 |
+
max_pixels: Optional[int] = None,
|
| 209 |
+
max_side_length: Optional[int] = None,
|
| 210 |
+
resize_mode: str = "default",
|
| 211 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
| 212 |
+
) -> torch.Tensor:
|
| 213 |
+
import PIL.Image
|
| 214 |
+
|
| 215 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
| 216 |
+
|
| 217 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
| 218 |
+
if isinstance(image, torch.Tensor):
|
| 219 |
+
image = image.unsqueeze(1)
|
| 220 |
+
else:
|
| 221 |
+
if image.shape[-1] == 1:
|
| 222 |
+
image = np.expand_dims(image, axis=0)
|
| 223 |
+
else:
|
| 224 |
+
image = np.expand_dims(image, axis=-1)
|
| 225 |
+
|
| 226 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
| 227 |
+
warnings.warn(
|
| 228 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated. "
|
| 229 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
| 230 |
+
FutureWarning,
|
| 231 |
+
)
|
| 232 |
+
image = np.concatenate(image, axis=0)
|
| 233 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
| 234 |
+
warnings.warn(
|
| 235 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated. "
|
| 236 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
| 237 |
+
FutureWarning,
|
| 238 |
+
)
|
| 239 |
+
image = torch.cat(image, axis=0)
|
| 240 |
+
|
| 241 |
+
if not is_valid_image_imagelist(image):
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"Input is in incorrect format. Currently, we only support "
|
| 244 |
+
f"{', '.join(str(x) for x in supported_formats)}"
|
| 245 |
+
)
|
| 246 |
+
if not isinstance(image, list):
|
| 247 |
+
image = [image]
|
| 248 |
+
|
| 249 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 250 |
+
if crops_coords is not None:
|
| 251 |
+
image = [i.crop(crops_coords) for i in image]
|
| 252 |
+
if self.config.do_resize:
|
| 253 |
+
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
|
| 254 |
+
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
| 255 |
+
if self.config.do_convert_grayscale:
|
| 256 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
| 257 |
+
image = self.pil_to_numpy(image)
|
| 258 |
+
image = self.numpy_to_pt(image)
|
| 259 |
+
elif isinstance(image[0], np.ndarray):
|
| 260 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
| 261 |
+
image = self.numpy_to_pt(image)
|
| 262 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
| 263 |
+
if self.config.do_resize:
|
| 264 |
+
image = self.resize(image, height, width)
|
| 265 |
+
elif isinstance(image[0], torch.Tensor):
|
| 266 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
| 267 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
| 268 |
+
image = image.unsqueeze(1)
|
| 269 |
+
channel = image.shape[1]
|
| 270 |
+
if channel == self.config.vae_latent_channels:
|
| 271 |
+
return image
|
| 272 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
| 273 |
+
if self.config.do_resize:
|
| 274 |
+
image = self.resize(image, height, width)
|
| 275 |
+
|
| 276 |
+
do_normalize = self.config.do_normalize
|
| 277 |
+
if do_normalize and image.min() < 0:
|
| 278 |
+
warnings.warn(
|
| 279 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. "
|
| 280 |
+
f"The expected value range for image tensor is [0,1] when passing as pytorch tensor or numpy Array. "
|
| 281 |
+
f"You passed `image` with value range [{image.min()},{image.max()}]",
|
| 282 |
+
FutureWarning,
|
| 283 |
+
)
|
| 284 |
+
do_normalize = False
|
| 285 |
+
if do_normalize:
|
| 286 |
+
image = self.normalize(image)
|
| 287 |
+
|
| 288 |
+
if self.config.do_binarize:
|
| 289 |
+
image = self.binarize(image)
|
| 290 |
+
|
| 291 |
+
return image
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@dataclass
|
| 295 |
+
class RefinerOutput:
|
| 296 |
+
images: Union[List[Image.Image], torch.Tensor]
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class IdentityWithArgs(nn.Module):
|
| 300 |
+
"""Placeholder Identity module for cond_proj."""
|
| 301 |
+
|
| 302 |
+
def __init__(self, dtype=torch.float32, device=None):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.register_buffer("_dummy", torch.zeros((), dtype=dtype, device=device))
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def dtype(self):
|
| 308 |
+
return self._dummy.dtype
|
| 309 |
+
|
| 310 |
+
@property
|
| 311 |
+
def device(self):
|
| 312 |
+
return self._dummy.device
|
| 313 |
+
|
| 314 |
+
def forward(self, x, *args, **kwargs):
|
| 315 |
+
return x
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _retrieve_timesteps(
|
| 319 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 320 |
+
num_inference_steps: Optional[int] = None,
|
| 321 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 322 |
+
timesteps: Optional[List[int]] = None,
|
| 323 |
+
**kwargs,
|
| 324 |
+
):
|
| 325 |
+
# If scheduler uses dynamic shifting and caller passed num_tokens, compute mu
|
| 326 |
+
# (same as training code refiner pipeline)
|
| 327 |
+
num_tokens = kwargs.pop("num_tokens", None)
|
| 328 |
+
if num_tokens is not None and getattr(scheduler.config, "use_dynamic_shifting", False):
|
| 329 |
+
# Compute mu from num_tokens using scheduler's linear interpolation
|
| 330 |
+
base_shift = getattr(scheduler.config, "base_shift", 0.5)
|
| 331 |
+
max_shift = getattr(scheduler.config, "max_shift", 1.15)
|
| 332 |
+
base_seq_len = getattr(scheduler.config, "base_image_seq_len", 256)
|
| 333 |
+
max_seq_len = getattr(scheduler.config, "max_image_seq_len", 4096)
|
| 334 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 335 |
+
b = base_shift - m * base_seq_len
|
| 336 |
+
mu = num_tokens * m + b
|
| 337 |
+
kwargs["mu"] = mu
|
| 338 |
+
|
| 339 |
+
accepted = set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 340 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted}
|
| 341 |
+
|
| 342 |
+
if timesteps is not None:
|
| 343 |
+
if "timesteps" not in accepted:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 346 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 347 |
+
)
|
| 348 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **filtered_kwargs)
|
| 349 |
+
timesteps = scheduler.timesteps
|
| 350 |
+
num_inference_steps = len(timesteps)
|
| 351 |
+
else:
|
| 352 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **filtered_kwargs)
|
| 353 |
+
timesteps = scheduler.timesteps
|
| 354 |
+
return timesteps, num_inference_steps
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class RefinerPipeline(DiffusionPipeline):
|
| 358 |
+
"""
|
| 359 |
+
Image refiner evaluation pipeline.
|
| 360 |
+
|
| 361 |
+
- cond comes from upstream model: encoder_hidden_states (quants / last_latent)
|
| 362 |
+
- grid_thw_list is used to split cond (consistent with training)
|
| 363 |
+
- image as ref image
|
| 364 |
+
- Supports FlowMatchEulerDiscreteScheduler + velocity model
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
vae: AutoencoderKL,
|
| 370 |
+
transformer: Transformer2DModel,
|
| 371 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 372 |
+
cond_proj: Optional[nn.Module] = None,
|
| 373 |
+
):
|
| 374 |
+
super().__init__()
|
| 375 |
+
|
| 376 |
+
self.register_modules(
|
| 377 |
+
vae=vae,
|
| 378 |
+
transformer=transformer,
|
| 379 |
+
scheduler=scheduler,
|
| 380 |
+
cond_proj=cond_proj if cond_proj is not None else IdentityWithArgs(),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.vae_scale_factor = (
|
| 384 |
+
2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 385 |
+
if hasattr(self.vae.config, "block_out_channels")
|
| 386 |
+
else 8
|
| 387 |
+
)
|
| 388 |
+
self.image_processor = RefinerImageProcessor(
|
| 389 |
+
vae_scale_factor=self.vae_scale_factor * 2, do_resize=True
|
| 390 |
+
)
|
| 391 |
+
self.patch_size = int(getattr(self.transformer.config, "patch_size", 16))
|
| 392 |
+
|
| 393 |
+
self._num_timesteps: int = 0
|
| 394 |
+
self._current_timestep: Optional[torch.Tensor] = None
|
| 395 |
+
self._interrupt: bool = False
|
| 396 |
+
self._freqs_cis: Optional[torch.Tensor] = None
|
| 397 |
+
self._text_guidance_scale: float = 1.0
|
| 398 |
+
self._image_guidance_scale: float = 1.0
|
| 399 |
+
self._cfg_range: Tuple[float, float] = (0.0, 1.0)
|
| 400 |
+
|
| 401 |
+
@torch.no_grad()
|
| 402 |
+
def _get_freqs_cis(self, device, dtype):
|
| 403 |
+
if self._freqs_cis is None:
|
| 404 |
+
self._freqs_cis = RotaryPosEmbed.get_freqs_cis(
|
| 405 |
+
self.transformer.config.axes_dim_rope,
|
| 406 |
+
self.transformer.config.axes_lens,
|
| 407 |
+
theta=10000,
|
| 408 |
+
)
|
| 409 |
+
return self._freqs_cis
|
| 410 |
+
|
| 411 |
+
@staticmethod
|
| 412 |
+
def _split_tokens(
|
| 413 |
+
encoder_hidden_states: torch.Tensor,
|
| 414 |
+
grid_thw_list: List[Tuple[int, int, int]],
|
| 415 |
+
) -> List[torch.Tensor]:
|
| 416 |
+
splits = [int(h) * int(w) // 4 for (_, h, w) in grid_thw_list]
|
| 417 |
+
return list(torch.split(encoder_hidden_states, splits, dim=1))
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _looks_like_latents(x: Union[torch.Tensor, Image.Image], latent_ch_hint: int = 16) -> bool:
|
| 421 |
+
if not isinstance(x, torch.Tensor):
|
| 422 |
+
return False
|
| 423 |
+
if x.ndim not in (3, 4):
|
| 424 |
+
return False
|
| 425 |
+
c = int(x.shape[-3])
|
| 426 |
+
if c == 3:
|
| 427 |
+
return False
|
| 428 |
+
if c == latent_ch_hint:
|
| 429 |
+
return True
|
| 430 |
+
if c > 3 and c <= 32:
|
| 431 |
+
return True
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
@torch.no_grad()
|
| 435 |
+
def _preprocess_to_vae_range(self, img: torch.Tensor) -> torch.Tensor:
|
| 436 |
+
if img.dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
| 437 |
+
img = img.float()
|
| 438 |
+
if img.max() > 1.5:
|
| 439 |
+
img = img / 255.0
|
| 440 |
+
if img.min() >= 0.0 and img.max() <= 1.0:
|
| 441 |
+
img = img * 2.0 - 1.0
|
| 442 |
+
return img.clamp(-1, 1)
|
| 443 |
+
|
| 444 |
+
@torch.no_grad()
|
| 445 |
+
def _encode_image_to_latents(
|
| 446 |
+
self,
|
| 447 |
+
img_any: Union[Image.Image, torch.Tensor],
|
| 448 |
+
device,
|
| 449 |
+
dtype,
|
| 450 |
+
) -> Tuple[torch.Tensor, int, int]:
|
| 451 |
+
latent_ch_hint = int(getattr(getattr(self.vae, "config", None), "latent_channels", 16))
|
| 452 |
+
|
| 453 |
+
if self._looks_like_latents(img_any, latent_ch_hint=latent_ch_hint):
|
| 454 |
+
z = img_any
|
| 455 |
+
if z.ndim == 3:
|
| 456 |
+
z = z.unsqueeze(0)
|
| 457 |
+
z = z.to(device=device, dtype=dtype)
|
| 458 |
+
H_lat, W_lat = z.shape[-2], z.shape[-1]
|
| 459 |
+
return z, H_lat, W_lat
|
| 460 |
+
|
| 461 |
+
if isinstance(img_any, Image.Image):
|
| 462 |
+
img = torch.from_numpy(
|
| 463 |
+
np.array(img_any).astype("float32") / 255.0
|
| 464 |
+
).permute(2, 0, 1).unsqueeze(0)
|
| 465 |
+
elif isinstance(img_any, torch.Tensor):
|
| 466 |
+
img = img_any
|
| 467 |
+
if img.ndim == 3:
|
| 468 |
+
img = img.unsqueeze(0)
|
| 469 |
+
else:
|
| 470 |
+
raise TypeError("Unsupported image type. Use PIL.Image or torch.Tensor or latent Tensor.")
|
| 471 |
+
|
| 472 |
+
img = self._preprocess_to_vae_range(img)
|
| 473 |
+
|
| 474 |
+
H, W = img.shape[-2:]
|
| 475 |
+
base = self.patch_size * self.vae_scale_factor
|
| 476 |
+
target_H = max(base, math.ceil(H / base) * base)
|
| 477 |
+
target_W = max(base, math.ceil(W / base) * base)
|
| 478 |
+
if (H != target_H) or (W != target_W):
|
| 479 |
+
img = F.interpolate(img, size=(target_H, target_W), mode="bilinear", align_corners=False)
|
| 480 |
+
|
| 481 |
+
img = img.to(device=device, dtype=self.vae.dtype)
|
| 482 |
+
|
| 483 |
+
posterior = self.vae.encode(img).latent_dist
|
| 484 |
+
z0 = posterior.sample()
|
| 485 |
+
if getattr(self.vae.config, "shift_factor", None) is not None:
|
| 486 |
+
z0 = z0 - self.vae.config.shift_factor
|
| 487 |
+
if getattr(self.vae.config, "scaling_factor", None) is not None:
|
| 488 |
+
z0 = z0 * self.vae.config.scaling_factor
|
| 489 |
+
|
| 490 |
+
z0 = z0.to(device=device, dtype=dtype)
|
| 491 |
+
H_lat, W_lat = z0.shape[-2], z0.shape[-1]
|
| 492 |
+
return z0, H_lat, W_lat
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def _expand_to_list(x, n):
|
| 496 |
+
if x is None:
|
| 497 |
+
return [None] * n
|
| 498 |
+
if isinstance(x, (Image.Image, torch.Tensor)):
|
| 499 |
+
return [x] * n
|
| 500 |
+
assert isinstance(x, list), "`image` must be PIL / Tensor or list of them."
|
| 501 |
+
assert len(x) == n, "`len(image)` must equal number of image chunks"
|
| 502 |
+
return x
|
| 503 |
+
|
| 504 |
+
@torch.no_grad()
|
| 505 |
+
def _denoise_once(
|
| 506 |
+
self,
|
| 507 |
+
cond_tokens: torch.Tensor,
|
| 508 |
+
ref_img: Optional[Union[Image.Image, torch.Tensor]],
|
| 509 |
+
num_inference_steps: int = 28,
|
| 510 |
+
timesteps: Optional[List[int]] = None,
|
| 511 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 512 |
+
output_type: str = "pil",
|
| 513 |
+
text_guidance_scale: float = 1.0,
|
| 514 |
+
image_guidance_scale: float = 1.0,
|
| 515 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 516 |
+
enable_processor_bar: bool = True,
|
| 517 |
+
):
|
| 518 |
+
device = cond_tokens.device
|
| 519 |
+
weight_dtype = self.transformer.dtype
|
| 520 |
+
|
| 521 |
+
self._text_guidance_scale = text_guidance_scale
|
| 522 |
+
self._image_guidance_scale = image_guidance_scale
|
| 523 |
+
self._cfg_range = cfg_range
|
| 524 |
+
|
| 525 |
+
cond_tokens = cond_tokens.to(device=device, dtype=weight_dtype)
|
| 526 |
+
text_feats = self.cond_proj(cond_tokens)
|
| 527 |
+
B, L, _ = text_feats.shape
|
| 528 |
+
text_mask = torch.ones(B, L, device=device, dtype=torch.bool)
|
| 529 |
+
|
| 530 |
+
ref_image_hidden_states = None
|
| 531 |
+
H_lat: int
|
| 532 |
+
W_lat: int
|
| 533 |
+
|
| 534 |
+
if ref_img is not None:
|
| 535 |
+
if isinstance(ref_img, torch.Tensor) and ref_img.ndim == 4 and ref_img.shape[0] == B:
|
| 536 |
+
z_ref, H_lat, W_lat = self._encode_image_to_latents(ref_img, device=device, dtype=weight_dtype)
|
| 537 |
+
ref_image_hidden_states = [[z_ref[b]] for b in range(B)]
|
| 538 |
+
else:
|
| 539 |
+
z_ref, H_lat, W_lat = self._encode_image_to_latents(ref_img, device=device, dtype=weight_dtype)
|
| 540 |
+
z_single = z_ref[0]
|
| 541 |
+
ref_image_hidden_states = [[z_single] for _ in range(B)]
|
| 542 |
+
else:
|
| 543 |
+
H_lat = W_lat = 128 // self.vae_scale_factor
|
| 544 |
+
|
| 545 |
+
C_lat = getattr(self.transformer.config, "in_channels", None)
|
| 546 |
+
if C_lat is None:
|
| 547 |
+
if ref_image_hidden_states is not None:
|
| 548 |
+
C_lat = ref_image_hidden_states[0][0].shape[0]
|
| 549 |
+
else:
|
| 550 |
+
raise ValueError("transformer.config.in_channels is None and no ref_img was provided.")
|
| 551 |
+
latents_shape = (B, C_lat, H_lat, W_lat)
|
| 552 |
+
|
| 553 |
+
if isinstance(generator, list):
|
| 554 |
+
if len(generator) != B:
|
| 555 |
+
raise ValueError(
|
| 556 |
+
f"len(generator)={len(generator)} must equal B={B} when passing list of generators."
|
| 557 |
+
)
|
| 558 |
+
latents = torch.stack(
|
| 559 |
+
[
|
| 560 |
+
torch.randn(
|
| 561 |
+
(1, C_lat, H_lat, W_lat),
|
| 562 |
+
generator=generator[i],
|
| 563 |
+
device=device,
|
| 564 |
+
dtype=weight_dtype,
|
| 565 |
+
).squeeze(0)
|
| 566 |
+
for i in range(B)
|
| 567 |
+
],
|
| 568 |
+
dim=0,
|
| 569 |
+
)
|
| 570 |
+
else:
|
| 571 |
+
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=weight_dtype)
|
| 572 |
+
|
| 573 |
+
num_tokens = H_lat * W_lat
|
| 574 |
+
timesteps_sched, num_inference_steps = _retrieve_timesteps(
|
| 575 |
+
self.scheduler,
|
| 576 |
+
num_inference_steps=num_inference_steps,
|
| 577 |
+
device=device,
|
| 578 |
+
timesteps=timesteps,
|
| 579 |
+
num_tokens=num_tokens,
|
| 580 |
+
)
|
| 581 |
+
num_warmup_steps = max(len(timesteps_sched) - num_inference_steps * self.scheduler.order, 0)
|
| 582 |
+
self._num_timesteps = len(timesteps_sched)
|
| 583 |
+
|
| 584 |
+
freqs_cis = self._get_freqs_cis(device=device, dtype=weight_dtype)
|
| 585 |
+
|
| 586 |
+
progress_bar = self.progress_bar(total=num_inference_steps) if enable_processor_bar else None
|
| 587 |
+
for i, t in enumerate(timesteps_sched):
|
| 588 |
+
if self._interrupt:
|
| 589 |
+
continue
|
| 590 |
+
self._current_timestep = t
|
| 591 |
+
|
| 592 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 593 |
+
|
| 594 |
+
step_frac = i / max(len(timesteps_sched) - 1, 1)
|
| 595 |
+
use_cfg = (cfg_range[0] <= step_frac <= cfg_range[1]) and (
|
| 596 |
+
text_guidance_scale > 1.0 or image_guidance_scale > 1.0
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
if not use_cfg:
|
| 600 |
+
optional_kwargs: Dict[str, Any] = {}
|
| 601 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 602 |
+
optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
|
| 603 |
+
model_pred = self.transformer(
|
| 604 |
+
latents, timestep, text_feats, freqs_cis, text_mask, **optional_kwargs
|
| 605 |
+
)
|
| 606 |
+
else:
|
| 607 |
+
text_uncond = torch.zeros_like(text_feats)
|
| 608 |
+
|
| 609 |
+
opt_kwargs_text: Dict[str, Any] = {}
|
| 610 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 611 |
+
opt_kwargs_text["ref_image_hidden_states"] = ref_image_hidden_states
|
| 612 |
+
|
| 613 |
+
model_pred_text = self.transformer(
|
| 614 |
+
latents, timestep, text_feats, freqs_cis, text_mask, **opt_kwargs_text
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
opt_kwargs_ref: Dict[str, Any] = {}
|
| 618 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 619 |
+
opt_kwargs_ref["ref_image_hidden_states"] = ref_image_hidden_states
|
| 620 |
+
|
| 621 |
+
model_pred_ref = self.transformer(
|
| 622 |
+
latents, timestep, text_uncond, freqs_cis, text_mask, **opt_kwargs_ref
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
opt_kwargs_uncond: Dict[str, Any] = {}
|
| 626 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 627 |
+
opt_kwargs_uncond["ref_image_hidden_states"] = None
|
| 628 |
+
|
| 629 |
+
model_pred_uncond = self.transformer(
|
| 630 |
+
latents, timestep, text_uncond, freqs_cis, text_mask, **opt_kwargs_uncond
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
| 634 |
+
model_pred = (
|
| 635 |
+
model_pred_uncond
|
| 636 |
+
+ image_guidance_scale * (model_pred_ref - model_pred_uncond)
|
| 637 |
+
+ text_guidance_scale * (model_pred_text - model_pred_ref)
|
| 638 |
+
)
|
| 639 |
+
elif text_guidance_scale > 1.0:
|
| 640 |
+
model_pred = model_pred_uncond + text_guidance_scale * (model_pred_text - model_pred_uncond)
|
| 641 |
+
elif image_guidance_scale > 1.0:
|
| 642 |
+
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond)
|
| 643 |
+
else:
|
| 644 |
+
model_pred = model_pred_text
|
| 645 |
+
|
| 646 |
+
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
| 647 |
+
latents = latents.to(dtype=weight_dtype)
|
| 648 |
+
|
| 649 |
+
if progress_bar is not None:
|
| 650 |
+
if i == len(timesteps_sched) - 1 or (
|
| 651 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 652 |
+
):
|
| 653 |
+
progress_bar.update()
|
| 654 |
+
|
| 655 |
+
if progress_bar is not None:
|
| 656 |
+
progress_bar.close()
|
| 657 |
+
|
| 658 |
+
self._current_timestep = None
|
| 659 |
+
|
| 660 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 661 |
+
if getattr(self.vae.config, "scaling_factor", None) is not None:
|
| 662 |
+
latents = latents / self.vae.config.scaling_factor
|
| 663 |
+
if getattr(self.vae.config, "shift_factor", None) is not None:
|
| 664 |
+
latents = latents + self.vae.config.shift_factor
|
| 665 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 666 |
+
|
| 667 |
+
images = self.image_processor.postprocess(image, output_type=output_type)
|
| 668 |
+
return images
|
| 669 |
+
|
| 670 |
+
@torch.no_grad()
|
| 671 |
+
def __call__(
|
| 672 |
+
self,
|
| 673 |
+
*,
|
| 674 |
+
encoder_hidden_states: torch.Tensor,
|
| 675 |
+
grid_thw_list: List[Tuple[int, int, int]],
|
| 676 |
+
image: Union[Image.Image, torch.Tensor, List[Union[Image.Image, torch.Tensor]], None] = None,
|
| 677 |
+
num_inference_steps: int = 28,
|
| 678 |
+
timesteps: Optional[List[int]] = None,
|
| 679 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 680 |
+
output_type: str = "pil",
|
| 681 |
+
return_dict: bool = True,
|
| 682 |
+
text_guidance_scale: float = 1.5,
|
| 683 |
+
image_guidance_scale: float = 1.5,
|
| 684 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 685 |
+
enable_processor_bar: bool = True,
|
| 686 |
+
**kwargs,
|
| 687 |
+
) -> Union[RefinerOutput, List[Image.Image], torch.Tensor]:
|
| 688 |
+
self._interrupt = False
|
| 689 |
+
|
| 690 |
+
token_chunks = self._split_tokens(encoder_hidden_states, grid_thw_list)
|
| 691 |
+
ref_list = self._expand_to_list(image, len(token_chunks))
|
| 692 |
+
|
| 693 |
+
results_pil: List[Image.Image] = []
|
| 694 |
+
results_pt: Optional[torch.Tensor] = None
|
| 695 |
+
|
| 696 |
+
for tok, _, img_any in zip(token_chunks, grid_thw_list, ref_list):
|
| 697 |
+
imgs = self._denoise_once(
|
| 698 |
+
cond_tokens=tok,
|
| 699 |
+
ref_img=img_any,
|
| 700 |
+
num_inference_steps=num_inference_steps,
|
| 701 |
+
timesteps=timesteps,
|
| 702 |
+
generator=generator,
|
| 703 |
+
output_type=output_type,
|
| 704 |
+
text_guidance_scale=text_guidance_scale,
|
| 705 |
+
image_guidance_scale=image_guidance_scale,
|
| 706 |
+
cfg_range=cfg_range,
|
| 707 |
+
enable_processor_bar=enable_processor_bar,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
if output_type == "pil":
|
| 711 |
+
results_pil += imgs
|
| 712 |
+
else:
|
| 713 |
+
results_pt = imgs if results_pt is None else torch.cat([results_pt, imgs], dim=0)
|
| 714 |
+
|
| 715 |
+
if not return_dict:
|
| 716 |
+
return results_pil if output_type == "pil" else results_pt
|
| 717 |
+
return RefinerOutput(images=results_pil if output_type == "pil" else results_pt)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def de_transform(
|
| 721 |
+
tensor: torch.Tensor,
|
| 722 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
| 723 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
| 724 |
+
rescale_factor: float = 1 / 255,
|
| 725 |
+
) -> torch.Tensor:
|
| 726 |
+
"""De-normalize and de-rescale, suitable for images processed by Qwen2VLImageProcessor."""
|
| 727 |
+
if tensor.ndim == 3:
|
| 728 |
+
tensor = tensor.unsqueeze(0)
|
| 729 |
+
mean_t = torch.tensor(mean).view(1, -1, 1, 1).to(tensor.device)
|
| 730 |
+
std_t = torch.tensor(std).view(1, -1, 1, 1).to(tensor.device)
|
| 731 |
+
tensor = tensor * std_t + mean_t
|
| 732 |
+
tensor = tensor / rescale_factor
|
| 733 |
+
tensor = torch.clamp(tensor / 255.0, 0, 1)
|
| 734 |
+
return tensor
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def tensor2pil(image_t: torch.Tensor, image_mean, image_std) -> Image.Image:
|
| 738 |
+
"""Convert a tensor to a PIL Image."""
|
| 739 |
+
image_t = image_t.detach().cpu()
|
| 740 |
+
rescale_factor = 1 / 255
|
| 741 |
+
sample = de_transform(
|
| 742 |
+
image_t,
|
| 743 |
+
mean=image_mean,
|
| 744 |
+
std=image_std,
|
| 745 |
+
rescale_factor=rescale_factor,
|
| 746 |
+
)[0]
|
| 747 |
+
ndarr = sample.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 748 |
+
return Image.fromarray(ndarr)
|
model-00001-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1685f75f9d024166baee24ab136d5ab2ea647b4de907db74bb7e3430c33bb65b
|
| 3 |
+
size 4295486032
|
model-00002-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfecc6d4d4b4f6cc86c4c07c9de9f25624ee78e06af78834192927917303deee
|
| 3 |
+
size 4295486168
|
model-00003-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4c5c6321ac4cfb40a3c1861ec138ebb257dd168ce662a470ca4365c6892528d7
|
| 3 |
+
size 4295486720
|
model-00004-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc4bb766d38246e270f534c3516a6350f4f0ff8e8cf9feac0c895827c8e2c905
|
| 3 |
+
size 4295493080
|
model-00005-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4960f236e0ed45b722c59f87c74625a755877f93ce1e6aa8c7405f074d96a4d1
|
| 3 |
+
size 3463467600
|
model-00006-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b9a31a5592a4b0dca851f92097116d388ded88b7d3ffbf7e6002f2e9038ff93
|
| 3 |
+
size 1736441984
|
model-00007-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2ddc8f496759a158eb7d12639d7cc6bddf6cfe1ba79a97a5ea09cfe282f0116
|
| 3 |
+
size 5234492032
|
model-00008-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad29cecd2c485c19269e8031a4418d1f2475f26b090fb73a893611e93544f629
|
| 3 |
+
size 5234493056
|
model-00009-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c59490055ec3cb72512512eb3113473eee00c2093f7105475de6ddf96e6628b8
|
| 3 |
+
size 5234494080
|
model-00010-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:732fde4d1ddc0052328defd8ec9977c921159e4d73ba24358e6ffda024a1c177
|
| 3 |
+
size 5234495104
|
model-00011-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9809ea9ab215857b644c00600c1cc91191d3b4fb0f66dc83fc87c85c8beb364d
|
| 3 |
+
size 5234496128
|
model-00012-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bcc28eaae6c44a109120814106af240405f921dfceb337650beaac6e06e04902
|
| 3 |
+
size 5234497152
|
model-00013-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc1599f3ff951c2fe58ef405a2e0d510cd8fa6635734cee56f6fbf7e6bffe61f
|
| 3 |
+
size 5234498176
|
model-00014-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4392e44efe5ee3a0f98b34bb5eec1def2d2a4299931be30cf91f5bfb7bd45c8d
|
| 3 |
+
size 5234499200
|
model-00015-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1fed0dd9555b4823a9c2f7b99104bdd9c37a4c6c347b2e884473880de38cd2c
|
| 3 |
+
size 5234500224
|
model-00016-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f4530ef721cf3469ebeb290e14ec600f3e8f2366a19ef245ebc597e2116bad21
|
| 3 |
+
size 5234501248
|
model-00017-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:307e83e55c4198f22815c59baf3654b38021d34e31da01f18394ce2fd49138e5
|
| 3 |
+
size 5234502272
|
model-00018-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68db1b41a05207aa11bcedbc60f1f596829c25ac5c3679e8579b835b5dd39a37
|
| 3 |
+
size 5234503296
|
model-00019-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:994959dcb058c87110fd04eb0a1699b4c24097316cf0e13a9ce155e697ebc5c1
|
| 3 |
+
size 1841871936
|
model-00020-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e33625b0e732967b9df36a485e9068bfcab0a9f4e378c7d34a6b92dbfd935a7
|
| 3 |
+
size 3479146344
|
model-00021-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:596ba2561f673c46c02eea86e6b7adc328818e0eaeca438e023eded892737006
|
| 3 |
+
size 4017358136
|
model-00022-of-00022.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8479513d47867a0ba235caa5d5ae677c4bab15bb81eee8ee95027d45465d0be4
|
| 3 |
+
size 1403332336
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model_extra_tensors.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:60243cd66e81d84bd3f63ff5bd5b13a292aab1fc6fab7880c527898b021c789a
|
| 3 |
+
size 2819054408
|
modeling_longcat_next.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2026 Meituan
|
| 3 |
+
# This code is licensed under the MIT License, for details, see the ./LICENSE file.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from transformers.cache_utils import Cache
|
| 15 |
+
from transformers.generation.configuration_utils import GenerationConfig
|
| 16 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
| 17 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
| 18 |
+
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerateNonBeamOutput
|
| 19 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 20 |
+
from transformers.models.longcat_flash.modeling_longcat_flash import LongcatFlashForCausalLM
|
| 21 |
+
from transformers.processing_utils import Unpack
|
| 22 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 23 |
+
|
| 24 |
+
from .configuration_longcat_next import LongcatNextConfig
|
| 25 |
+
from .modeling_longcat_ngram import LongcatFlashNgramModel, NgramCache
|
| 26 |
+
from .modular_longcat_next import CasualDepthTransformerHead
|
| 27 |
+
from .modular_longcat_next_audio import LongcatNextAudioTokenizer
|
| 28 |
+
from .modular_longcat_next_visual import LongcatNextVisualTokenizer
|
| 29 |
+
|
| 30 |
+
from .cosy24k_vocoder import Cosy24kVocoder
|
| 31 |
+
from .image_refiner import ImageRefinerContainer
|
| 32 |
+
from .refiner_modules import FlowMatchEulerDiscreteScheduler
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class LongcatNextForCausalLMOutputWithPast(CausalLMOutputWithPast):
|
| 38 |
+
visual_loss: Optional[torch.FloatTensor] = None
|
| 39 |
+
visual_logits: Optional[torch.FloatTensor] = None
|
| 40 |
+
visual_ids: Optional[torch.LongTensor] = None
|
| 41 |
+
audio_loss: Optional[torch.FloatTensor] = None
|
| 42 |
+
audio_logits: Optional[torch.FloatTensor] = None
|
| 43 |
+
audio_ids: Optional[torch.LongTensor] = None
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class LongcatNextForCausalLMGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
|
| 47 |
+
visual_ids: Optional[torch.LongTensor] = None
|
| 48 |
+
audio_ids: Optional[torch.LongTensor] = None
|
| 49 |
+
audio_text_ids: Optional[torch.LongTensor] = None
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class LongcatNextForCausalLMGenerateEncoderDecoderOutput(GenerateEncoderDecoderOutput):
|
| 53 |
+
visual_ids: Optional[torch.LongTensor] = None
|
| 54 |
+
audio_ids: Optional[torch.LongTensor] = None
|
| 55 |
+
audio_text_ids: Optional[torch.LongTensor] = None
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class LongcatNextForCausalLMGenerationStatus:
|
| 59 |
+
mode: str = "text"
|
| 60 |
+
current_image_token_num: int = -1
|
| 61 |
+
audio_parallel_decoding: bool = False
|
| 62 |
+
is_audio_text_end: bool = False
|
| 63 |
+
is_audio_start: bool = False
|
| 64 |
+
last_step_mode: str = None
|
| 65 |
+
|
| 66 |
+
def __init__(self, visual_generation_config, audio_generation_config):
|
| 67 |
+
self.visual_generation_config = visual_generation_config
|
| 68 |
+
self.h = self.visual_generation_config.custom_params["token_h"]
|
| 69 |
+
self.w = self.visual_generation_config.custom_params["token_w"]
|
| 70 |
+
self.anyres_prefix = self.visual_generation_config.custom_params["anyres_prefix"].format(h=self.h, w=self.w)
|
| 71 |
+
self.audio_generation_config = audio_generation_config
|
| 72 |
+
self.audio_parallel_decoding = audio_generation_config.audio_parallel_decoding
|
| 73 |
+
|
| 74 |
+
def switch_to(self, modal):
|
| 75 |
+
assert modal in ["text", "visual", "audio"]
|
| 76 |
+
self.mode = modal
|
| 77 |
+
self.current_image_token_num = 0 if modal == "visual" else -1
|
| 78 |
+
self.is_audio_text_end = False
|
| 79 |
+
self.is_audio_start = False
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def is_img_newline(self):
|
| 83 |
+
return ((self.current_image_token_num + 1) % (self.w + 1)) == 0 and not self.is_img_end
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def is_img_end(self):
|
| 87 |
+
return (self.current_image_token_num + 1) / (self.w + 1) == self.h
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class LongcatNextModel(LongcatFlashNgramModel):
|
| 91 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 92 |
+
config_class = LongcatNextConfig
|
| 93 |
+
|
| 94 |
+
def __init__(self, config):
|
| 95 |
+
super().__init__(config)
|
| 96 |
+
self.visual_tokenizer = LongcatNextVisualTokenizer(config)
|
| 97 |
+
self.audio_tokenizer = LongcatNextAudioTokenizer(config)
|
| 98 |
+
|
| 99 |
+
self._init_multimodal_constants(config)
|
| 100 |
+
self.post_init()
|
| 101 |
+
|
| 102 |
+
def _init_multimodal_constants(self, config):
|
| 103 |
+
name2id_dict = {
|
| 104 |
+
"image_newline_token_id": self.config.visual_config.image_newline_token_id,
|
| 105 |
+
"image_end_token_id": self.config.visual_config.image_end_token_id,
|
| 106 |
+
"image_pad_token_id": self.config.visual_config.image_pad_token_id,
|
| 107 |
+
"audiotext_start_token_id": config.audio_config.audiotext_start_token_id,
|
| 108 |
+
"audiotext_pad_token_id": self.config.audio_config.audiotext_pad_token_id,
|
| 109 |
+
"audiogen_end_token_id": config.audio_config.audiogen_end_token_id,
|
| 110 |
+
"audio_pad_token_id": self.config.audio_config.audio_pad_token_id,
|
| 111 |
+
}
|
| 112 |
+
for k, v in name2id_dict.items():
|
| 113 |
+
self.register_buffer(k, torch.tensor([v], dtype=torch.long), persistent=False)
|
| 114 |
+
visual_offset_list = [config.visual_offset] + config.visual_config.vq_config.codebook_sizes[:-1]
|
| 115 |
+
visual_offset_vals = torch.cumsum(torch.tensor(visual_offset_list, dtype=torch.long), dim=0)
|
| 116 |
+
self.register_buffer("visual_offset_vals", visual_offset_vals, persistent=False)
|
| 117 |
+
audio_offset_list = [config.audio_offset] + config.audio_config.vq_config.codebook_sizes[:-1]
|
| 118 |
+
audio_offset_vals = torch.cumsum(torch.tensor(audio_offset_list, dtype=torch.long), dim=0)
|
| 119 |
+
self.register_buffer("audio_offset_vals", audio_offset_vals, persistent=False)
|
| 120 |
+
print(f"{self.visual_offset_vals=}")
|
| 121 |
+
print(f"{self.audio_offset_vals=}")
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 127 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 128 |
+
past_key_values: Optional[Cache] = None,
|
| 129 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 130 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 131 |
+
use_cache: Optional[bool] = None,
|
| 132 |
+
visual_inputs=None,
|
| 133 |
+
visual_ids=None,
|
| 134 |
+
audio_inputs=None,
|
| 135 |
+
audio_ids=None,
|
| 136 |
+
audio_text_ids=None,
|
| 137 |
+
multimodal_generation_status=None,
|
| 138 |
+
**kwargs
|
| 139 |
+
) -> BaseModelOutputWithPast:
|
| 140 |
+
|
| 141 |
+
if input_ids is None:
|
| 142 |
+
raise ValueError("You must specify input_ids")
|
| 143 |
+
|
| 144 |
+
# Extract N-gram context if available
|
| 145 |
+
ngram_context = None
|
| 146 |
+
if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
|
| 147 |
+
ngram_context = past_key_values.ngram_context
|
| 148 |
+
|
| 149 |
+
# assert input_ids.size(0) == 1, "only support bs=1 for now" # but when bs=2, idx=1 is for uncond_image_generation
|
| 150 |
+
special_visual_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask = self.get_placeholder_mask(input_ids[:1]) # seq-dim
|
| 151 |
+
|
| 152 |
+
if inputs_embeds is None:
|
| 153 |
+
input_ids[:, special_visual_mask | special_audio_mask | special_audio_text_pad_mask | special_audio_text_start_mask] = 0
|
| 154 |
+
filled_text_pad_mask = torch.ones_like(special_audio_mask)
|
| 155 |
+
audio_text_position_mask = (special_audio_text_pad_mask | special_audio_text_start_mask | special_audio_mask)
|
| 156 |
+
|
| 157 |
+
if audio_text_ids is not None and audio_text_ids.size(1) > 0 and audio_text_position_mask.sum() > 0:
|
| 158 |
+
filled_text = audio_text_ids[:, -audio_text_position_mask.sum():]
|
| 159 |
+
filled_text_pad_mask = (filled_text==self.config.audio_config.audiotext_pad_token_id)[0]
|
| 160 |
+
input_ids[:, audio_text_position_mask] = filled_text
|
| 161 |
+
input_ids[input_ids == self.config.audio_config.audiotext_pad_token_id] = 0
|
| 162 |
+
|
| 163 |
+
inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
|
| 164 |
+
inputs_embeds[:, (special_visual_mask | (special_audio_mask & filled_text_pad_mask))] = 0
|
| 165 |
+
|
| 166 |
+
if special_audio_text_start_mask.sum() > 0:
|
| 167 |
+
audio_text_start_embedding = self.embed_tokens(self.audiotext_start_token_id)
|
| 168 |
+
if multimodal_generation_status.last_step_mode is None: # prefill
|
| 169 |
+
inputs_embeds[:1, special_audio_text_start_mask] += audio_text_start_embedding
|
| 170 |
+
else:
|
| 171 |
+
inputs_embeds[:, special_audio_text_start_mask] += audio_text_start_embedding
|
| 172 |
+
|
| 173 |
+
if visual_inputs is not None:
|
| 174 |
+
visual_ids = self.get_visual_ids(**visual_inputs) # [<bs=1>*seq, lev]
|
| 175 |
+
|
| 176 |
+
if visual_ids is not None and special_visual_mask.sum() > 0:
|
| 177 |
+
visual_embeddings = self.get_visual_embeddings(visual_ids[-special_visual_mask.sum():]) # -> [seq, dim]
|
| 178 |
+
if multimodal_generation_status.last_step_mode is None: # prefill
|
| 179 |
+
inputs_embeds[:1, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
|
| 180 |
+
else:
|
| 181 |
+
inputs_embeds[:, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
|
| 182 |
+
|
| 183 |
+
if audio_inputs is not None:
|
| 184 |
+
audio_ids = self.get_audio_ids(**audio_inputs) # -> [<bs=1>*seq, lev]
|
| 185 |
+
|
| 186 |
+
if audio_ids is not None and special_audio_mask.sum() > 0:
|
| 187 |
+
audio_embeddings = self.get_audio_embeddings(audio_ids[-special_audio_mask.sum():]) # -> [seq, dim]
|
| 188 |
+
if multimodal_generation_status.last_step_mode is None: # prefill
|
| 189 |
+
inputs_embeds[:1, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
|
| 190 |
+
else:
|
| 191 |
+
inputs_embeds[:, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
|
| 192 |
+
|
| 193 |
+
# Initialize NgramCache if needed
|
| 194 |
+
if use_cache and past_key_values is None:
|
| 195 |
+
past_key_values = NgramCache(config=self.config)
|
| 196 |
+
|
| 197 |
+
# Update N-gram context
|
| 198 |
+
if use_cache and isinstance(past_key_values, NgramCache):
|
| 199 |
+
past_key_values.update_ngram_context(input_ids)
|
| 200 |
+
|
| 201 |
+
return super().forward(
|
| 202 |
+
input_ids=None,
|
| 203 |
+
attention_mask=attention_mask,
|
| 204 |
+
position_ids=position_ids,
|
| 205 |
+
past_key_values=past_key_values,
|
| 206 |
+
inputs_embeds=inputs_embeds,
|
| 207 |
+
cache_position=cache_position,
|
| 208 |
+
use_cache=use_cache,
|
| 209 |
+
**kwargs
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def get_visual_ids(self, pixel_values, visual_grid_thw, offset=True):
|
| 213 |
+
visual_ids = self.visual_tokenizer.encode(pixel_values, visual_grid_thw)
|
| 214 |
+
if offset:
|
| 215 |
+
visual_ids += self.visual_offset_vals.to(visual_ids.device)
|
| 216 |
+
return visual_ids
|
| 217 |
+
|
| 218 |
+
def get_audio_ids(self, audio, encoder_length, bridge_length, offset=True):
|
| 219 |
+
audio_ids = self.audio_tokenizer.encode(audio, encoder_length, bridge_length)
|
| 220 |
+
if offset:
|
| 221 |
+
audio_ids += self.audio_offset_vals.to(audio_ids.device)
|
| 222 |
+
return audio_ids
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def decode_visual_ids_and_save(
|
| 226 |
+
self,
|
| 227 |
+
visual_ids,
|
| 228 |
+
save_prefix,
|
| 229 |
+
token_h,
|
| 230 |
+
token_w,
|
| 231 |
+
**kwargs,
|
| 232 |
+
):
|
| 233 |
+
visual_ids -= self.visual_offset_vals.to(visual_ids.device)
|
| 234 |
+
|
| 235 |
+
if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
|
| 236 |
+
save_prefix = f"./{save_prefix}"
|
| 237 |
+
os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
|
| 238 |
+
return self.visual_tokenizer.lazy_decode_and_save(visual_ids, token_h, token_w, f"{save_prefix}_{0}.png")
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def decode_audio_ids_and_save(
|
| 242 |
+
self,
|
| 243 |
+
audio_ids,
|
| 244 |
+
save_prefix,
|
| 245 |
+
sampling_rate,
|
| 246 |
+
wave_concat_overlap,
|
| 247 |
+
**kwargs,
|
| 248 |
+
):
|
| 249 |
+
audio_ids -= self.audio_offset_vals.to(audio_ids.device)
|
| 250 |
+
|
| 251 |
+
if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
|
| 252 |
+
save_prefix = f"./{save_prefix}"
|
| 253 |
+
os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
|
| 254 |
+
save_path = f"{save_prefix}_{0}.wav"
|
| 255 |
+
self.audio_tokenizer.lazy_decode_and_save(audio_ids, sampling_rate, wave_concat_overlap, save_path)
|
| 256 |
+
return [save_path]
|
| 257 |
+
|
| 258 |
+
def get_visual_embeddings(self, visual_ids):
|
| 259 |
+
visual_embeddings = self.embed_tokens(visual_ids).sum(dim=1) # [seq, lev] -> [seq, lev, dim] -> [seq, dim]
|
| 260 |
+
visual_embeddings = self.visual_tokenizer.visual_embedding_layer(visual_embeddings)
|
| 261 |
+
return visual_embeddings
|
| 262 |
+
|
| 263 |
+
def get_audio_embeddings(self, audio_ids):
|
| 264 |
+
audio_embeddings = self.embed_tokens(audio_ids).sum(dim=1)
|
| 265 |
+
return audio_embeddings
|
| 266 |
+
|
| 267 |
+
def get_placeholder_mask(self, input_ids: torch.LongTensor):
|
| 268 |
+
special_image_mask = (input_ids == self.config.visual_config.image_pad_token_id).squeeze(0)
|
| 269 |
+
special_audio_mask = (input_ids == self.config.audio_config.audio_pad_token_id).squeeze(0)
|
| 270 |
+
special_audio_text_start_mask = (input_ids == self.config.audio_config.audiotext_start_token_id).squeeze(0)
|
| 271 |
+
special_audio_text_pad_mask = (input_ids == self.config.audio_config.audiotext_pad_token_id).squeeze(0)
|
| 272 |
+
return special_image_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class LongcatNextForCausalLM(LongcatFlashForCausalLM):
|
| 276 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 277 |
+
_no_split_modules = [
|
| 278 |
+
"LongcatFlashDecoderLayer",
|
| 279 |
+
"CasualDepthTransformerHead",
|
| 280 |
+
]
|
| 281 |
+
config_class = LongcatNextConfig
|
| 282 |
+
|
| 283 |
+
def __init__(self, config):
|
| 284 |
+
super().__init__(config)
|
| 285 |
+
self.config = config
|
| 286 |
+
self.model = LongcatNextModel(config)
|
| 287 |
+
self.lm_head = nn.Linear(config.hidden_size, config.text_vocab_plus_multimodal_special_token_size, bias=False)
|
| 288 |
+
|
| 289 |
+
self.visual_head = CasualDepthTransformerHead(
|
| 290 |
+
hidden_size=config.hidden_size,
|
| 291 |
+
codebook_sizes=config.visual_config.vq_config.codebook_sizes,
|
| 292 |
+
transformer_layer_num=config.visual_config.image_head_transformer_layers,
|
| 293 |
+
transformer_dim=config.visual_config.image_head_transformer_dims,
|
| 294 |
+
transformer_ffn_scale=config.visual_config.image_head_transformer_ffn_scale,
|
| 295 |
+
)
|
| 296 |
+
self.audio_head = CasualDepthTransformerHead(
|
| 297 |
+
hidden_size=config.hidden_size,
|
| 298 |
+
codebook_sizes=config.audio_config.vq_config.codebook_sizes,
|
| 299 |
+
transformer_layer_num=config.audio_config.audio_head_transformer_layers,
|
| 300 |
+
transformer_dim=config.audio_config.audio_head_transformer_dims,
|
| 301 |
+
transformer_ffn_scale=config.audio_config.audio_head_transformer_ffn_scale,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.post_init()
|
| 305 |
+
|
| 306 |
+
@can_return_tuple
|
| 307 |
+
@auto_docstring
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 312 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 313 |
+
past_key_values: Optional[Cache] = None,
|
| 314 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 315 |
+
labels: Optional[torch.LongTensor] = None,
|
| 316 |
+
use_cache: Optional[bool] = None,
|
| 317 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 318 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 319 |
+
visual_inputs=None,
|
| 320 |
+
visual_ids=None,
|
| 321 |
+
audio_inputs=None,
|
| 322 |
+
audio_ids=None,
|
| 323 |
+
audio_text_ids=None,
|
| 324 |
+
multimodal_generation_status: LongcatNextForCausalLMGenerationStatus = None,
|
| 325 |
+
visual_generation_config: GenerationConfig = None,
|
| 326 |
+
audio_generation_config: GenerationConfig = None,
|
| 327 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 328 |
+
) -> CausalLMOutputWithPast:
|
| 329 |
+
r"""
|
| 330 |
+
visual_inputs (`BatchFeature`, *optional*):
|
| 331 |
+
Visual inputs returned by the processor, containing pixel values and grid metadata for image encoding.
|
| 332 |
+
visual_ids (`torch.LongTensor` of shape `(num_visual_tokens, num_codebooks)`, *optional*):
|
| 333 |
+
Quantized visual token ids from the visual tokenizer, used to build visual embeddings during generation.
|
| 334 |
+
audio_inputs (`BatchFeature`, *optional*):
|
| 335 |
+
Audio inputs returned by the processor, containing mel-spectrogram features and length metadata.
|
| 336 |
+
audio_ids (`torch.LongTensor` of shape `(num_audio_tokens, num_codebooks)`, *optional*):
|
| 337 |
+
Quantized audio token ids from the audio tokenizer, used to build audio embeddings during generation.
|
| 338 |
+
audio_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 339 |
+
Token ids for the audio text transcript generated alongside audio tokens.
|
| 340 |
+
multimodal_generation_status (`LongcatNextForCausalLMGenerationStatus`, *optional*):
|
| 341 |
+
Stateful object tracking the current multimodal generation mode (text / visual / audio) and
|
| 342 |
+
associated counters used to route logits to the correct head during auto-regressive decoding.
|
| 343 |
+
visual_generation_config (`GenerationConfig`, *optional*):
|
| 344 |
+
Generation configuration for the visual head, controlling sampling parameters such as
|
| 345 |
+
`temperature`, `top_k`, `top_p`, and custom parameters like `cfg_scale` and `anyres_config`.
|
| 346 |
+
audio_generation_config (`GenerationConfig`, *optional*):
|
| 347 |
+
Generation configuration for the audio head, controlling sampling parameters such as
|
| 348 |
+
`temperature`, `top_k`, `top_p`, `repetition_penalty`, and `audio_parallel_decoding`.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
if multimodal_generation_status.mode == "visual" and visual_generation_config.custom_params["cfg_scale"] != 1.0 and input_ids.size(0) == 1:
|
| 352 |
+
input_ids = input_ids.repeat((2, 1))
|
| 353 |
+
|
| 354 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 355 |
+
input_ids=input_ids,
|
| 356 |
+
attention_mask=attention_mask,
|
| 357 |
+
position_ids=position_ids,
|
| 358 |
+
past_key_values=past_key_values,
|
| 359 |
+
inputs_embeds=inputs_embeds,
|
| 360 |
+
use_cache=use_cache,
|
| 361 |
+
cache_position=cache_position,
|
| 362 |
+
visual_inputs=visual_inputs,
|
| 363 |
+
visual_ids=visual_ids,
|
| 364 |
+
audio_inputs=audio_inputs,
|
| 365 |
+
audio_ids=audio_ids,
|
| 366 |
+
audio_text_ids=audio_text_ids,
|
| 367 |
+
multimodal_generation_status=multimodal_generation_status,
|
| 368 |
+
**kwargs,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
hidden_states = outputs.last_hidden_state
|
| 372 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 373 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 374 |
+
slice_hidden_states = hidden_states[:, slice_indices, :]
|
| 375 |
+
|
| 376 |
+
loss, logits = None, None
|
| 377 |
+
if multimodal_generation_status.mode == "visual" and \
|
| 378 |
+
(not multimodal_generation_status.is_img_newline) and (not multimodal_generation_status.is_img_end):
|
| 379 |
+
visual_ids = self.get_multimodal_logits_and_ids(
|
| 380 |
+
self.visual_head,
|
| 381 |
+
visual_ids,
|
| 382 |
+
slice_hidden_states,
|
| 383 |
+
self.model.embed_tokens,
|
| 384 |
+
self.config.visual_config.vq_config.codebook_sizes,
|
| 385 |
+
self.model.visual_offset_vals,
|
| 386 |
+
visual_generation_config,
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
logits = self.lm_head(slice_hidden_states)
|
| 390 |
+
|
| 391 |
+
if multimodal_generation_status.mode == "audio" and multimodal_generation_status.is_audio_start:
|
| 392 |
+
audio_ids = self.get_multimodal_logits_and_ids(
|
| 393 |
+
self.audio_head,
|
| 394 |
+
audio_ids,
|
| 395 |
+
slice_hidden_states,
|
| 396 |
+
self.model.embed_tokens,
|
| 397 |
+
self.config.audio_config.vq_config.codebook_sizes,
|
| 398 |
+
self.model.audio_offset_vals,
|
| 399 |
+
audio_generation_config,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return LongcatNextForCausalLMOutputWithPast(
|
| 403 |
+
loss=loss,
|
| 404 |
+
logits=logits,
|
| 405 |
+
past_key_values=outputs.past_key_values,
|
| 406 |
+
hidden_states=outputs.hidden_states,
|
| 407 |
+
attentions=outputs.attentions,
|
| 408 |
+
visual_ids=visual_ids,
|
| 409 |
+
audio_ids=audio_ids,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def get_multimodal_logits_and_ids(
|
| 413 |
+
self,
|
| 414 |
+
head_model,
|
| 415 |
+
multimodal_ids,
|
| 416 |
+
hidden_states,
|
| 417 |
+
multimodal_embedding_layer,
|
| 418 |
+
codebook_sizes,
|
| 419 |
+
offset_vals,
|
| 420 |
+
multimodal_generation_config,
|
| 421 |
+
):
|
| 422 |
+
next_token_ids = torch.zeros(hidden_states.size(0), len(codebook_sizes), dtype=torch.long, device=hidden_states.device)
|
| 423 |
+
multimodal_embedding_layer = multimodal_embedding_layer.to(hidden_states.device)
|
| 424 |
+
|
| 425 |
+
for level, _ in enumerate(codebook_sizes):
|
| 426 |
+
logits = head_model(hidden_states, next_token_ids, multimodal_embedding_layer, level) # -> (bs, 1, dim)
|
| 427 |
+
next_token_id = self.inner_sample(logits, multimodal_ids[None, :, level]-offset_vals[level], multimodal_generation_config) # (bs, 1)
|
| 428 |
+
next_token_id += offset_vals[level]
|
| 429 |
+
next_token_ids[:, level] = next_token_id
|
| 430 |
+
|
| 431 |
+
return next_token_ids[:1]
|
| 432 |
+
|
| 433 |
+
def inner_sample(
|
| 434 |
+
self,
|
| 435 |
+
next_token_logits: torch.Tensor,
|
| 436 |
+
multimodal_ids: torch.LongTensor,
|
| 437 |
+
generation_config: GenerationConfig,
|
| 438 |
+
) -> torch.Tensor:
|
| 439 |
+
logits_processor = self._get_logits_processor(generation_config)
|
| 440 |
+
|
| 441 |
+
if "cfg_scale" in generation_config.custom_params and generation_config.custom_params["cfg_scale"] != 1.0:
|
| 442 |
+
cond_logits, uncond_logits = next_token_logits.chunk(2, dim=0)
|
| 443 |
+
next_token_logits = generation_config.custom_params["cfg_scale"] * (cond_logits - uncond_logits) + uncond_logits
|
| 444 |
+
|
| 445 |
+
next_token_scores = logits_processor(multimodal_ids, next_token_logits.to(multimodal_ids.device))
|
| 446 |
+
if generation_config.do_sample:
|
| 447 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 448 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 449 |
+
else:
|
| 450 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 451 |
+
return next_tokens
|
| 452 |
+
|
| 453 |
+
@torch.no_grad()
|
| 454 |
+
def generate(self, inputs=None, **kwargs):
|
| 455 |
+
"""Override to ensure NgramCache is used."""
|
| 456 |
+
|
| 457 |
+
if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
|
| 458 |
+
kwargs["past_key_values"] = NgramCache(config=self.config)
|
| 459 |
+
|
| 460 |
+
return super().generate(
|
| 461 |
+
inputs=inputs,
|
| 462 |
+
**kwargs,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def prepare_inputs_for_generation(
|
| 466 |
+
self,
|
| 467 |
+
input_ids,
|
| 468 |
+
visual_ids,
|
| 469 |
+
audio_ids,
|
| 470 |
+
audio_text_ids,
|
| 471 |
+
multimodal_generation_status,
|
| 472 |
+
generation_config,
|
| 473 |
+
attention_mask,
|
| 474 |
+
cache_position,
|
| 475 |
+
**kwargs,
|
| 476 |
+
):
|
| 477 |
+
extra_new_tokens = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
|
| 478 |
+
if visual_ids is None:
|
| 479 |
+
visual_ids = torch.empty(0, len(self.config.visual_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
|
| 480 |
+
if audio_ids is None:
|
| 481 |
+
audio_ids = torch.empty(0, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
|
| 482 |
+
if audio_text_ids is None:
|
| 483 |
+
audio_text_ids = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
|
| 484 |
+
|
| 485 |
+
def insert_ids(new_ids, _input_ids, _attention_mask, _cache_position, position=0):
|
| 486 |
+
if position < 0:
|
| 487 |
+
parts = [_input_ids[:, :position], new_ids, _input_ids[:, position:]]
|
| 488 |
+
else:
|
| 489 |
+
parts = [_input_ids, new_ids]
|
| 490 |
+
_input_ids = torch.cat(parts, dim=1)
|
| 491 |
+
insert_len = new_ids.size(1)
|
| 492 |
+
_attention_mask = F.pad(_attention_mask, (0, insert_len), value=1)
|
| 493 |
+
insert_position = _cache_position[-1] + 1 + torch.arange(insert_len, device=_cache_position.device)
|
| 494 |
+
_cache_position = torch.cat([_cache_position, insert_position])
|
| 495 |
+
return _input_ids, _attention_mask, _cache_position
|
| 496 |
+
|
| 497 |
+
# multimodal generation status change
|
| 498 |
+
if cache_position[0] != 0:
|
| 499 |
+
multimodal_generation_status.last_step_mode = multimodal_generation_status.mode
|
| 500 |
+
|
| 501 |
+
if multimodal_generation_status.mode == "visual":
|
| 502 |
+
multimodal_generation_status.current_image_token_num += 1
|
| 503 |
+
|
| 504 |
+
if (input_ids[:, -1] == self.config.visual_config.image_start_token_id).all():
|
| 505 |
+
multimodal_generation_status.switch_to("visual")
|
| 506 |
+
anyres_prefix_ids = self.text_tokenizer.encode(multimodal_generation_status.anyres_prefix, return_tensors="pt")
|
| 507 |
+
anyres_prefix_ids = anyres_prefix_ids.to(input_ids.device)
|
| 508 |
+
extra_new_tokens = torch.cat([extra_new_tokens, anyres_prefix_ids], dim=1)
|
| 509 |
+
input_ids, attention_mask, cache_position = insert_ids(anyres_prefix_ids, input_ids, attention_mask, cache_position, position=-1)
|
| 510 |
+
if input_ids.size(0) == 1: # cfg, change bs=1 -> 2
|
| 511 |
+
input_ids = input_ids.repeat((2, input_ids.size(1)))
|
| 512 |
+
input_ids[1, :-(anyres_prefix_ids.size(-1)+1)] = 0
|
| 513 |
+
print(f"change to cfg, input_ids: {input_ids}")
|
| 514 |
+
attention_mask = attention_mask.repeat((2, attention_mask.size(1)))
|
| 515 |
+
|
| 516 |
+
elif (input_ids[:, -1] == self.config.audio_config.audiogen_start_token_id).all():
|
| 517 |
+
multimodal_generation_status.switch_to("audio")
|
| 518 |
+
|
| 519 |
+
elif (input_ids[:, -1] == self.config.audio_config.audiotext_start_token_id).all():
|
| 520 |
+
multimodal_generation_status.is_audio_start = True
|
| 521 |
+
|
| 522 |
+
elif ((input_ids[:, -1] == self.config.visual_config.image_end_token_id) | (input_ids[:, -1] == self.config.audio_config.audiogen_end_token_id)).all():
|
| 523 |
+
multimodal_generation_status.switch_to("text")
|
| 524 |
+
|
| 525 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 526 |
+
input_ids=input_ids,
|
| 527 |
+
visual_ids=visual_ids,
|
| 528 |
+
audio_ids=audio_ids,
|
| 529 |
+
audio_text_ids=audio_text_ids,
|
| 530 |
+
attention_mask=attention_mask,
|
| 531 |
+
cache_position=cache_position,
|
| 532 |
+
**kwargs,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
if model_inputs["cache_position"][0] != 0:
|
| 536 |
+
model_inputs["visual_inputs"] = None
|
| 537 |
+
model_inputs["audio_inputs"] = None
|
| 538 |
+
|
| 539 |
+
return model_inputs, multimodal_generation_status, extra_new_tokens
|
| 540 |
+
|
| 541 |
+
def _sample(
|
| 542 |
+
self,
|
| 543 |
+
input_ids: torch.LongTensor,
|
| 544 |
+
logits_processor: LogitsProcessorList,
|
| 545 |
+
stopping_criteria: StoppingCriteriaList,
|
| 546 |
+
generation_config: GenerationConfig,
|
| 547 |
+
synced_gpus: bool = False,
|
| 548 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 549 |
+
visual_ids=None,
|
| 550 |
+
audio_ids=None,
|
| 551 |
+
audio_text_ids=None,
|
| 552 |
+
**model_kwargs,
|
| 553 |
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 554 |
+
r"""
|
| 555 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
| 556 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
| 557 |
+
|
| 558 |
+
Parameters:
|
| 559 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 560 |
+
The sequence used as a prompt for the generation.
|
| 561 |
+
logits_processor (`LogitsProcessorList`):
|
| 562 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 563 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 564 |
+
stopping_criteria (`StoppingCriteriaList`):
|
| 565 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
| 566 |
+
used to tell if the generation loop should stop.
|
| 567 |
+
generation_config ([`~generation.GenerationConfig`]):
|
| 568 |
+
The generation configuration to be used as parametrization of the decoding method.
|
| 569 |
+
synced_gpus (`bool`):
|
| 570 |
+
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 571 |
+
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 572 |
+
streamer (`BaseStreamer`, *optional*):
|
| 573 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
| 574 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
| 575 |
+
model_kwargs:
|
| 576 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
| 577 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
| 578 |
+
|
| 579 |
+
Return:
|
| 580 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
| 581 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
| 582 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
| 583 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
| 584 |
+
`model.config.is_encoder_decoder=True`.
|
| 585 |
+
"""
|
| 586 |
+
# init values
|
| 587 |
+
pad_token_id = generation_config._pad_token_tensor
|
| 588 |
+
output_attentions = generation_config.output_attentions
|
| 589 |
+
output_hidden_states = generation_config.output_hidden_states
|
| 590 |
+
output_scores = generation_config.output_scores
|
| 591 |
+
output_logits = generation_config.output_logits
|
| 592 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 593 |
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
| 594 |
+
do_sample = generation_config.do_sample
|
| 595 |
+
|
| 596 |
+
# init attention / hidden states / scores tuples
|
| 597 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
| 598 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
| 599 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 600 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 601 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
| 602 |
+
|
| 603 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
| 604 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
| 605 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
| 606 |
+
encoder_hidden_states = (
|
| 607 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# keep track of which sequences are already finished
|
| 611 |
+
batch_size, cur_len = input_ids.shape[:2]
|
| 612 |
+
this_peer_finished = False
|
| 613 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 614 |
+
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
| 615 |
+
|
| 616 |
+
model_forward = self.__call__
|
| 617 |
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
| 618 |
+
if compile_forward:
|
| 619 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
| 620 |
+
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
| 621 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 622 |
+
# only raise warning if the user passed an explicit compile-config
|
| 623 |
+
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
|
| 624 |
+
logger.warning_once(
|
| 625 |
+
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
|
| 626 |
+
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
|
| 627 |
+
)
|
| 628 |
+
generation_config.compile_config.fullgraph = False
|
| 629 |
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
| 630 |
+
|
| 631 |
+
if generation_config.prefill_chunk_size is not None:
|
| 632 |
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
| 633 |
+
is_prefill = False
|
| 634 |
+
else:
|
| 635 |
+
is_prefill = True
|
| 636 |
+
|
| 637 |
+
visual_generation_config = GenerationConfig(**generation_config.visual_generation_config)
|
| 638 |
+
audio_generation_config = GenerationConfig(**generation_config.audio_generation_config)
|
| 639 |
+
multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(visual_generation_config, audio_generation_config)
|
| 640 |
+
|
| 641 |
+
pbar = tqdm(iter(int, 1), desc="Generating", unit="tok")
|
| 642 |
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 643 |
+
# prepare model inputs
|
| 644 |
+
model_inputs, multimodal_generation_status, extra_new_tokens = self.prepare_inputs_for_generation(
|
| 645 |
+
input_ids,
|
| 646 |
+
visual_ids,
|
| 647 |
+
audio_ids,
|
| 648 |
+
audio_text_ids,
|
| 649 |
+
multimodal_generation_status,
|
| 650 |
+
generation_config,
|
| 651 |
+
**model_kwargs,
|
| 652 |
+
)
|
| 653 |
+
if extra_new_tokens.size(1) > 0:
|
| 654 |
+
input_ids = torch.cat([input_ids[:, :-1], extra_new_tokens, input_ids[:, -1:]], dim=1)
|
| 655 |
+
model_kwargs["attention_mask"] = model_inputs["attention_mask"]
|
| 656 |
+
model_kwargs["cache_position"] = model_inputs["cache_position"]
|
| 657 |
+
|
| 658 |
+
if multimodal_generation_status.mode == "text" and multimodal_generation_status.last_step_mode == "visual":
|
| 659 |
+
next_tokens = generation_config._eos_token_tensor
|
| 660 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 661 |
+
if streamer is not None:
|
| 662 |
+
streamer.put(next_tokens.cpu())
|
| 663 |
+
break
|
| 664 |
+
|
| 665 |
+
visual_ids = model_inputs["visual_ids"]
|
| 666 |
+
audio_ids = model_inputs["audio_ids"]
|
| 667 |
+
audio_text_ids = model_inputs["audio_text_ids"]
|
| 668 |
+
|
| 669 |
+
if is_prefill:
|
| 670 |
+
outputs = self(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
|
| 671 |
+
is_prefill = False
|
| 672 |
+
else:
|
| 673 |
+
outputs = model_forward(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
|
| 674 |
+
|
| 675 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 676 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 677 |
+
outputs,
|
| 678 |
+
model_kwargs,
|
| 679 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
| 680 |
+
num_new_tokens=1,
|
| 681 |
+
)
|
| 682 |
+
if synced_gpus and this_peer_finished:
|
| 683 |
+
continue
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# multimodal generation
|
| 687 |
+
if multimodal_generation_status.mode == "text" or \
|
| 688 |
+
(multimodal_generation_status.mode == "audio" and not multimodal_generation_status.is_audio_text_end):
|
| 689 |
+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
| 690 |
+
# (the clone itself is always small)
|
| 691 |
+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
| 692 |
+
|
| 693 |
+
# pre-process distribution
|
| 694 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 695 |
+
|
| 696 |
+
# Store scores, attentions and hidden_states when required
|
| 697 |
+
if return_dict_in_generate:
|
| 698 |
+
if output_scores:
|
| 699 |
+
scores += (next_token_scores,)
|
| 700 |
+
if output_logits:
|
| 701 |
+
raw_logits += (next_token_logits,)
|
| 702 |
+
if output_attentions:
|
| 703 |
+
decoder_attentions += (
|
| 704 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
| 705 |
+
)
|
| 706 |
+
if self.config.is_encoder_decoder:
|
| 707 |
+
cross_attentions += (outputs.cross_attentions,)
|
| 708 |
+
|
| 709 |
+
if output_hidden_states:
|
| 710 |
+
decoder_hidden_states += (
|
| 711 |
+
(outputs.decoder_hidden_states,)
|
| 712 |
+
if self.config.is_encoder_decoder
|
| 713 |
+
else (outputs.hidden_states,)
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# token selection
|
| 717 |
+
if do_sample:
|
| 718 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 719 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
| 720 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 721 |
+
else:
|
| 722 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 723 |
+
|
| 724 |
+
# audio_text_ids done
|
| 725 |
+
if multimodal_generation_status.mode == "audio" and (next_tokens == self.config.audio_config.audiotext_pad_token_id).all():
|
| 726 |
+
multimodal_generation_status.is_audio_text_end = True
|
| 727 |
+
|
| 728 |
+
elif multimodal_generation_status.mode == "visual":
|
| 729 |
+
if multimodal_generation_status.is_img_end:
|
| 730 |
+
next_tokens = self.model.image_end_token_id.to(input_ids.device)
|
| 731 |
+
|
| 732 |
+
elif multimodal_generation_status.is_img_newline:
|
| 733 |
+
next_tokens = self.model.image_newline_token_id.to(input_ids.device)
|
| 734 |
+
|
| 735 |
+
else:
|
| 736 |
+
visual_ids = torch.cat([visual_ids, outputs.visual_ids], dim=0) # [seq, lev]
|
| 737 |
+
next_tokens = self.model.image_pad_token_id.to(input_ids.device)
|
| 738 |
+
|
| 739 |
+
else: # mode == "audio" and multimodal_generation_status.is_audio_text_end
|
| 740 |
+
next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
if multimodal_generation_status.mode == "audio":
|
| 744 |
+
# audio_text_ids update
|
| 745 |
+
audio_text_next_tokens = self.model.audiotext_pad_token_id.to(input_ids.device)
|
| 746 |
+
if not multimodal_generation_status.is_audio_text_end:
|
| 747 |
+
audio_text_next_tokens, next_tokens = next_tokens, audio_text_next_tokens
|
| 748 |
+
audio_text_ids = torch.cat((audio_text_ids, audio_text_next_tokens[:, None]), dim=1)
|
| 749 |
+
|
| 750 |
+
# audio_ids update
|
| 751 |
+
if multimodal_generation_status.is_audio_start:
|
| 752 |
+
if outputs.audio_ids[-1, 0] == (self.model.audio_offset_vals[1]): # offset + (level_1_len)
|
| 753 |
+
next_tokens = self.model.audiogen_end_token_id.to(input_ids.device)
|
| 754 |
+
else:
|
| 755 |
+
next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
|
| 756 |
+
audio_ids = torch.cat([audio_ids, outputs.audio_ids], dim=0)
|
| 757 |
+
|
| 758 |
+
elif (multimodal_generation_status.audio_parallel_decoding) or \
|
| 759 |
+
(not multimodal_generation_status.audio_parallel_decoding and multimodal_generation_status.is_audio_text_end):
|
| 760 |
+
next_tokens = self.model.audiotext_start_token_id.to(input_ids.device)
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
# finished sentences should have their next token be a padding token
|
| 764 |
+
if has_eos_stopping_criteria:
|
| 765 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
| 766 |
+
|
| 767 |
+
# update generated ids, model inputs, and length for next step
|
| 768 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 769 |
+
|
| 770 |
+
# TODO: streaming mm ids
|
| 771 |
+
if streamer is not None:
|
| 772 |
+
streamer.put(next_tokens.cpu())
|
| 773 |
+
|
| 774 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
| 775 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
| 776 |
+
cur_len += 1
|
| 777 |
+
|
| 778 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 779 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 780 |
+
del outputs
|
| 781 |
+
|
| 782 |
+
pbar.update(1)
|
| 783 |
+
pbar.set_postfix({
|
| 784 |
+
"recent_5toks": f"{input_ids[:, -5:].tolist()}",
|
| 785 |
+
})
|
| 786 |
+
|
| 787 |
+
pbar.close()
|
| 788 |
+
|
| 789 |
+
if streamer is not None:
|
| 790 |
+
streamer.end()
|
| 791 |
+
|
| 792 |
+
if return_dict_in_generate:
|
| 793 |
+
if self.config.is_encoder_decoder:
|
| 794 |
+
return LongcatNextForCausalLMGenerateEncoderDecoderOutput(
|
| 795 |
+
sequences=input_ids,
|
| 796 |
+
scores=scores,
|
| 797 |
+
logits=raw_logits,
|
| 798 |
+
encoder_attentions=encoder_attentions,
|
| 799 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 800 |
+
decoder_attentions=decoder_attentions,
|
| 801 |
+
cross_attentions=cross_attentions,
|
| 802 |
+
decoder_hidden_states=decoder_hidden_states,
|
| 803 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 804 |
+
visual_ids=visual_ids,
|
| 805 |
+
audio_ids=audio_ids,
|
| 806 |
+
audio_text_ids=audio_text_ids,
|
| 807 |
+
)
|
| 808 |
+
else:
|
| 809 |
+
return LongcatNextForCausalLMGenerateDecoderOnlyOutput(
|
| 810 |
+
sequences=input_ids,
|
| 811 |
+
scores=scores,
|
| 812 |
+
logits=raw_logits,
|
| 813 |
+
attentions=decoder_attentions,
|
| 814 |
+
hidden_states=decoder_hidden_states,
|
| 815 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 816 |
+
visual_ids=visual_ids,
|
| 817 |
+
audio_ids=audio_ids,
|
| 818 |
+
audio_text_ids=audio_text_ids,
|
| 819 |
+
)
|
| 820 |
+
else:
|
| 821 |
+
return input_ids, visual_ids, audio_ids, audio_text_ids
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
__all__ = ["LongcatNextModel", "LongcatNextForCausalLM"]
|
modeling_longcat_ngram.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2026 Meituan
|
| 3 |
+
# This code is licensed under the MIT License, for details, see the ./LICENSE file.
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Tuple, Dict, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 12 |
+
from transformers.masking_utils import create_causal_mask
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 14 |
+
from transformers.processing_utils import Unpack
|
| 15 |
+
from transformers.utils import auto_docstring, logging
|
| 16 |
+
from transformers.models.longcat_flash.modeling_longcat_flash import (
|
| 17 |
+
LongcatFlashForCausalLM,
|
| 18 |
+
LongcatFlashModel,
|
| 19 |
+
LongcatFlashRMSNorm,
|
| 20 |
+
LongcatFlashRotaryEmbedding,
|
| 21 |
+
LongcatFlashDecoderLayer,
|
| 22 |
+
LongcatFlashPreTrainedModel,
|
| 23 |
+
)
|
| 24 |
+
from .configuration_longcat_ngram import LongcatFlashNgramConfig
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@auto_docstring
|
| 30 |
+
class LongcatFlashNgramPreTrainedModel(LongcatFlashPreTrainedModel):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class NgramCache(DynamicCache):
|
| 35 |
+
"""
|
| 36 |
+
Extended DynamicCache for storing N-gram context alongside KV cache.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, config=None):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.ngram_context = None
|
| 41 |
+
# Keep only n-1 tokens (minimum needed for N-gram computation)
|
| 42 |
+
self.max_context_len = config.emb_neighbor_num - 1
|
| 43 |
+
self.oe_ignored_token_ids = torch.tensor(config.oe_ignored_token_ids, dtype=torch.long)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def update_ngram_context(self, new_tokens: torch.Tensor) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Update N-gram context with window management.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
new_tokens: New tokens to append, shape (batch_size, seq_len)
|
| 52 |
+
"""
|
| 53 |
+
new_tokens = new_tokens.clone()
|
| 54 |
+
new_tokens[torch.isin(new_tokens, self.oe_ignored_token_ids.to(new_tokens.device))] = 0
|
| 55 |
+
|
| 56 |
+
if self.ngram_context is None:
|
| 57 |
+
self.ngram_context = new_tokens
|
| 58 |
+
else:
|
| 59 |
+
self.ngram_context = torch.cat([self.ngram_context, new_tokens], dim=-1)
|
| 60 |
+
|
| 61 |
+
# Truncate to maintain constant memory footprint
|
| 62 |
+
if self.ngram_context.size(-1) > self.max_context_len:
|
| 63 |
+
self.ngram_context = self.ngram_context[..., -self.max_context_len:]
|
| 64 |
+
|
| 65 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> "Cache":
|
| 66 |
+
"""Reorder cache for beam search."""
|
| 67 |
+
# Reorder parent's KV cache
|
| 68 |
+
super().reorder_cache(beam_idx)
|
| 69 |
+
|
| 70 |
+
# Reorder N-gram context
|
| 71 |
+
if self.ngram_context is not None:
|
| 72 |
+
self.ngram_context = self.ngram_context.index_select(0, beam_idx.to(self.ngram_context.device))
|
| 73 |
+
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class EmbeddingWithMask(nn.Embedding):
|
| 78 |
+
def forward(self, input: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
x (torch.Tensor): Input indices of shape (batch_size, seq_len)
|
| 82 |
+
mask (torch.Tensor): Boolean mask of shape (batch_size, seq_len).
|
| 83 |
+
True means compute, False means skip and return 0.
|
| 84 |
+
Returns:
|
| 85 |
+
torch.Tensor: Embeddings of shape (batch_size, seq_len, embedding_dim)
|
| 86 |
+
"""
|
| 87 |
+
if mask is not None:
|
| 88 |
+
# Ensure mask is boolean
|
| 89 |
+
mask = mask.bool()
|
| 90 |
+
else:
|
| 91 |
+
mask = torch.ones_like(input, dtype=torch.bool)
|
| 92 |
+
|
| 93 |
+
batch_size, seq_len = input.shape
|
| 94 |
+
embedding_dim = self.embedding_dim
|
| 95 |
+
|
| 96 |
+
# 1. Initialize the output tensor with zeros on the correct device
|
| 97 |
+
output = torch.zeros(
|
| 98 |
+
(batch_size, seq_len, embedding_dim),
|
| 99 |
+
device=input.device,
|
| 100 |
+
dtype=self.weight.dtype
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 2. Filter out the valid indices using the mask
|
| 104 |
+
# valid_indices is a 1D tensor containing only the elements where mask is True
|
| 105 |
+
valid_indices = input[mask]
|
| 106 |
+
|
| 107 |
+
# 3. Only perform the embedding lookup if there is at least one valid index
|
| 108 |
+
if valid_indices.numel() > 0:
|
| 109 |
+
# Look up only the necessary embeddings (saves compute/memory bandwidth)
|
| 110 |
+
valid_embeddings = F.embedding(
|
| 111 |
+
valid_indices, self.weight, self.padding_idx, self.max_norm,
|
| 112 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
| 113 |
+
|
| 114 |
+
# 4. Scatter the valid embeddings back to their original positions in the output tensor
|
| 115 |
+
output[mask] = valid_embeddings
|
| 116 |
+
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class NgramEmbedding(nn.Module):
|
| 121 |
+
"""
|
| 122 |
+
Computes embeddings enriched with N-gram features without maintaining internal state.
|
| 123 |
+
"""
|
| 124 |
+
def __init__(self, config, base_embeddings):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.config = config
|
| 127 |
+
self.word_embeddings = base_embeddings
|
| 128 |
+
|
| 129 |
+
# self.m = config.ngram_vocab_size_ratio * config.vocab_size
|
| 130 |
+
self.m = config.ngram_vocab_size_ratio * config.text_vocab_size
|
| 131 |
+
self.k = config.emb_split_num
|
| 132 |
+
self.n = config.emb_neighbor_num
|
| 133 |
+
self.oe_ignored_token_ids = torch.tensor(config.oe_ignored_token_ids)
|
| 134 |
+
|
| 135 |
+
self._init_ngram_embeddings()
|
| 136 |
+
self._vocab_mods_cache = None
|
| 137 |
+
|
| 138 |
+
def _init_ngram_embeddings(self) -> None:
|
| 139 |
+
"""Initialize N-gram embedding and projection layers."""
|
| 140 |
+
num_embedders = self.k * (self.n - 1)
|
| 141 |
+
emb_dim = self.config.hidden_size // num_embedders
|
| 142 |
+
|
| 143 |
+
embedders = []
|
| 144 |
+
post_projs = []
|
| 145 |
+
|
| 146 |
+
for i in range(num_embedders):
|
| 147 |
+
vocab_size = int(self.m + i * 2 + 1)
|
| 148 |
+
emb = EmbeddingWithMask(vocab_size, emb_dim, padding_idx=self.config.pad_token_id)
|
| 149 |
+
proj = nn.Linear(emb_dim, self.config.hidden_size, bias=False)
|
| 150 |
+
embedders.append(emb)
|
| 151 |
+
post_projs.append(proj)
|
| 152 |
+
|
| 153 |
+
self.embedders = nn.ModuleList(embedders)
|
| 154 |
+
self.post_projs = nn.ModuleList(post_projs)
|
| 155 |
+
|
| 156 |
+
def _shift_right_ignore_eos(self, tensor: torch.Tensor, n: int, eos_token_id: int = 2) -> torch.Tensor:
|
| 157 |
+
p, q = tensor.shape
|
| 158 |
+
# special_token / modal set 0
|
| 159 |
+
special_tokens = 0
|
| 160 |
+
|
| 161 |
+
if n == 0:
|
| 162 |
+
return tensor.clone()
|
| 163 |
+
|
| 164 |
+
if n >= q:
|
| 165 |
+
return torch.zeros_like(tensor)
|
| 166 |
+
|
| 167 |
+
result = torch.zeros_like(tensor)
|
| 168 |
+
|
| 169 |
+
# Find all special_token/modal/EOS locations
|
| 170 |
+
special_mask = (tensor == special_tokens)
|
| 171 |
+
total_mask = (tensor == eos_token_id | special_mask)
|
| 172 |
+
|
| 173 |
+
# Calculate the segment ID to which each position belongs
|
| 174 |
+
eos_cumsum = total_mask.long().cumsum(dim=1)
|
| 175 |
+
# Shift right by 1, so that the first EOS position still belongs to segment 0, and the second EOS position belongs to segment 1
|
| 176 |
+
segment_ids = torch.cat([
|
| 177 |
+
torch.zeros(p, 1, dtype=torch.long, device=tensor.device),
|
| 178 |
+
eos_cumsum[:, :-1]
|
| 179 |
+
], dim=1)
|
| 180 |
+
|
| 181 |
+
col_indices = torch.arange(q, device=tensor.device).unsqueeze(0).expand(p, q)
|
| 182 |
+
# Number of segments
|
| 183 |
+
max_segments = segment_ids.max().item() + 1
|
| 184 |
+
segment_starts = torch.full((p, max_segments), q, dtype=torch.long, device=tensor.device)
|
| 185 |
+
# Calculate the starting position of each segment
|
| 186 |
+
segment_starts.scatter_reduce_(1, segment_ids, col_indices, reduce='amin', include_self=False)
|
| 187 |
+
|
| 188 |
+
# Get the start position of the segment to which each position belongs
|
| 189 |
+
segment_start_per_pos = torch.gather(segment_starts, 1, segment_ids)
|
| 190 |
+
|
| 191 |
+
# Calculate the offset of each position within the segment
|
| 192 |
+
offset_in_segment = col_indices - segment_start_per_pos
|
| 193 |
+
|
| 194 |
+
# Data for each position should be taken from the position offset -n within the segment
|
| 195 |
+
source_offset = offset_in_segment - n
|
| 196 |
+
valid_mask = source_offset >= 0
|
| 197 |
+
|
| 198 |
+
# Calculate the actual source index
|
| 199 |
+
source_indices = segment_start_per_pos + torch.clamp(source_offset, min=0)
|
| 200 |
+
|
| 201 |
+
# Data is collected by source_indices
|
| 202 |
+
result = torch.gather(tensor, 1, source_indices)
|
| 203 |
+
|
| 204 |
+
# Set invalid position to zero
|
| 205 |
+
result = result * valid_mask * (~special_mask)
|
| 206 |
+
|
| 207 |
+
return result
|
| 208 |
+
|
| 209 |
+
def _precompute_vocab_mods(self) -> Dict[Tuple[int, int], List[int]]:
|
| 210 |
+
"""Precompute modular arithmetic values for vocabulary."""
|
| 211 |
+
if self._vocab_mods_cache is not None:
|
| 212 |
+
return self._vocab_mods_cache
|
| 213 |
+
|
| 214 |
+
vocab_mods = {}
|
| 215 |
+
vocab_size = self.config.text_vocab_size
|
| 216 |
+
|
| 217 |
+
for i in range(2, self.n + 1):
|
| 218 |
+
for j in range(self.k):
|
| 219 |
+
index = (i - 2) * self.k + j
|
| 220 |
+
emb_vocab_dim = int(self.m + index * 2 + 1)
|
| 221 |
+
|
| 222 |
+
mods = []
|
| 223 |
+
power_mod = 1
|
| 224 |
+
for _ in range(i - 1):
|
| 225 |
+
power_mod = (power_mod * vocab_size) % emb_vocab_dim
|
| 226 |
+
mods.append(power_mod)
|
| 227 |
+
|
| 228 |
+
vocab_mods[(i, j)] = mods
|
| 229 |
+
|
| 230 |
+
self._vocab_mods_cache = vocab_mods
|
| 231 |
+
return vocab_mods
|
| 232 |
+
|
| 233 |
+
def _get_ngram_ids(
|
| 234 |
+
self,
|
| 235 |
+
input_ids: torch.Tensor,
|
| 236 |
+
shifted_ids: Dict[int, torch.Tensor],
|
| 237 |
+
vocab_mods: List[int],
|
| 238 |
+
ngram: int
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""Compute N-gram hash IDs using polynomial rolling hash."""
|
| 241 |
+
ngram_ids = input_ids.clone()
|
| 242 |
+
for k in range(2, ngram + 1):
|
| 243 |
+
ngram_ids = ngram_ids + shifted_ids[k] * vocab_mods[k - 2]
|
| 244 |
+
return ngram_ids
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
input_ids: torch.Tensor,
|
| 249 |
+
ngram_context: Optional[torch.Tensor] = None
|
| 250 |
+
) -> torch.Tensor:
|
| 251 |
+
"""
|
| 252 |
+
Stateless forward pass.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
input_ids: Current input token IDs of shape (batch_size, seq_len)
|
| 256 |
+
ngram_context: Optional historical context of shape (batch_size, context_len)
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Embedding tensor of shape (batch_size, seq_len, hidden_size)
|
| 260 |
+
"""
|
| 261 |
+
seq_len = input_ids.size(-1)
|
| 262 |
+
|
| 263 |
+
# Determine complete context
|
| 264 |
+
if ngram_context is not None:
|
| 265 |
+
context = torch.cat([ngram_context[..., -(self.n-1):], input_ids], dim=-1)
|
| 266 |
+
else:
|
| 267 |
+
context = input_ids.clone()
|
| 268 |
+
|
| 269 |
+
# Skip N-gram look-up for oe_ignored_token_ids
|
| 270 |
+
oe_ignored_mask = torch.isin(input_ids, self.oe_ignored_token_ids.to(device=input_ids.device))
|
| 271 |
+
context[torch.isin(context, self.oe_ignored_token_ids.to(device=context.device))] = 0
|
| 272 |
+
|
| 273 |
+
# Base word embeddings
|
| 274 |
+
device = self.word_embeddings.weight.device
|
| 275 |
+
x = self.word_embeddings(input_ids.to(device)).clone()
|
| 276 |
+
|
| 277 |
+
# Precompute modular values
|
| 278 |
+
vocab_mods = self._precompute_vocab_mods()
|
| 279 |
+
|
| 280 |
+
# Compute shifted IDs
|
| 281 |
+
shifted_ids = {}
|
| 282 |
+
for i in range(2, self.n + 1):
|
| 283 |
+
shifted_ids[i] = self._shift_right_ignore_eos(
|
| 284 |
+
context, i - 1, eos_token_id=self.config.eos_token_id
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Add N-gram embeddings
|
| 288 |
+
for i in range(2, self.n + 1):
|
| 289 |
+
for j in range(self.k):
|
| 290 |
+
index = (i - 2) * self.k + j
|
| 291 |
+
emb_vocab_dim = int(self.m + index * 2 + 1)
|
| 292 |
+
|
| 293 |
+
ngram_ids = self._get_ngram_ids(context, shifted_ids, vocab_mods[(i, j)], ngram=i)
|
| 294 |
+
new_ids = (ngram_ids % emb_vocab_dim)[..., -seq_len:]
|
| 295 |
+
text_mask = new_ids > 0
|
| 296 |
+
|
| 297 |
+
embedder_device = self.embedders[index].weight.device
|
| 298 |
+
x_ngram = self.embedders[index](new_ids.to(embedder_device), text_mask)
|
| 299 |
+
|
| 300 |
+
proj_device = self.post_projs[index].weight.device
|
| 301 |
+
x_proj = self.post_projs[index](x_ngram.to(proj_device))
|
| 302 |
+
x = x + x_proj.to(x.device)
|
| 303 |
+
|
| 304 |
+
# Normalize
|
| 305 |
+
x[~oe_ignored_mask] /= (1 + self.k * (self.n - 1))
|
| 306 |
+
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class LongcatFlashNgramModel(LongcatFlashModel):
|
| 311 |
+
"""LongcatFlash model with N-gram enhanced embeddings."""
|
| 312 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 313 |
+
config_class = LongcatFlashNgramConfig
|
| 314 |
+
|
| 315 |
+
def __init__(self, config):
|
| 316 |
+
super().__init__(config)
|
| 317 |
+
|
| 318 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 319 |
+
self.ngram_embeddings = NgramEmbedding(config, self.embed_tokens)
|
| 320 |
+
|
| 321 |
+
self.layers = nn.ModuleList(
|
| 322 |
+
[LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers)]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.head_dim = config.head_dim
|
| 326 |
+
self.config.num_hidden_layers = 2 * config.num_layers
|
| 327 |
+
self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 328 |
+
self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
|
| 329 |
+
self.gradient_checkpointing = False
|
| 330 |
+
|
| 331 |
+
self.post_init()
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 336 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 337 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 338 |
+
past_key_values: Optional[Cache] = None,
|
| 339 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 340 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 341 |
+
use_cache: Optional[bool] = None,
|
| 342 |
+
**kwargs
|
| 343 |
+
) -> BaseModelOutputWithPast:
|
| 344 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 345 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 346 |
+
|
| 347 |
+
# Extract N-gram context if available
|
| 348 |
+
ngram_context = None
|
| 349 |
+
if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
|
| 350 |
+
ngram_context = past_key_values.ngram_context
|
| 351 |
+
|
| 352 |
+
if inputs_embeds is None:
|
| 353 |
+
inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
|
| 354 |
+
|
| 355 |
+
# Initialize NgramCache if needed
|
| 356 |
+
if use_cache and past_key_values is None:
|
| 357 |
+
past_key_values = NgramCache(config=self.config)
|
| 358 |
+
|
| 359 |
+
# Update N-gram context
|
| 360 |
+
if use_cache and isinstance(past_key_values, NgramCache) and input_ids is not None:
|
| 361 |
+
past_key_values.update_ngram_context(input_ids)
|
| 362 |
+
|
| 363 |
+
# Prepare cache position
|
| 364 |
+
if cache_position is None:
|
| 365 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 366 |
+
cache_position = torch.arange(
|
| 367 |
+
inputs_embeds.shape[1], device=inputs_embeds.device
|
| 368 |
+
) + past_seen_tokens
|
| 369 |
+
|
| 370 |
+
if position_ids is None:
|
| 371 |
+
position_ids = cache_position.unsqueeze(0)
|
| 372 |
+
|
| 373 |
+
# Create causal mask
|
| 374 |
+
causal_mask = create_causal_mask(
|
| 375 |
+
config=self.config,
|
| 376 |
+
input_embeds=inputs_embeds,
|
| 377 |
+
attention_mask=attention_mask,
|
| 378 |
+
cache_position=cache_position,
|
| 379 |
+
past_key_values=past_key_values,
|
| 380 |
+
position_ids=position_ids,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Forward through decoder layers
|
| 384 |
+
hidden_states = inputs_embeds
|
| 385 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 386 |
+
|
| 387 |
+
for decoder_layer in self.layers[: self.config.num_layers]:
|
| 388 |
+
hidden_states = decoder_layer(
|
| 389 |
+
hidden_states,
|
| 390 |
+
attention_mask=causal_mask,
|
| 391 |
+
position_ids=position_ids,
|
| 392 |
+
past_key_values=past_key_values,
|
| 393 |
+
cache_position=cache_position,
|
| 394 |
+
position_embeddings=position_embeddings,
|
| 395 |
+
**kwargs,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
hidden_states = self.norm(hidden_states)
|
| 399 |
+
|
| 400 |
+
return BaseModelOutputWithPast(
|
| 401 |
+
last_hidden_state=hidden_states,
|
| 402 |
+
past_key_values=past_key_values,
|
| 403 |
+
hidden_states=None,
|
| 404 |
+
attentions=None,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class LongcatFlashNgramForCausalLM(LongcatFlashForCausalLM):
|
| 409 |
+
"""LongcatFlash model for causal language modeling with N-gram embeddings."""
|
| 410 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 411 |
+
config_class = LongcatFlashNgramConfig
|
| 412 |
+
|
| 413 |
+
def __init__(self, config):
|
| 414 |
+
super().__init__(config)
|
| 415 |
+
self.model = LongcatFlashNgramModel(config)
|
| 416 |
+
|
| 417 |
+
@torch.no_grad()
|
| 418 |
+
def generate(self, inputs=None, generation_config=None, **kwargs):
|
| 419 |
+
"""Override to ensure NgramCache is used."""
|
| 420 |
+
|
| 421 |
+
if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
|
| 422 |
+
kwargs["past_key_values"] = NgramCache(config=self.config)
|
| 423 |
+
|
| 424 |
+
return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
|
| 425 |
+
|
| 426 |
+
__all__ = ["LongcatFlashNgramPreTrainedModel", "LongcatFlashNgramModel", "LongcatFlashNgramForCausalLM"]
|
modular_longcat_next.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from flash_attn import flash_attn_varlen_func
|
| 6 |
+
|
| 7 |
+
from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FlashVarLenAttention(nn.Module):
|
| 11 |
+
def __init__(self, embed_dim, num_heads, causal=False, window_size=(-1,-1)):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.embed_dim = embed_dim
|
| 14 |
+
self.num_heads = num_heads
|
| 15 |
+
self.head_dim = embed_dim // num_heads
|
| 16 |
+
|
| 17 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 18 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 19 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 20 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 21 |
+
|
| 22 |
+
self.causal = causal
|
| 23 |
+
self.window_size = window_size
|
| 24 |
+
|
| 25 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
| 26 |
+
bsz, _ = hidden_states.size()
|
| 27 |
+
|
| 28 |
+
query_states = self.q_proj(hidden_states)
|
| 29 |
+
query_states = query_states.view(bsz, self.num_heads, self.head_dim).contiguous()
|
| 30 |
+
key_states = self.k_proj(hidden_states)
|
| 31 |
+
key_states = key_states.view(bsz, self.num_heads, self.head_dim).contiguous()
|
| 32 |
+
value_states = self.v_proj(hidden_states)
|
| 33 |
+
value_states = value_states.view(bsz, self.num_heads, self.head_dim).contiguous()
|
| 34 |
+
|
| 35 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
| 36 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
| 37 |
+
|
| 38 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
|
| 39 |
+
max_seqlen, causal=self.causal, window_size=self.window_size) # (bsz * qlen, nheads, headdim)
|
| 40 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
| 41 |
+
attn_output = self.out_proj(attn_output)
|
| 42 |
+
return attn_output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class CasualDepthTransformerLayer(nn.Module):
|
| 47 |
+
def __init__(self, depth, transformer_dim, transformer_ffn_scale):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.depth = depth
|
| 50 |
+
self.transformer_dim = transformer_dim
|
| 51 |
+
self.transformer_ffn_scale = transformer_ffn_scale
|
| 52 |
+
self.num_heads = self.transformer_dim // 128
|
| 53 |
+
|
| 54 |
+
assert self.transformer_dim % 128 == 0
|
| 55 |
+
assert self.transformer_dim % depth == 0
|
| 56 |
+
|
| 57 |
+
self.self_attention = FlashVarLenAttention(embed_dim=self.transformer_dim, num_heads=self.num_heads, causal=True)
|
| 58 |
+
|
| 59 |
+
self.layernorm1 = RMSNorm(self.transformer_dim)
|
| 60 |
+
self.layernorm2 = RMSNorm(self.transformer_dim)
|
| 61 |
+
|
| 62 |
+
self.linear1 = nn.Linear(self.transformer_dim, self.transformer_ffn_scale * self.transformer_dim)
|
| 63 |
+
self.linear2 = nn.Linear(self.transformer_ffn_scale * self.transformer_dim, self.transformer_dim)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
bsz = x.shape[0]
|
| 67 |
+
res = x
|
| 68 |
+
x = self.layernorm1(x)
|
| 69 |
+
seqlens = torch.tensor([self.depth] * bsz, dtype=torch.int32, device=x.device)
|
| 70 |
+
_x = self.self_attention(x.view(-1, self.transformer_dim), seqlens)
|
| 71 |
+
_x = _x.view(bsz, self.depth, self.transformer_dim).contiguous()
|
| 72 |
+
|
| 73 |
+
_res = _x + res # (bs, sl, d)
|
| 74 |
+
res = self.layernorm2(_res)
|
| 75 |
+
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (self.transformer_ffn_scale * self.transformer_dim // self.depth, self.depth, self.transformer_dim)))
|
| 76 |
+
x = torch.nn.functional.gelu(x)
|
| 77 |
+
x = torch.einsum('blt,dlt->bld',x, torch.reshape(self.linear2.weight, (self.transformer_dim, self.depth, self.transformer_ffn_scale * self.transformer_dim // self.depth)))
|
| 78 |
+
return _res + x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CasualDepthTransformerHead(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Depth-wise causal transformer head shared by image/audio heads.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
hidden_size,
|
| 89 |
+
codebook_sizes,
|
| 90 |
+
transformer_layer_num,
|
| 91 |
+
transformer_dim,
|
| 92 |
+
transformer_ffn_scale,
|
| 93 |
+
gradient_checkpointing=False,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.hidden_size = hidden_size
|
| 97 |
+
self.codebook_sizes = codebook_sizes
|
| 98 |
+
self.transformer_ffn_scale = transformer_ffn_scale
|
| 99 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 100 |
+
|
| 101 |
+
if self.transformer_ffn_scale > 0:
|
| 102 |
+
self.hidden_norm = RMSNorm(self.hidden_size)
|
| 103 |
+
self.hidden_proj = nn.Linear(self.hidden_size, transformer_dim, bias=False)
|
| 104 |
+
|
| 105 |
+
self.transformer_layers = nn.ModuleList(
|
| 106 |
+
[
|
| 107 |
+
CasualDepthTransformerLayer(len(codebook_sizes), transformer_dim, transformer_ffn_scale)
|
| 108 |
+
for _ in range(transformer_layer_num)
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
self.headnorm = RMSNorm(transformer_dim)
|
| 112 |
+
self.heads = nn.ModuleList(
|
| 113 |
+
[nn.Linear(transformer_dim, vq_size + 1) for vq_size in codebook_sizes]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
for param in self.parameters():
|
| 117 |
+
param.requires_grad = False
|
| 118 |
+
|
| 119 |
+
def forward(self, x, visual_tokens, visual_emb_layers, level):
|
| 120 |
+
main_device = "cuda:0"
|
| 121 |
+
visual_tokens = visual_tokens.to(main_device)
|
| 122 |
+
visual_emb_layers = visual_emb_layers.to(main_device)
|
| 123 |
+
|
| 124 |
+
cumsum_visual_embed = torch.stack([
|
| 125 |
+
visual_emb_layers(visual_tokens[..., i])
|
| 126 |
+
for i, vq_size in enumerate(self.codebook_sizes[:-1])
|
| 127 |
+
], dim=1).to(x.device)
|
| 128 |
+
|
| 129 |
+
cumsum_visual_embed = torch.cumsum(cumsum_visual_embed, dim=1) # (bs, depth-1, d)
|
| 130 |
+
|
| 131 |
+
hidden_states = torch.concat([x.reshape(-1, 1, self.hidden_size), cumsum_visual_embed], dim=1) # (bs, depth, d)
|
| 132 |
+
assert hidden_states.size(1) == len(self.codebook_sizes)
|
| 133 |
+
|
| 134 |
+
if self.transformer_ffn_scale > 0:
|
| 135 |
+
hidden_states = self.hidden_norm(hidden_states)
|
| 136 |
+
hidden_states = self.hidden_proj(hidden_states)
|
| 137 |
+
|
| 138 |
+
for i, tlayer in enumerate(self.transformer_layers):
|
| 139 |
+
if self.gradient_checkpointing and self.training:
|
| 140 |
+
|
| 141 |
+
def create_custom_forward(module):
|
| 142 |
+
def custom_forward(*inputs):
|
| 143 |
+
# None for past_key_value
|
| 144 |
+
return module(*inputs)
|
| 145 |
+
|
| 146 |
+
return custom_forward
|
| 147 |
+
|
| 148 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 149 |
+
create_custom_forward(tlayer), hidden_states,
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
hidden_states = tlayer(
|
| 153 |
+
hidden_states,
|
| 154 |
+
)
|
| 155 |
+
hidden_states = self.headnorm(hidden_states)
|
| 156 |
+
logits = self.heads[level](hidden_states[:, level])
|
| 157 |
+
return logits
|
modular_longcat_next_audio.py
ADDED
|
@@ -0,0 +1,2039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import copy
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
from einops import pack, rearrange, repeat
|
| 11 |
+
from flash_attn import flash_attn_varlen_func
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.cuda.amp import autocast
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
|
| 16 |
+
from diffusers.models.activations import get_activation
|
| 17 |
+
from diffusers.models.attention import (
|
| 18 |
+
GEGLU,
|
| 19 |
+
GELU,
|
| 20 |
+
AdaLayerNorm,
|
| 21 |
+
AdaLayerNormZero,
|
| 22 |
+
ApproximateGELU,
|
| 23 |
+
)
|
| 24 |
+
from diffusers.models.attention_processor import Attention
|
| 25 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
| 26 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.modeling_outputs import ModelOutput
|
| 30 |
+
from transformers.utils import logging
|
| 31 |
+
|
| 32 |
+
from .cosy24k_vocoder import Cosy24kVocoder
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 38 |
+
"""Returns sinusoids for positional embedding"""
|
| 39 |
+
assert channels % 2 == 0
|
| 40 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 41 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 42 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 43 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_sequence_mask(inputs, inputs_length):
|
| 47 |
+
if inputs.dim() == 3:
|
| 48 |
+
bsz, tgt_len, _ = inputs.size()
|
| 49 |
+
else:
|
| 50 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
| 51 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
| 52 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
| 53 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
|
| 54 |
+
return sequence_mask, unpacking_index
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def unpack_hidden_states(hidden_states, lengths):
|
| 58 |
+
bsz = lengths.shape[0]
|
| 59 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
|
| 60 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 61 |
+
bsz, torch.max(lengths), hidden_states.shape[-1]
|
| 62 |
+
)
|
| 63 |
+
hidden_states = torch.where(
|
| 64 |
+
sequence_mask, hidden_states, 0
|
| 65 |
+
) # 3d (bsz, max_input_len, d)
|
| 66 |
+
return hidden_states
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def uniform_init(*shape):
|
| 70 |
+
t = torch.zeros(shape)
|
| 71 |
+
nn.init.kaiming_uniform_(t)
|
| 72 |
+
return t
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def cdist(x, y):
|
| 76 |
+
x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
|
| 77 |
+
y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
|
| 78 |
+
xy = torch.einsum('bd,cd->bc', x, y) * -2
|
| 79 |
+
return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 83 |
+
assert mask.dtype == torch.bool
|
| 84 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
| 85 |
+
mask = mask.to(dtype)
|
| 86 |
+
# attention mask bias
|
| 87 |
+
# NOTE(Mddct): torch.finfo jit issues
|
| 88 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
| 89 |
+
mask = (1.0 - mask) * torch.finfo(dtype).min
|
| 90 |
+
return mask
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def subsequent_chunk_mask(
|
| 94 |
+
size: int,
|
| 95 |
+
chunk_size: int,
|
| 96 |
+
num_left_chunks: int = -1,
|
| 97 |
+
device: torch.device = torch.device("cpu"),
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
| 100 |
+
this is for streaming encoder
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
size (int): size of mask
|
| 104 |
+
chunk_size (int): size of chunk
|
| 105 |
+
num_left_chunks (int): number of left chunks
|
| 106 |
+
<0: use full chunk
|
| 107 |
+
>=0: use num_left_chunks
|
| 108 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
torch.Tensor: mask
|
| 112 |
+
|
| 113 |
+
Examples:
|
| 114 |
+
>>> subsequent_chunk_mask(4, 2)
|
| 115 |
+
[[1, 1, 0, 0],
|
| 116 |
+
[1, 1, 0, 0],
|
| 117 |
+
[1, 1, 1, 1],
|
| 118 |
+
[1, 1, 1, 1]]
|
| 119 |
+
"""
|
| 120 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
| 121 |
+
# actually this is not needed after we have inference cache implemented, will remove it later
|
| 122 |
+
pos_idx = torch.arange(size, device=device)
|
| 123 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
| 124 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
| 125 |
+
return ret
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
| 129 |
+
masks: torch.Tensor,
|
| 130 |
+
use_dynamic_chunk: bool,
|
| 131 |
+
use_dynamic_left_chunk: bool,
|
| 132 |
+
decoding_chunk_size: int,
|
| 133 |
+
static_chunk_size: int,
|
| 134 |
+
num_decoding_left_chunks: int,
|
| 135 |
+
enable_full_context: bool = True):
|
| 136 |
+
""" Apply optional mask for encoder.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
| 140 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
| 141 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
| 142 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
| 143 |
+
training.
|
| 144 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
| 145 |
+
0: default for training, use random dynamic chunk.
|
| 146 |
+
<0: for decoding, use full chunk.
|
| 147 |
+
>0: for decoding, use fixed chunk size as set.
|
| 148 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
| 149 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
| 150 |
+
this parameter will be ignored
|
| 151 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 152 |
+
the chunk size is decoding_chunk_size.
|
| 153 |
+
>=0: use num_decoding_left_chunks
|
| 154 |
+
<0: use all left chunks
|
| 155 |
+
enable_full_context (bool):
|
| 156 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
| 157 |
+
False: chunk size ~ U[1, 25]
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
torch.Tensor: chunk mask of the input xs.
|
| 161 |
+
"""
|
| 162 |
+
# Whether to use chunk mask or not
|
| 163 |
+
if use_dynamic_chunk:
|
| 164 |
+
max_len = xs.size(1)
|
| 165 |
+
if decoding_chunk_size < 0:
|
| 166 |
+
chunk_size = max_len
|
| 167 |
+
num_left_chunks = -1
|
| 168 |
+
elif decoding_chunk_size > 0:
|
| 169 |
+
chunk_size = decoding_chunk_size
|
| 170 |
+
num_left_chunks = num_decoding_left_chunks
|
| 171 |
+
else:
|
| 172 |
+
# chunk size is either [1, 25] or full context(max_len).
|
| 173 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
| 174 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
| 175 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
| 176 |
+
num_left_chunks = -1
|
| 177 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
| 178 |
+
chunk_size = max_len
|
| 179 |
+
else:
|
| 180 |
+
chunk_size = chunk_size % 25 + 1
|
| 181 |
+
if use_dynamic_left_chunk:
|
| 182 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
| 183 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
| 184 |
+
(1, )).item()
|
| 185 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
| 186 |
+
num_left_chunks,
|
| 187 |
+
xs.device) # (L, L)
|
| 188 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 189 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 190 |
+
elif static_chunk_size > 0:
|
| 191 |
+
num_left_chunks = num_decoding_left_chunks
|
| 192 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
| 193 |
+
num_left_chunks,
|
| 194 |
+
xs.device) # (L, L)
|
| 195 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 196 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 197 |
+
else:
|
| 198 |
+
chunk_masks = masks
|
| 199 |
+
return chunk_masks
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class EuclideanCodebook(nn.Module):
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
dim,
|
| 206 |
+
codebook_size,
|
| 207 |
+
init_std=0.02,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.init_std = init_std
|
| 211 |
+
self.dim = dim
|
| 212 |
+
self.codebook_size = codebook_size
|
| 213 |
+
|
| 214 |
+
embed = uniform_init(codebook_size, dim).to(torch.float32)
|
| 215 |
+
self.cluster_size = nn.Parameter(torch.ones(codebook_size))
|
| 216 |
+
self.embed_avg = nn.Parameter(embed.clone())
|
| 217 |
+
self.embed = nn.Parameter(embed)
|
| 218 |
+
del embed
|
| 219 |
+
|
| 220 |
+
@autocast(enabled=True, dtype=torch.float32)
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
assert(len(x.shape) == 2)
|
| 224 |
+
assert(x.dtype == torch.float32)
|
| 225 |
+
embed = self.embed.detach().to(x.device)
|
| 226 |
+
dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
|
| 227 |
+
embed_ind = dist.argmax(dim=-1)
|
| 228 |
+
quantize = embed[embed_ind] # (bs*sl, d)
|
| 229 |
+
return quantize, embed_ind, dist
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class VectorQuantize(nn.Module):
|
| 233 |
+
def __init__(self, config, *args, **kwargs):
|
| 234 |
+
super().__init__(*args, **kwargs)
|
| 235 |
+
self.config = config
|
| 236 |
+
self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
|
| 237 |
+
|
| 238 |
+
def forward(self, x, input_length):
|
| 239 |
+
batch_size, seq_len, _ = x.shape
|
| 240 |
+
mask, unpacking_index = get_sequence_mask(x, input_length)
|
| 241 |
+
if x.dtype != torch.float32:
|
| 242 |
+
x = x.to(torch.float32)
|
| 243 |
+
x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
|
| 244 |
+
quantize, embed_ind, _ = self.codebook(x)
|
| 245 |
+
quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
|
| 246 |
+
quantize = torch.where(mask, quantize, 0)
|
| 247 |
+
embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
|
| 248 |
+
embed_ind = torch.where(mask, embed_ind, -1).squeeze()
|
| 249 |
+
return quantize, embed_ind
|
| 250 |
+
|
| 251 |
+
def get_output_from_indices(self, indices):
|
| 252 |
+
indices = indices.to(self.codebook.embed.device)
|
| 253 |
+
return self.codebook.embed[indices]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class SnakeBeta(nn.Module):
|
| 257 |
+
"""
|
| 258 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 259 |
+
Shape:
|
| 260 |
+
- Input: (B, C, T)
|
| 261 |
+
- Output: (B, C, T), same shape as the input
|
| 262 |
+
Parameters:
|
| 263 |
+
- alpha - trainable parameter that controls frequency
|
| 264 |
+
- beta - trainable parameter that controls magnitude
|
| 265 |
+
References:
|
| 266 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 267 |
+
https://arxiv.org/abs/2006.08195
|
| 268 |
+
Examples:
|
| 269 |
+
>>> a1 = snakebeta(256)
|
| 270 |
+
>>> x = torch.randn(256)
|
| 271 |
+
>>> x = a1(x)
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
in_features,
|
| 277 |
+
out_features,
|
| 278 |
+
alpha=1.0,
|
| 279 |
+
alpha_trainable=True,
|
| 280 |
+
alpha_logscale=True,
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Initialization.
|
| 284 |
+
INPUT:
|
| 285 |
+
- in_features: shape of the input
|
| 286 |
+
- alpha - trainable parameter that controls frequency
|
| 287 |
+
- beta - trainable parameter that controls magnitude
|
| 288 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 289 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 290 |
+
alpha will be trained along with the rest of your model.
|
| 291 |
+
"""
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.in_features = (
|
| 294 |
+
out_features if isinstance(out_features, list) else [out_features]
|
| 295 |
+
)
|
| 296 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
| 297 |
+
|
| 298 |
+
# initialize alpha
|
| 299 |
+
self.alpha_logscale = alpha_logscale
|
| 300 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 301 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
| 302 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
| 303 |
+
else: # linear scale alphas initialized to ones
|
| 304 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
| 305 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
| 306 |
+
|
| 307 |
+
self.alpha.requires_grad = alpha_trainable
|
| 308 |
+
self.beta.requires_grad = alpha_trainable
|
| 309 |
+
|
| 310 |
+
self.no_div_by_zero = 0.000000001
|
| 311 |
+
|
| 312 |
+
def forward(self, x):
|
| 313 |
+
"""
|
| 314 |
+
Forward pass of the function.
|
| 315 |
+
Applies the function to the input elementwise.
|
| 316 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 317 |
+
"""
|
| 318 |
+
x = self.proj(x)
|
| 319 |
+
if self.alpha_logscale:
|
| 320 |
+
alpha = torch.exp(self.alpha)
|
| 321 |
+
beta = torch.exp(self.beta)
|
| 322 |
+
else:
|
| 323 |
+
alpha = self.alpha
|
| 324 |
+
beta = self.beta
|
| 325 |
+
|
| 326 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
| 327 |
+
torch.sin(x * alpha), 2
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return x
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class FeedForward(nn.Module):
|
| 334 |
+
r"""
|
| 335 |
+
A feed-forward layer.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
dim (`int`): The number of channels in the input.
|
| 339 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 340 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 341 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 342 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 343 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(
|
| 347 |
+
self,
|
| 348 |
+
dim: int,
|
| 349 |
+
dim_out: Optional[int] = None,
|
| 350 |
+
mult: int = 4,
|
| 351 |
+
dropout: float = 0.0,
|
| 352 |
+
activation_fn: str = "geglu",
|
| 353 |
+
final_dropout: bool = False,
|
| 354 |
+
):
|
| 355 |
+
super().__init__()
|
| 356 |
+
inner_dim = int(dim * mult)
|
| 357 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 358 |
+
|
| 359 |
+
if activation_fn == "gelu":
|
| 360 |
+
act_fn = GELU(dim, inner_dim)
|
| 361 |
+
if activation_fn == "gelu-approximate":
|
| 362 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
| 363 |
+
elif activation_fn == "geglu":
|
| 364 |
+
act_fn = GEGLU(dim, inner_dim)
|
| 365 |
+
elif activation_fn == "geglu-approximate":
|
| 366 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
| 367 |
+
elif activation_fn == "snakebeta":
|
| 368 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
| 369 |
+
|
| 370 |
+
self.net = nn.ModuleList([])
|
| 371 |
+
# project in
|
| 372 |
+
self.net.append(act_fn)
|
| 373 |
+
# project dropout
|
| 374 |
+
self.net.append(nn.Dropout(dropout))
|
| 375 |
+
# project out
|
| 376 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
| 377 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 378 |
+
if final_dropout:
|
| 379 |
+
self.net.append(nn.Dropout(dropout))
|
| 380 |
+
|
| 381 |
+
def forward(self, hidden_states):
|
| 382 |
+
for module in self.net:
|
| 383 |
+
hidden_states = module(hidden_states)
|
| 384 |
+
return hidden_states
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
@maybe_allow_in_graph
|
| 388 |
+
class BasicTransformerBlock(nn.Module):
|
| 389 |
+
r"""
|
| 390 |
+
A basic Transformer block.
|
| 391 |
+
|
| 392 |
+
Parameters:
|
| 393 |
+
dim (`int`): The number of channels in the input and output.
|
| 394 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 395 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 396 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 397 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 398 |
+
only_cross_attention (`bool`, *optional*):
|
| 399 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 400 |
+
double_self_attention (`bool`, *optional*):
|
| 401 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 402 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 403 |
+
num_embeds_ada_norm (:
|
| 404 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 405 |
+
attention_bias (:
|
| 406 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
dim: int,
|
| 412 |
+
num_attention_heads: int,
|
| 413 |
+
attention_head_dim: int,
|
| 414 |
+
dropout=0.0,
|
| 415 |
+
cross_attention_dim: Optional[int] = None,
|
| 416 |
+
activation_fn: str = "geglu",
|
| 417 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 418 |
+
attention_bias: bool = False,
|
| 419 |
+
only_cross_attention: bool = False,
|
| 420 |
+
double_self_attention: bool = False,
|
| 421 |
+
upcast_attention: bool = False,
|
| 422 |
+
norm_elementwise_affine: bool = True,
|
| 423 |
+
norm_type: str = "layer_norm",
|
| 424 |
+
final_dropout: bool = False,
|
| 425 |
+
use_omni_attn: bool = False,
|
| 426 |
+
):
|
| 427 |
+
super().__init__()
|
| 428 |
+
|
| 429 |
+
self.use_omni_attn = use_omni_attn
|
| 430 |
+
self.dim = dim
|
| 431 |
+
|
| 432 |
+
self.only_cross_attention = only_cross_attention
|
| 433 |
+
|
| 434 |
+
self.use_ada_layer_norm_zero = (
|
| 435 |
+
num_embeds_ada_norm is not None
|
| 436 |
+
) and norm_type == "ada_norm_zero"
|
| 437 |
+
self.use_ada_layer_norm = (
|
| 438 |
+
num_embeds_ada_norm is not None
|
| 439 |
+
) and norm_type == "ada_norm"
|
| 440 |
+
|
| 441 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
| 442 |
+
raise ValueError(
|
| 443 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
| 444 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 448 |
+
# 1. Self-Attn
|
| 449 |
+
if self.use_ada_layer_norm:
|
| 450 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 451 |
+
elif self.use_ada_layer_norm_zero:
|
| 452 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
| 453 |
+
else:
|
| 454 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 455 |
+
|
| 456 |
+
if self.use_omni_attn:
|
| 457 |
+
if only_cross_attention:
|
| 458 |
+
raise NotImplementedError
|
| 459 |
+
print(
|
| 460 |
+
"Use OmniWhisperAttention with flash attention. Dropout is ignored."
|
| 461 |
+
)
|
| 462 |
+
self.attn1 = OmniWhisperAttention(
|
| 463 |
+
embed_dim=dim, num_heads=num_attention_heads, causal=False
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
self.attn1 = Attention(
|
| 467 |
+
query_dim=dim,
|
| 468 |
+
heads=num_attention_heads,
|
| 469 |
+
dim_head=attention_head_dim,
|
| 470 |
+
dropout=dropout,
|
| 471 |
+
bias=attention_bias,
|
| 472 |
+
cross_attention_dim=(
|
| 473 |
+
cross_attention_dim if only_cross_attention else None
|
| 474 |
+
),
|
| 475 |
+
upcast_attention=upcast_attention,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# 2. Cross-Attn
|
| 479 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 480 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 481 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 482 |
+
# the second cross attention block.
|
| 483 |
+
self.norm2 = (
|
| 484 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 485 |
+
if self.use_ada_layer_norm
|
| 486 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 487 |
+
)
|
| 488 |
+
self.attn2 = Attention(
|
| 489 |
+
query_dim=dim,
|
| 490 |
+
cross_attention_dim=(
|
| 491 |
+
cross_attention_dim if not double_self_attention else None
|
| 492 |
+
),
|
| 493 |
+
heads=num_attention_heads,
|
| 494 |
+
dim_head=attention_head_dim,
|
| 495 |
+
dropout=dropout,
|
| 496 |
+
bias=attention_bias,
|
| 497 |
+
upcast_attention=upcast_attention,
|
| 498 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
| 499 |
+
) # is self-attn if encoder_hidden_states is none
|
| 500 |
+
else:
|
| 501 |
+
self.norm2 = None
|
| 502 |
+
self.attn2 = None
|
| 503 |
+
|
| 504 |
+
# 3. Feed-forward
|
| 505 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 506 |
+
self.ff = FeedForward(
|
| 507 |
+
dim,
|
| 508 |
+
dropout=dropout,
|
| 509 |
+
activation_fn=activation_fn,
|
| 510 |
+
final_dropout=final_dropout,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# let chunk size default to None
|
| 514 |
+
self._chunk_size = None
|
| 515 |
+
self._chunk_dim = 0
|
| 516 |
+
|
| 517 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
| 518 |
+
# Sets chunk feed-forward
|
| 519 |
+
self._chunk_size = chunk_size
|
| 520 |
+
self._chunk_dim = dim
|
| 521 |
+
|
| 522 |
+
def forward(
|
| 523 |
+
self,
|
| 524 |
+
hidden_states: torch.FloatTensor,
|
| 525 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 526 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 527 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 528 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 529 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 530 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 531 |
+
):
|
| 532 |
+
|
| 533 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
| 534 |
+
|
| 535 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 536 |
+
# 1. Self-Attention
|
| 537 |
+
if self.use_ada_layer_norm:
|
| 538 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 539 |
+
elif self.use_ada_layer_norm_zero:
|
| 540 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 541 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 542 |
+
)
|
| 543 |
+
else:
|
| 544 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 545 |
+
|
| 546 |
+
cross_attention_kwargs = (
|
| 547 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if self.use_omni_attn:
|
| 551 |
+
seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
|
| 552 |
+
var_len_attention_mask, unpacking_index = get_sequence_mask(
|
| 553 |
+
norm_hidden_states, seq_len
|
| 554 |
+
)
|
| 555 |
+
norm_hidden_states = torch.masked_select(
|
| 556 |
+
norm_hidden_states, var_len_attention_mask
|
| 557 |
+
)
|
| 558 |
+
norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
|
| 559 |
+
attn_output = self.attn1(norm_hidden_states, seq_len)
|
| 560 |
+
# unpacking
|
| 561 |
+
attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
|
| 562 |
+
bsz, tgt_len, d_model
|
| 563 |
+
)
|
| 564 |
+
attn_output = torch.where(var_len_attention_mask, attn_output, 0)
|
| 565 |
+
else:
|
| 566 |
+
attn_output = self.attn1(
|
| 567 |
+
norm_hidden_states,
|
| 568 |
+
encoder_hidden_states=(
|
| 569 |
+
encoder_hidden_states if self.only_cross_attention else None
|
| 570 |
+
),
|
| 571 |
+
attention_mask=(
|
| 572 |
+
encoder_attention_mask
|
| 573 |
+
if self.only_cross_attention
|
| 574 |
+
else attention_mask
|
| 575 |
+
),
|
| 576 |
+
**cross_attention_kwargs,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
if self.use_ada_layer_norm_zero:
|
| 580 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 581 |
+
hidden_states = attn_output + hidden_states
|
| 582 |
+
|
| 583 |
+
# 2. Cross-Attention
|
| 584 |
+
if self.attn2 is not None:
|
| 585 |
+
norm_hidden_states = (
|
| 586 |
+
self.norm2(hidden_states, timestep)
|
| 587 |
+
if self.use_ada_layer_norm
|
| 588 |
+
else self.norm2(hidden_states)
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
attn_output = self.attn2(
|
| 592 |
+
norm_hidden_states,
|
| 593 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 594 |
+
attention_mask=encoder_attention_mask,
|
| 595 |
+
**cross_attention_kwargs,
|
| 596 |
+
)
|
| 597 |
+
hidden_states = attn_output + hidden_states
|
| 598 |
+
|
| 599 |
+
# 3. Feed-forward
|
| 600 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 601 |
+
|
| 602 |
+
if self.use_ada_layer_norm_zero:
|
| 603 |
+
norm_hidden_states = (
|
| 604 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if self._chunk_size is not None:
|
| 608 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 609 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
| 610 |
+
raise ValueError(
|
| 611 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
| 615 |
+
ff_output = torch.cat(
|
| 616 |
+
[
|
| 617 |
+
self.ff(hid_slice)
|
| 618 |
+
for hid_slice in norm_hidden_states.chunk(
|
| 619 |
+
num_chunks, dim=self._chunk_dim
|
| 620 |
+
)
|
| 621 |
+
],
|
| 622 |
+
dim=self._chunk_dim,
|
| 623 |
+
)
|
| 624 |
+
else:
|
| 625 |
+
ff_output = self.ff(norm_hidden_states)
|
| 626 |
+
|
| 627 |
+
if self.use_ada_layer_norm_zero:
|
| 628 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 629 |
+
|
| 630 |
+
hidden_states = ff_output + hidden_states
|
| 631 |
+
|
| 632 |
+
return hidden_states
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class Transpose(torch.nn.Module):
|
| 636 |
+
def __init__(self, dim0: int, dim1: int):
|
| 637 |
+
super().__init__()
|
| 638 |
+
self.dim0 = dim0
|
| 639 |
+
self.dim1 = dim1
|
| 640 |
+
|
| 641 |
+
def forward(self, x: torch.Tensor):
|
| 642 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 643 |
+
return x
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
class Block1D(torch.nn.Module):
|
| 647 |
+
def __init__(self, dim, dim_out, groups=8):
|
| 648 |
+
super().__init__()
|
| 649 |
+
self.block = torch.nn.Sequential(
|
| 650 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
| 651 |
+
torch.nn.GroupNorm(groups, dim_out),
|
| 652 |
+
nn.Mish(),
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, x, mask):
|
| 656 |
+
output = self.block(x * mask)
|
| 657 |
+
return output * mask
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class ResnetBlock1D(torch.nn.Module):
|
| 661 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
| 662 |
+
super().__init__()
|
| 663 |
+
self.mlp = torch.nn.Sequential(
|
| 664 |
+
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
| 668 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
| 669 |
+
|
| 670 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
| 671 |
+
|
| 672 |
+
def forward(self, x, mask, time_emb):
|
| 673 |
+
h = self.block1(x, mask)
|
| 674 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
| 675 |
+
h = self.block2(h, mask)
|
| 676 |
+
output = h + self.res_conv(x * mask)
|
| 677 |
+
return output
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class CausalBlock1D(Block1D):
|
| 681 |
+
def __init__(self, dim: int, dim_out: int):
|
| 682 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 683 |
+
self.block = torch.nn.Sequential(
|
| 684 |
+
CausalConv1d(dim, dim_out, 3),
|
| 685 |
+
Transpose(1, 2),
|
| 686 |
+
nn.LayerNorm(dim_out),
|
| 687 |
+
Transpose(1, 2),
|
| 688 |
+
nn.Mish(),
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
| 692 |
+
output = self.block(x * mask)
|
| 693 |
+
return output * mask
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 697 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 698 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 699 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 700 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
in_channels: int,
|
| 707 |
+
out_channels: int,
|
| 708 |
+
kernel_size: int,
|
| 709 |
+
stride: int = 1,
|
| 710 |
+
dilation: int = 1,
|
| 711 |
+
groups: int = 1,
|
| 712 |
+
bias: bool = True,
|
| 713 |
+
padding_mode: str = 'zeros',
|
| 714 |
+
device=None,
|
| 715 |
+
dtype=None
|
| 716 |
+
) -> None:
|
| 717 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 718 |
+
kernel_size, stride,
|
| 719 |
+
padding=0, dilation=dilation,
|
| 720 |
+
groups=groups, bias=bias,
|
| 721 |
+
padding_mode=padding_mode,
|
| 722 |
+
device=device, dtype=dtype)
|
| 723 |
+
assert stride == 1
|
| 724 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 725 |
+
|
| 726 |
+
def forward(self, x: torch.Tensor):
|
| 727 |
+
x = F.pad(x, self.causal_padding)
|
| 728 |
+
x = super(CausalConv1d, self).forward(x)
|
| 729 |
+
return x
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
class BASECFM(torch.nn.Module, ABC):
|
| 733 |
+
def __init__(
|
| 734 |
+
self,
|
| 735 |
+
n_feats,
|
| 736 |
+
cfm_params,
|
| 737 |
+
n_spks=1,
|
| 738 |
+
spk_emb_dim=128,
|
| 739 |
+
):
|
| 740 |
+
super().__init__()
|
| 741 |
+
self.n_feats = n_feats
|
| 742 |
+
self.n_spks = n_spks
|
| 743 |
+
self.spk_emb_dim = spk_emb_dim
|
| 744 |
+
self.solver = cfm_params.solver
|
| 745 |
+
if hasattr(cfm_params, "sigma_min"):
|
| 746 |
+
self.sigma_min = cfm_params.sigma_min
|
| 747 |
+
else:
|
| 748 |
+
self.sigma_min = 1e-4
|
| 749 |
+
|
| 750 |
+
self.estimator = None
|
| 751 |
+
|
| 752 |
+
@torch.inference_mode()
|
| 753 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 754 |
+
"""Forward diffusion
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
mu (torch.Tensor): output of encoder
|
| 758 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 759 |
+
mask (torch.Tensor): output_mask
|
| 760 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 761 |
+
n_timesteps (int): number of diffusion steps
|
| 762 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 763 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 764 |
+
shape: (batch_size, spk_emb_dim)
|
| 765 |
+
cond: Not used but kept for future purposes
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
sample: generated mel-spectrogram
|
| 769 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 770 |
+
"""
|
| 771 |
+
z = torch.randn_like(mu) * temperature
|
| 772 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 773 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
| 774 |
+
|
| 775 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| 776 |
+
"""
|
| 777 |
+
Fixed euler solver for ODEs.
|
| 778 |
+
Args:
|
| 779 |
+
x (torch.Tensor): random noise
|
| 780 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 781 |
+
shape: (n_timesteps + 1,)
|
| 782 |
+
mu (torch.Tensor): output of encoder
|
| 783 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 784 |
+
mask (torch.Tensor): output_mask
|
| 785 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 786 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 787 |
+
shape: (batch_size, spk_emb_dim)
|
| 788 |
+
cond: Not used but kept for future purposes
|
| 789 |
+
"""
|
| 790 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 791 |
+
|
| 792 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 793 |
+
# Or in future might add like a return_all_steps flag
|
| 794 |
+
sol = []
|
| 795 |
+
|
| 796 |
+
for step in range(1, len(t_span)):
|
| 797 |
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
| 798 |
+
|
| 799 |
+
x = x + dt * dphi_dt
|
| 800 |
+
t = t + dt
|
| 801 |
+
sol.append(x)
|
| 802 |
+
if step < len(t_span) - 1:
|
| 803 |
+
dt = t_span[step + 1] - t
|
| 804 |
+
|
| 805 |
+
return sol[-1]
|
| 806 |
+
|
| 807 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 808 |
+
"""Computes diffusion loss
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
x1 (torch.Tensor): Target
|
| 812 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 813 |
+
mask (torch.Tensor): target mask
|
| 814 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 815 |
+
mu (torch.Tensor): output of encoder
|
| 816 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 817 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 818 |
+
shape: (batch_size, spk_emb_dim)
|
| 819 |
+
|
| 820 |
+
Returns:
|
| 821 |
+
loss: conditional flow matching loss
|
| 822 |
+
y: conditional flow
|
| 823 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 824 |
+
"""
|
| 825 |
+
b, _, t = mu.shape
|
| 826 |
+
|
| 827 |
+
# random timestep
|
| 828 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 829 |
+
# sample noise p(x_0)
|
| 830 |
+
z = torch.randn_like(x1)
|
| 831 |
+
|
| 832 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 833 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 834 |
+
|
| 835 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
| 836 |
+
torch.sum(mask) * u.shape[1]
|
| 837 |
+
)
|
| 838 |
+
return loss, y
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class ConditionalDecoder(nn.Module):
|
| 842 |
+
def __init__(
|
| 843 |
+
self,
|
| 844 |
+
in_channels,
|
| 845 |
+
out_channels,
|
| 846 |
+
causal=False,
|
| 847 |
+
channels=(256, 256),
|
| 848 |
+
dropout=0.05,
|
| 849 |
+
attention_head_dim=64,
|
| 850 |
+
n_blocks=1,
|
| 851 |
+
num_mid_blocks=2,
|
| 852 |
+
num_heads=4,
|
| 853 |
+
act_fn="snake",
|
| 854 |
+
gradient_checkpointing=False,
|
| 855 |
+
):
|
| 856 |
+
"""
|
| 857 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 858 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 859 |
+
"""
|
| 860 |
+
super().__init__()
|
| 861 |
+
channels = tuple(channels)
|
| 862 |
+
self.in_channels = in_channels
|
| 863 |
+
self.out_channels = out_channels
|
| 864 |
+
self.causal = causal
|
| 865 |
+
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
|
| 866 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 867 |
+
|
| 868 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 869 |
+
time_embed_dim = channels[0] * 4
|
| 870 |
+
self.time_mlp = TimestepEmbedding(
|
| 871 |
+
in_channels=in_channels,
|
| 872 |
+
time_embed_dim=time_embed_dim,
|
| 873 |
+
act_fn="silu",
|
| 874 |
+
)
|
| 875 |
+
self.down_blocks = nn.ModuleList([])
|
| 876 |
+
self.mid_blocks = nn.ModuleList([])
|
| 877 |
+
self.up_blocks = nn.ModuleList([])
|
| 878 |
+
|
| 879 |
+
output_channel = in_channels
|
| 880 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 881 |
+
input_channel = output_channel
|
| 882 |
+
output_channel = channels[i]
|
| 883 |
+
is_last = i == len(channels) - 1
|
| 884 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 885 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 886 |
+
transformer_blocks = nn.ModuleList(
|
| 887 |
+
[
|
| 888 |
+
BasicTransformerBlock(
|
| 889 |
+
dim=output_channel,
|
| 890 |
+
num_attention_heads=num_heads,
|
| 891 |
+
attention_head_dim=attention_head_dim,
|
| 892 |
+
dropout=dropout,
|
| 893 |
+
activation_fn=act_fn,
|
| 894 |
+
)
|
| 895 |
+
for _ in range(n_blocks)
|
| 896 |
+
]
|
| 897 |
+
)
|
| 898 |
+
downsample = (
|
| 899 |
+
Downsample1D(output_channel) if not is_last else
|
| 900 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 901 |
+
)
|
| 902 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 903 |
+
|
| 904 |
+
for _ in range(num_mid_blocks):
|
| 905 |
+
input_channel = channels[-1]
|
| 906 |
+
out_channels = channels[-1]
|
| 907 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 908 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 909 |
+
|
| 910 |
+
transformer_blocks = nn.ModuleList(
|
| 911 |
+
[
|
| 912 |
+
BasicTransformerBlock(
|
| 913 |
+
dim=output_channel,
|
| 914 |
+
num_attention_heads=num_heads,
|
| 915 |
+
attention_head_dim=attention_head_dim,
|
| 916 |
+
dropout=dropout,
|
| 917 |
+
activation_fn=act_fn,
|
| 918 |
+
)
|
| 919 |
+
for _ in range(n_blocks)
|
| 920 |
+
]
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 924 |
+
|
| 925 |
+
channels = channels[::-1] + (channels[0],)
|
| 926 |
+
for i in range(len(channels) - 1):
|
| 927 |
+
input_channel = channels[i] * 2
|
| 928 |
+
output_channel = channels[i + 1]
|
| 929 |
+
is_last = i == len(channels) - 2
|
| 930 |
+
resnet = CausalResnetBlock1D(
|
| 931 |
+
dim=input_channel,
|
| 932 |
+
dim_out=output_channel,
|
| 933 |
+
time_emb_dim=time_embed_dim,
|
| 934 |
+
) if self.causal else ResnetBlock1D(
|
| 935 |
+
dim=input_channel,
|
| 936 |
+
dim_out=output_channel,
|
| 937 |
+
time_emb_dim=time_embed_dim,
|
| 938 |
+
)
|
| 939 |
+
transformer_blocks = nn.ModuleList(
|
| 940 |
+
[
|
| 941 |
+
BasicTransformerBlock(
|
| 942 |
+
dim=output_channel,
|
| 943 |
+
num_attention_heads=num_heads,
|
| 944 |
+
attention_head_dim=attention_head_dim,
|
| 945 |
+
dropout=dropout,
|
| 946 |
+
activation_fn=act_fn,
|
| 947 |
+
)
|
| 948 |
+
for _ in range(n_blocks)
|
| 949 |
+
]
|
| 950 |
+
)
|
| 951 |
+
upsample = (
|
| 952 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 953 |
+
if not is_last
|
| 954 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 955 |
+
)
|
| 956 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 957 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
| 958 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 959 |
+
self.initialize_weights()
|
| 960 |
+
|
| 961 |
+
def initialize_weights(self):
|
| 962 |
+
for m in self.modules():
|
| 963 |
+
if isinstance(m, nn.Conv1d):
|
| 964 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 965 |
+
if m.bias is not None:
|
| 966 |
+
nn.init.constant_(m.bias, 0)
|
| 967 |
+
elif isinstance(m, nn.GroupNorm):
|
| 968 |
+
nn.init.constant_(m.weight, 1)
|
| 969 |
+
nn.init.constant_(m.bias, 0)
|
| 970 |
+
elif isinstance(m, nn.Linear):
|
| 971 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 972 |
+
if m.bias is not None:
|
| 973 |
+
nn.init.constant_(m.bias, 0)
|
| 974 |
+
|
| 975 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
| 976 |
+
"""Forward pass of the UNet1DConditional model.
|
| 977 |
+
|
| 978 |
+
Args:
|
| 979 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 980 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 981 |
+
t (_type_): shape (batch_size)
|
| 982 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 983 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 984 |
+
|
| 985 |
+
Raises:
|
| 986 |
+
ValueError: _description_
|
| 987 |
+
ValueError: _description_
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
_type_: _description_
|
| 991 |
+
"""
|
| 992 |
+
t = self.time_embeddings(t)
|
| 993 |
+
t = t.to(x.dtype)
|
| 994 |
+
t = self.time_mlp(t)
|
| 995 |
+
x = pack([x, mu], "b * t")[0]
|
| 996 |
+
mask = mask.to(x.dtype)
|
| 997 |
+
if spks is not None:
|
| 998 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 999 |
+
x = pack([x, spks], "b * t")[0]
|
| 1000 |
+
if cond is not None:
|
| 1001 |
+
x = pack([x, cond], "b * t")[0]
|
| 1002 |
+
|
| 1003 |
+
hiddens = []
|
| 1004 |
+
masks = [mask]
|
| 1005 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 1006 |
+
mask_down = masks[-1]
|
| 1007 |
+
x = resnet(x, mask_down, t)
|
| 1008 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 1009 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
| 1010 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 1011 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 1012 |
+
for transformer_block in transformer_blocks:
|
| 1013 |
+
if self.gradient_checkpointing and self.training:
|
| 1014 |
+
def create_custom_forward(module):
|
| 1015 |
+
def custom_forward(*inputs):
|
| 1016 |
+
return module(*inputs)
|
| 1017 |
+
return custom_forward
|
| 1018 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1019 |
+
create_custom_forward(transformer_block),
|
| 1020 |
+
x,
|
| 1021 |
+
attn_mask,
|
| 1022 |
+
t,
|
| 1023 |
+
)
|
| 1024 |
+
else:
|
| 1025 |
+
x = transformer_block(
|
| 1026 |
+
hidden_states=x,
|
| 1027 |
+
attention_mask=attn_mask,
|
| 1028 |
+
timestep=t,
|
| 1029 |
+
)
|
| 1030 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 1031 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 1032 |
+
x = downsample(x * mask_down)
|
| 1033 |
+
masks.append(mask_down[:, :, ::2])
|
| 1034 |
+
masks = masks[:-1]
|
| 1035 |
+
mask_mid = masks[-1]
|
| 1036 |
+
|
| 1037 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 1038 |
+
x = resnet(x, mask_mid, t)
|
| 1039 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 1040 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
| 1041 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 1042 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 1043 |
+
for transformer_block in transformer_blocks:
|
| 1044 |
+
if self.gradient_checkpointing and self.training:
|
| 1045 |
+
def create_custom_forward(module):
|
| 1046 |
+
def custom_forward(*inputs):
|
| 1047 |
+
return module(*inputs)
|
| 1048 |
+
return custom_forward
|
| 1049 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1050 |
+
create_custom_forward(transformer_block),
|
| 1051 |
+
x,
|
| 1052 |
+
attn_mask,
|
| 1053 |
+
t,
|
| 1054 |
+
)
|
| 1055 |
+
else:
|
| 1056 |
+
x = transformer_block(
|
| 1057 |
+
hidden_states=x,
|
| 1058 |
+
attention_mask=attn_mask,
|
| 1059 |
+
timestep=t,
|
| 1060 |
+
)
|
| 1061 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 1062 |
+
|
| 1063 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 1064 |
+
mask_up = masks.pop()
|
| 1065 |
+
skip = hiddens.pop()
|
| 1066 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 1067 |
+
x = resnet(x, mask_up, t)
|
| 1068 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 1069 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
| 1070 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 1071 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 1072 |
+
for transformer_block in transformer_blocks:
|
| 1073 |
+
if self.gradient_checkpointing and self.training:
|
| 1074 |
+
def create_custom_forward(module):
|
| 1075 |
+
def custom_forward(*inputs):
|
| 1076 |
+
return module(*inputs)
|
| 1077 |
+
return custom_forward
|
| 1078 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1079 |
+
create_custom_forward(transformer_block),
|
| 1080 |
+
x,
|
| 1081 |
+
attn_mask,
|
| 1082 |
+
t,
|
| 1083 |
+
)
|
| 1084 |
+
else:
|
| 1085 |
+
x = transformer_block(
|
| 1086 |
+
hidden_states=x,
|
| 1087 |
+
attention_mask=attn_mask,
|
| 1088 |
+
timestep=t,
|
| 1089 |
+
)
|
| 1090 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 1091 |
+
x = upsample(x * mask_up)
|
| 1092 |
+
x = self.final_block(x, mask_up)
|
| 1093 |
+
output = self.final_proj(x * mask_up)
|
| 1094 |
+
return output * mask
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
class ConditionalCFM(BASECFM):
|
| 1098 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
|
| 1099 |
+
super().__init__(
|
| 1100 |
+
n_feats=in_channels,
|
| 1101 |
+
cfm_params=cfm_params,
|
| 1102 |
+
n_spks=n_spks,
|
| 1103 |
+
spk_emb_dim=spk_emb_dim,
|
| 1104 |
+
)
|
| 1105 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 1106 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 1107 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 1108 |
+
|
| 1109 |
+
@torch.inference_mode()
|
| 1110 |
+
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 1111 |
+
"""Forward diffusion
|
| 1112 |
+
|
| 1113 |
+
Args:
|
| 1114 |
+
mu (torch.Tensor): output of encoder
|
| 1115 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1116 |
+
mask (torch.Tensor): output_mask
|
| 1117 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 1118 |
+
n_timesteps (int): number of diffusion steps
|
| 1119 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 1120 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 1121 |
+
shape: (batch_size, spk_emb_dim)
|
| 1122 |
+
cond: Not used but kept for future purposes
|
| 1123 |
+
|
| 1124 |
+
Returns:
|
| 1125 |
+
sample: generated mel-spectrogram
|
| 1126 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1127 |
+
"""
|
| 1128 |
+
z = torch.randn_like(mu) * temperature
|
| 1129 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 1130 |
+
if self.t_scheduler == 'cosine':
|
| 1131 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 1132 |
+
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
|
| 1133 |
+
|
| 1134 |
+
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
|
| 1135 |
+
"""
|
| 1136 |
+
Fixed euler solver for ODEs.
|
| 1137 |
+
Args:
|
| 1138 |
+
x (torch.Tensor): random noise
|
| 1139 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 1140 |
+
shape: (n_timesteps + 1,)
|
| 1141 |
+
mu (torch.Tensor): output of encoder
|
| 1142 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1143 |
+
mask (torch.Tensor): output_mask
|
| 1144 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 1145 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 1146 |
+
shape: (batch_size, spk_emb_dim)
|
| 1147 |
+
cond: Not used but kept for future purposes
|
| 1148 |
+
"""
|
| 1149 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 1150 |
+
|
| 1151 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 1152 |
+
# Or in future might add like a return_all_steps flag
|
| 1153 |
+
sol = []
|
| 1154 |
+
|
| 1155 |
+
for step in range(1, len(t_span)):
|
| 1156 |
+
dphi_dt = estimator(x, mask, mu, t, spks, cond)
|
| 1157 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 1158 |
+
if self.inference_cfg_rate > 0:
|
| 1159 |
+
cfg_dphi_dt = estimator(
|
| 1160 |
+
x, mask,
|
| 1161 |
+
torch.zeros_like(mu), t,
|
| 1162 |
+
torch.zeros_like(spks) if spks is not None else None,
|
| 1163 |
+
cond=cond
|
| 1164 |
+
)
|
| 1165 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
| 1166 |
+
self.inference_cfg_rate * cfg_dphi_dt)
|
| 1167 |
+
x = x + dt * dphi_dt
|
| 1168 |
+
t = t + dt
|
| 1169 |
+
sol.append(x)
|
| 1170 |
+
if step < len(t_span) - 1:
|
| 1171 |
+
dt = t_span[step + 1] - t
|
| 1172 |
+
|
| 1173 |
+
return sol[-1]
|
| 1174 |
+
|
| 1175 |
+
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
|
| 1176 |
+
"""Computes diffusion loss
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
x1 (torch.Tensor): Target
|
| 1180 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1181 |
+
mask (torch.Tensor): target mask
|
| 1182 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 1183 |
+
mu (torch.Tensor): output of encoder
|
| 1184 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1185 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 1186 |
+
shape: (batch_size, spk_emb_dim)
|
| 1187 |
+
|
| 1188 |
+
Returns:
|
| 1189 |
+
loss: conditional flow matching loss
|
| 1190 |
+
y: conditional flow
|
| 1191 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1192 |
+
"""
|
| 1193 |
+
org_dtype = x1.dtype
|
| 1194 |
+
|
| 1195 |
+
b, _, t = mu.shape
|
| 1196 |
+
# random timestep
|
| 1197 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 1198 |
+
if self.t_scheduler == 'cosine':
|
| 1199 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 1200 |
+
# sample noise p(x_0)
|
| 1201 |
+
z = torch.randn_like(x1)
|
| 1202 |
+
|
| 1203 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 1204 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 1205 |
+
|
| 1206 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 1207 |
+
if self.training_cfg_rate > 0:
|
| 1208 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 1209 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 1210 |
+
if spks is not None:
|
| 1211 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 1212 |
+
if cond is not None:
|
| 1213 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 1214 |
+
|
| 1215 |
+
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
|
| 1216 |
+
pred = pred.float()
|
| 1217 |
+
u = u.float()
|
| 1218 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 1219 |
+
loss = loss.to(org_dtype)
|
| 1220 |
+
return loss, y
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
| 1224 |
+
def __init__(self, dim):
|
| 1225 |
+
super().__init__()
|
| 1226 |
+
self.dim = dim
|
| 1227 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
| 1228 |
+
|
| 1229 |
+
def forward(self, x, scale=1000):
|
| 1230 |
+
if x.ndim < 1:
|
| 1231 |
+
x = x.unsqueeze(0)
|
| 1232 |
+
device = x.device
|
| 1233 |
+
half_dim = self.dim // 2
|
| 1234 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 1235 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 1236 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 1237 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 1238 |
+
return emb
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
class Downsample1D(nn.Module):
|
| 1242 |
+
def __init__(self, dim):
|
| 1243 |
+
super().__init__()
|
| 1244 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
| 1245 |
+
|
| 1246 |
+
def forward(self, x):
|
| 1247 |
+
return self.conv(x)
|
| 1248 |
+
|
| 1249 |
+
|
| 1250 |
+
class TimestepEmbedding(nn.Module):
|
| 1251 |
+
def __init__(
|
| 1252 |
+
self,
|
| 1253 |
+
in_channels: int,
|
| 1254 |
+
time_embed_dim: int,
|
| 1255 |
+
act_fn: str = "silu",
|
| 1256 |
+
out_dim: int = None,
|
| 1257 |
+
post_act_fn: Optional[str] = None,
|
| 1258 |
+
cond_proj_dim=None,
|
| 1259 |
+
):
|
| 1260 |
+
super().__init__()
|
| 1261 |
+
|
| 1262 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
| 1263 |
+
|
| 1264 |
+
if cond_proj_dim is not None:
|
| 1265 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 1266 |
+
else:
|
| 1267 |
+
self.cond_proj = None
|
| 1268 |
+
|
| 1269 |
+
self.act = get_activation(act_fn)
|
| 1270 |
+
|
| 1271 |
+
if out_dim is not None:
|
| 1272 |
+
time_embed_dim_out = out_dim
|
| 1273 |
+
else:
|
| 1274 |
+
time_embed_dim_out = time_embed_dim
|
| 1275 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
| 1276 |
+
|
| 1277 |
+
if post_act_fn is None:
|
| 1278 |
+
self.post_act = None
|
| 1279 |
+
else:
|
| 1280 |
+
self.post_act = get_activation(post_act_fn)
|
| 1281 |
+
|
| 1282 |
+
def forward(self, sample, condition=None):
|
| 1283 |
+
if condition is not None:
|
| 1284 |
+
sample = sample + self.cond_proj(condition)
|
| 1285 |
+
sample = self.linear_1(sample)
|
| 1286 |
+
|
| 1287 |
+
if self.act is not None:
|
| 1288 |
+
sample = self.act(sample)
|
| 1289 |
+
|
| 1290 |
+
sample = self.linear_2(sample)
|
| 1291 |
+
|
| 1292 |
+
if self.post_act is not None:
|
| 1293 |
+
sample = self.post_act(sample)
|
| 1294 |
+
return sample
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class Upsample1D(nn.Module):
|
| 1298 |
+
"""A 1D upsampling layer with an optional convolution.
|
| 1299 |
+
|
| 1300 |
+
Parameters:
|
| 1301 |
+
channels (`int`):
|
| 1302 |
+
number of channels in the inputs and outputs.
|
| 1303 |
+
use_conv (`bool`, default `False`):
|
| 1304 |
+
option to use a convolution.
|
| 1305 |
+
use_conv_transpose (`bool`, default `False`):
|
| 1306 |
+
option to use a convolution transpose.
|
| 1307 |
+
out_channels (`int`, optional):
|
| 1308 |
+
number of output channels. Defaults to `channels`.
|
| 1309 |
+
"""
|
| 1310 |
+
|
| 1311 |
+
def __init__(
|
| 1312 |
+
self,
|
| 1313 |
+
channels,
|
| 1314 |
+
use_conv=False,
|
| 1315 |
+
use_conv_transpose=True,
|
| 1316 |
+
out_channels=None,
|
| 1317 |
+
name="conv",
|
| 1318 |
+
):
|
| 1319 |
+
super().__init__()
|
| 1320 |
+
self.channels = channels
|
| 1321 |
+
self.out_channels = out_channels or channels
|
| 1322 |
+
self.use_conv = use_conv
|
| 1323 |
+
self.use_conv_transpose = use_conv_transpose
|
| 1324 |
+
self.name = name
|
| 1325 |
+
|
| 1326 |
+
self.conv = None
|
| 1327 |
+
if use_conv_transpose:
|
| 1328 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
| 1329 |
+
elif use_conv:
|
| 1330 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
| 1331 |
+
|
| 1332 |
+
def forward(self, inputs):
|
| 1333 |
+
assert inputs.shape[1] == self.channels
|
| 1334 |
+
if self.use_conv_transpose:
|
| 1335 |
+
return self.conv(inputs)
|
| 1336 |
+
|
| 1337 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
| 1338 |
+
|
| 1339 |
+
if self.use_conv:
|
| 1340 |
+
outputs = self.conv(outputs)
|
| 1341 |
+
|
| 1342 |
+
return outputs
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
class RMSNorm(nn.Module):
|
| 1346 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 1347 |
+
"""
|
| 1348 |
+
RMSNorm is equivalent to T5LayerNorm
|
| 1349 |
+
"""
|
| 1350 |
+
super().__init__()
|
| 1351 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 1352 |
+
self.variance_epsilon = eps
|
| 1353 |
+
|
| 1354 |
+
def forward(self, hidden_states):
|
| 1355 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 1356 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 1357 |
+
|
| 1358 |
+
# convert into half-precision if necessary
|
| 1359 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 1360 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 1361 |
+
|
| 1362 |
+
return self.weight * hidden_states
|
| 1363 |
+
|
| 1364 |
+
|
| 1365 |
+
class OmniWhisperAttention(nn.Module):
|
| 1366 |
+
def __init__(self, embed_dim, num_heads, causal=False):
|
| 1367 |
+
super().__init__()
|
| 1368 |
+
self.embed_dim = embed_dim
|
| 1369 |
+
self.num_heads = num_heads
|
| 1370 |
+
self.head_dim = embed_dim // num_heads
|
| 1371 |
+
|
| 1372 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 1373 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1374 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1375 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1376 |
+
|
| 1377 |
+
self.causal = causal
|
| 1378 |
+
|
| 1379 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
| 1380 |
+
bsz, _ = hidden_states.size()
|
| 1381 |
+
|
| 1382 |
+
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 1383 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 1384 |
+
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 1385 |
+
|
| 1386 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
| 1387 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
| 1388 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
|
| 1389 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
| 1390 |
+
attn_output = self.out_proj(attn_output)
|
| 1391 |
+
return attn_output
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
class OmniWhisperTransformerLayer(nn.Module):
|
| 1395 |
+
def __init__(
|
| 1396 |
+
self,
|
| 1397 |
+
act,
|
| 1398 |
+
d_model,
|
| 1399 |
+
encoder_attention_heads,
|
| 1400 |
+
encoder_ffn_dim,
|
| 1401 |
+
causal,
|
| 1402 |
+
ln_type="LayerNorm",
|
| 1403 |
+
):
|
| 1404 |
+
super().__init__()
|
| 1405 |
+
self.embed_dim = d_model
|
| 1406 |
+
self.self_attn = OmniWhisperAttention(
|
| 1407 |
+
self.embed_dim, encoder_attention_heads, causal
|
| 1408 |
+
)
|
| 1409 |
+
|
| 1410 |
+
if ln_type == "LayerNorm":
|
| 1411 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1412 |
+
elif ln_type == "RMSNorm":
|
| 1413 |
+
self.self_attn_layer_norm = RMSNorm(self.embed_dim)
|
| 1414 |
+
else:
|
| 1415 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
| 1416 |
+
|
| 1417 |
+
self.activation_fn = act
|
| 1418 |
+
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
|
| 1419 |
+
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
|
| 1420 |
+
|
| 1421 |
+
if ln_type == "LayerNorm":
|
| 1422 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1423 |
+
elif ln_type == "RMSNorm":
|
| 1424 |
+
self.final_layer_norm = RMSNorm(self.embed_dim)
|
| 1425 |
+
else:
|
| 1426 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
| 1427 |
+
|
| 1428 |
+
def forward(
|
| 1429 |
+
self, hidden_states: torch.Tensor, seq_len: torch.Tensor
|
| 1430 |
+
) -> torch.Tensor:
|
| 1431 |
+
residual = hidden_states
|
| 1432 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1433 |
+
hidden_states = self.self_attn(hidden_states, seq_len)
|
| 1434 |
+
hidden_states = residual + hidden_states
|
| 1435 |
+
residual = hidden_states
|
| 1436 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 1437 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1438 |
+
hidden_states = self.fc2(hidden_states)
|
| 1439 |
+
hidden_states = residual + hidden_states
|
| 1440 |
+
|
| 1441 |
+
if (
|
| 1442 |
+
hidden_states.dtype == torch.float16
|
| 1443 |
+
or hidden_states.dtype == torch.bfloat16
|
| 1444 |
+
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
| 1445 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 1446 |
+
hidden_states = torch.clamp(
|
| 1447 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
| 1448 |
+
)
|
| 1449 |
+
return hidden_states
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
|
| 1453 |
+
class LongcatNextAudioEncoder(nn.Module):
|
| 1454 |
+
def __init__(self, config):
|
| 1455 |
+
super().__init__()
|
| 1456 |
+
self.config = config
|
| 1457 |
+
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
|
| 1458 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 1459 |
+
|
| 1460 |
+
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
|
| 1461 |
+
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
|
| 1462 |
+
stride=config.stride_size, padding=1)
|
| 1463 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
|
| 1464 |
+
|
| 1465 |
+
self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
|
| 1466 |
+
ACT2FN[config.activation_function],
|
| 1467 |
+
config.d_model,
|
| 1468 |
+
config.encoder_attention_heads,
|
| 1469 |
+
config.encoder_ffn_dim,
|
| 1470 |
+
False) for _ in range(config.encoder_layers)])
|
| 1471 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 1472 |
+
|
| 1473 |
+
def forward(
|
| 1474 |
+
self,
|
| 1475 |
+
input_features,
|
| 1476 |
+
output_length,
|
| 1477 |
+
):
|
| 1478 |
+
input_features = input_features.to(self.conv1.weight.dtype)
|
| 1479 |
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
|
| 1480 |
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
|
| 1481 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
|
| 1482 |
+
bsz, tgt_len, _ = inputs_embeds.size()
|
| 1483 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
| 1484 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
| 1485 |
+
else:
|
| 1486 |
+
current_positional_embedding = self.positional_embedding
|
| 1487 |
+
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
|
| 1488 |
+
|
| 1489 |
+
# packing hidden states
|
| 1490 |
+
attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
|
| 1491 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
|
| 1492 |
+
self.config.d_model)
|
| 1493 |
+
|
| 1494 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1495 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 1496 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1497 |
+
# unpacking
|
| 1498 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
| 1499 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 1500 |
+
return hidden_states
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
class CasualConvTranspose1d(nn.Module):
|
| 1504 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
| 1505 |
+
super().__init__()
|
| 1506 |
+
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
| 1507 |
+
self.norm = nn.GroupNorm(1, out_channels)
|
| 1508 |
+
self.in_channels = in_channels
|
| 1509 |
+
self.out_channels = out_channels
|
| 1510 |
+
|
| 1511 |
+
def forward(self, hidden_states, input_length, output_dim=None):
|
| 1512 |
+
kernel_size = self.conv.kernel_size[0]
|
| 1513 |
+
stride = self.conv.stride[0]
|
| 1514 |
+
bsz = input_length.shape[0]
|
| 1515 |
+
|
| 1516 |
+
if output_dim is None:
|
| 1517 |
+
output_dim = hidden_states.dim()
|
| 1518 |
+
if hidden_states.dim() <= 2: # unpack sequence to 3d
|
| 1519 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
|
| 1520 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
|
| 1521 |
+
self.in_channels)
|
| 1522 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
|
| 1523 |
+
|
| 1524 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
|
| 1525 |
+
hidden_states = self.conv(hidden_states)
|
| 1526 |
+
hidden_states = self.norm(hidden_states)
|
| 1527 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
|
| 1528 |
+
|
| 1529 |
+
casual_padding_right = max(0, kernel_size - stride)
|
| 1530 |
+
hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
|
| 1531 |
+
:]
|
| 1532 |
+
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
|
| 1533 |
+
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
|
| 1534 |
+
if output_dim <= 2:
|
| 1535 |
+
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
|
| 1536 |
+
else:
|
| 1537 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
| 1538 |
+
hidden_states = hidden_states[:, :torch.max(output_length), :]
|
| 1539 |
+
return hidden_states, output_length
|
| 1540 |
+
|
| 1541 |
+
|
| 1542 |
+
class MelSpecRefineNet(nn.Module):
|
| 1543 |
+
"""
|
| 1544 |
+
# post net, coarse to refined mel-spectrogram frames
|
| 1545 |
+
# ref1: Autoregressive Speech Synthesis without Vector Quantization
|
| 1546 |
+
# ref2: CosyVoice length_regulator.py
|
| 1547 |
+
# ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
|
| 1548 |
+
"""
|
| 1549 |
+
|
| 1550 |
+
def __init__(self, encoder_config, vocoder_config):
|
| 1551 |
+
super().__init__()
|
| 1552 |
+
self.encoder_config = encoder_config
|
| 1553 |
+
self.vocoder_config = vocoder_config
|
| 1554 |
+
|
| 1555 |
+
layers = nn.ModuleList([])
|
| 1556 |
+
in_channels = self.vocoder_config.num_mel_bins
|
| 1557 |
+
for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
|
| 1558 |
+
module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
|
| 1559 |
+
in_channels = out_channels
|
| 1560 |
+
norm = nn.GroupNorm(1, out_channels)
|
| 1561 |
+
act = nn.Mish()
|
| 1562 |
+
layers.extend([module, norm, act])
|
| 1563 |
+
layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
|
| 1564 |
+
self.layers = nn.Sequential(*layers)
|
| 1565 |
+
|
| 1566 |
+
def compute_output_length(self, input_length):
|
| 1567 |
+
output_length = input_length.to(
|
| 1568 |
+
torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
|
| 1569 |
+
output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
|
| 1570 |
+
return output_length.to(torch.int64)
|
| 1571 |
+
|
| 1572 |
+
def forward(self, coarse_mel, input_length, output_length=None):
|
| 1573 |
+
bsz, _, d = coarse_mel.shape
|
| 1574 |
+
assert (d == self.vocoder_config.num_mel_bins)
|
| 1575 |
+
if output_length is None or not self.training:
|
| 1576 |
+
output_length = self.compute_output_length(input_length)
|
| 1577 |
+
coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
|
| 1578 |
+
coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
|
| 1579 |
+
mode='nearest').to(default_dtype)
|
| 1580 |
+
refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
|
| 1581 |
+
coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
|
| 1582 |
+
refined_mel += coarse_mel # residual conntection
|
| 1583 |
+
sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
|
| 1584 |
+
coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
|
| 1585 |
+
refined_mel = torch.where(sequence_mask, refined_mel, 0)
|
| 1586 |
+
return refined_mel, coarse_mel, output_length
|
| 1587 |
+
|
| 1588 |
+
|
| 1589 |
+
@dataclass
|
| 1590 |
+
class OmniAudioDecoderOutput(ModelOutput):
|
| 1591 |
+
refined_mel: Optional[torch.FloatTensor] = None
|
| 1592 |
+
coarse_mel: Optional[torch.FloatTensor] = None
|
| 1593 |
+
mel_length: Optional[torch.Tensor] = None
|
| 1594 |
+
hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
|
| 1595 |
+
output_length_before_dconv2: Optional[torch.Tensor] = None
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
class LongcatNextAudioDecoder(nn.Module):
|
| 1599 |
+
def __init__(self, config):
|
| 1600 |
+
super().__init__()
|
| 1601 |
+
self.config = config
|
| 1602 |
+
self.vocoder_config = config.vocoder_config
|
| 1603 |
+
self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
|
| 1604 |
+
|
| 1605 |
+
self.dconv1 = CasualConvTranspose1d(
|
| 1606 |
+
self.config.d_model,
|
| 1607 |
+
self.config.d_model,
|
| 1608 |
+
self.config.decoder_kernel_size,
|
| 1609 |
+
self.config.avg_pooler,
|
| 1610 |
+
)
|
| 1611 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
|
| 1612 |
+
# causal transformer layers
|
| 1613 |
+
self.layers = nn.ModuleList(
|
| 1614 |
+
[OmniWhisperTransformerLayer(
|
| 1615 |
+
ACT2FN[self.config.activation_function],
|
| 1616 |
+
self.config.d_model,
|
| 1617 |
+
self.config.decoder_attention_heads,
|
| 1618 |
+
self.config.decoder_ffn_dim,
|
| 1619 |
+
True # causal
|
| 1620 |
+
) for _ in range(self.config.decoder_layers)
|
| 1621 |
+
])
|
| 1622 |
+
self.layer_norm = nn.LayerNorm(self.config.d_model)
|
| 1623 |
+
self.dconv2 = CasualConvTranspose1d(
|
| 1624 |
+
self.config.d_model,
|
| 1625 |
+
self.vocoder_config.num_mel_bins,
|
| 1626 |
+
self.config.decoder_kernel_size,
|
| 1627 |
+
self.config.decoder_stride_size
|
| 1628 |
+
)
|
| 1629 |
+
self.post_net = MelSpecRefineNet(self.config, self.vocoder_config)
|
| 1630 |
+
self.gradient_checkpointing = False
|
| 1631 |
+
|
| 1632 |
+
def forward(self,
|
| 1633 |
+
audio_embed,
|
| 1634 |
+
input_length,
|
| 1635 |
+
mel_labels=None,
|
| 1636 |
+
mel_labels_length=None,
|
| 1637 |
+
):
|
| 1638 |
+
assert (audio_embed.shape[-1] == self.config.d_model)
|
| 1639 |
+
audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
|
| 1640 |
+
audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
|
| 1641 |
+
_, tgt_len, _ = audio_embed.size()
|
| 1642 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
| 1643 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
| 1644 |
+
else:
|
| 1645 |
+
current_positional_embedding = self.positional_embedding
|
| 1646 |
+
hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
|
| 1647 |
+
|
| 1648 |
+
# packing hidden states
|
| 1649 |
+
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
|
| 1650 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
|
| 1651 |
+
|
| 1652 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1653 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 1654 |
+
|
| 1655 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1656 |
+
hidden_states_before_dconv2 = hidden_states
|
| 1657 |
+
output_length_before_dconv2 = output_length
|
| 1658 |
+
|
| 1659 |
+
coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
|
| 1660 |
+
refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
|
| 1661 |
+
|
| 1662 |
+
return OmniAudioDecoderOutput(
|
| 1663 |
+
refined_mel=refined_mel,
|
| 1664 |
+
coarse_mel=coarse_mel,
|
| 1665 |
+
mel_length=mel_labels_length,
|
| 1666 |
+
hidden_states_before_dconv2=hidden_states_before_dconv2,
|
| 1667 |
+
output_length_before_dconv2=output_length_before_dconv2,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
|
| 1671 |
+
class LongcatNextAudioVQBridger(nn.Module):
|
| 1672 |
+
def __init__(self, config):
|
| 1673 |
+
super().__init__()
|
| 1674 |
+
self.config = config
|
| 1675 |
+
self.gradient_checkpointing = False
|
| 1676 |
+
self.intermediate_dim = self.config.d_model * self.config.avg_pooler
|
| 1677 |
+
self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
| 1678 |
+
self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
| 1679 |
+
|
| 1680 |
+
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
|
| 1681 |
+
self.act_fn = ACT2FN['silu']
|
| 1682 |
+
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
|
| 1683 |
+
self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
|
| 1684 |
+
|
| 1685 |
+
self.vq_list = nn.ModuleList([])
|
| 1686 |
+
for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
|
| 1687 |
+
vq_config = copy.deepcopy(self.config.vq_config)
|
| 1688 |
+
vq_config.dim = self.intermediate_dim
|
| 1689 |
+
vq_config.codebook_size = codebook_size
|
| 1690 |
+
self.vq_list.append(VectorQuantize(vq_config))
|
| 1691 |
+
|
| 1692 |
+
def rvq_op(self, inputs, output_length):
|
| 1693 |
+
def rvq_layer_op(vq_layer, residual_encoding, output_length):
|
| 1694 |
+
q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
|
| 1695 |
+
residual_encoding = residual_encoding.float() - q_v_i.float()
|
| 1696 |
+
residual_encoding = residual_encoding.to(inputs.dtype)
|
| 1697 |
+
return residual_encoding, code_ids_i
|
| 1698 |
+
|
| 1699 |
+
cmt_loss, residual_encoding = 0, inputs
|
| 1700 |
+
code_ids_list = []
|
| 1701 |
+
for i, vq_layer in enumerate(self.vq_list):
|
| 1702 |
+
residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
|
| 1703 |
+
code_ids_list.append(code_ids_i)
|
| 1704 |
+
return torch.stack(code_ids_list, -1)
|
| 1705 |
+
|
| 1706 |
+
def forward(self, x, output_length):
|
| 1707 |
+
batch_size, _, _ = x.shape
|
| 1708 |
+
output_length = output_length.to(x.device)
|
| 1709 |
+
|
| 1710 |
+
if x.shape[1] % self.config.avg_pooler != 0:
|
| 1711 |
+
x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
|
| 1712 |
+
xt = x.permute(0, 2, 1)
|
| 1713 |
+
g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
|
| 1714 |
+
u = self.up_proj(xt).permute(0, 2, 1)
|
| 1715 |
+
x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
|
| 1716 |
+
|
| 1717 |
+
c = self.down_proj(self.act_fn(g) * u)
|
| 1718 |
+
res = self.layer_norm(c + x)
|
| 1719 |
+
valid_mask, _ = get_sequence_mask(res, output_length)
|
| 1720 |
+
code_ids = self.rvq_op(res, output_length)
|
| 1721 |
+
code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
|
| 1722 |
+
return code_ids
|
| 1723 |
+
|
| 1724 |
+
@torch.no_grad()
|
| 1725 |
+
def decode(self, code_ids):
|
| 1726 |
+
vq_num = code_ids.shape[-1]
|
| 1727 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
| 1728 |
+
decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
|
| 1729 |
+
return decoder_emb
|
| 1730 |
+
|
| 1731 |
+
@torch.no_grad()
|
| 1732 |
+
def recover(self, code_ids):
|
| 1733 |
+
vq_num = code_ids.shape[-1]
|
| 1734 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
| 1735 |
+
return res
|
| 1736 |
+
|
| 1737 |
+
|
| 1738 |
+
class FlowmatchingPrenet(nn.Module):
|
| 1739 |
+
def __init__(
|
| 1740 |
+
self,
|
| 1741 |
+
input_feat_dim,
|
| 1742 |
+
out_feat_dim,
|
| 1743 |
+
d_model,
|
| 1744 |
+
attention_heads,
|
| 1745 |
+
ffn_dim,
|
| 1746 |
+
nlayers,
|
| 1747 |
+
activation_function,
|
| 1748 |
+
max_source_positions,
|
| 1749 |
+
target_mel_length_scale_ratio,
|
| 1750 |
+
):
|
| 1751 |
+
super().__init__()
|
| 1752 |
+
|
| 1753 |
+
self.d_model = d_model
|
| 1754 |
+
self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
|
| 1755 |
+
self.gradient_checkpointing = False
|
| 1756 |
+
|
| 1757 |
+
self.register_buffer(
|
| 1758 |
+
"positional_embedding", sinusoids(max_source_positions, d_model)
|
| 1759 |
+
)
|
| 1760 |
+
|
| 1761 |
+
self.in_mlp = nn.Sequential(
|
| 1762 |
+
nn.Linear(input_feat_dim, d_model * 4),
|
| 1763 |
+
nn.SiLU(),
|
| 1764 |
+
nn.Linear(d_model * 4, d_model),
|
| 1765 |
+
)
|
| 1766 |
+
|
| 1767 |
+
self.transformer_layers = nn.ModuleList(
|
| 1768 |
+
[
|
| 1769 |
+
OmniWhisperTransformerLayer(
|
| 1770 |
+
act=ACT2FN[activation_function],
|
| 1771 |
+
d_model=d_model,
|
| 1772 |
+
encoder_attention_heads=attention_heads,
|
| 1773 |
+
encoder_ffn_dim=ffn_dim,
|
| 1774 |
+
causal=True, # causal
|
| 1775 |
+
ln_type="RMSNorm",
|
| 1776 |
+
)
|
| 1777 |
+
for _ in range(nlayers)
|
| 1778 |
+
]
|
| 1779 |
+
)
|
| 1780 |
+
|
| 1781 |
+
self.final_norm = RMSNorm(self.d_model)
|
| 1782 |
+
self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
|
| 1783 |
+
|
| 1784 |
+
def compute_output_length(self, input_length):
|
| 1785 |
+
output_length = input_length.float() * self.target_mel_length_scale_ratio
|
| 1786 |
+
return output_length.to(torch.int64)
|
| 1787 |
+
|
| 1788 |
+
def forward(self, input_feat, input_length, output_length=None):
|
| 1789 |
+
"""
|
| 1790 |
+
Args:
|
| 1791 |
+
input_feat: [B, T, input_feat_dim]
|
| 1792 |
+
input_length: [B]
|
| 1793 |
+
output_length: [B]
|
| 1794 |
+
|
| 1795 |
+
"""
|
| 1796 |
+
if output_length is None or not self.training:
|
| 1797 |
+
output_length = self.compute_output_length(input_length)
|
| 1798 |
+
|
| 1799 |
+
input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
|
| 1800 |
+
orig_dtype = input_feat.dtype
|
| 1801 |
+
|
| 1802 |
+
input_feat = F.interpolate(
|
| 1803 |
+
input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
|
| 1804 |
+
size=output_length.max(),
|
| 1805 |
+
mode="nearest",
|
| 1806 |
+
).to(orig_dtype)
|
| 1807 |
+
input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
|
| 1808 |
+
hidden_states = self.in_mlp(input_feat)
|
| 1809 |
+
|
| 1810 |
+
# packing hidden states
|
| 1811 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
| 1812 |
+
attention_mask, unpacking_index = get_sequence_mask(
|
| 1813 |
+
hidden_states, output_length
|
| 1814 |
+
)
|
| 1815 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
| 1816 |
+
torch.sum(output_length), self.d_model
|
| 1817 |
+
)
|
| 1818 |
+
|
| 1819 |
+
for idx, encoder_layer in enumerate(self.transformer_layers):
|
| 1820 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 1821 |
+
|
| 1822 |
+
# unpacking
|
| 1823 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 1824 |
+
bsz, tgt_len, d_model
|
| 1825 |
+
)
|
| 1826 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 1827 |
+
|
| 1828 |
+
hidden_states = self.final_norm(hidden_states)
|
| 1829 |
+
output = self.out_proj(hidden_states)
|
| 1830 |
+
return output, output_length
|
| 1831 |
+
|
| 1832 |
+
|
| 1833 |
+
@dataclass
|
| 1834 |
+
class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
|
| 1835 |
+
flow_matching_mel: Optional[torch.FloatTensor] = None
|
| 1836 |
+
flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
|
| 1837 |
+
|
| 1838 |
+
|
| 1839 |
+
class LongcatNextAudioFlowMatchingDecoder(nn.Module):
|
| 1840 |
+
def __init__(self, config):
|
| 1841 |
+
super().__init__()
|
| 1842 |
+
self.config = config.flow_matching_config
|
| 1843 |
+
self.in_channels = self.config.in_channels
|
| 1844 |
+
self.spk_emb_dim = self.config.spk_emb_dim
|
| 1845 |
+
self.diffusion_steps = self.config.diffusion_steps
|
| 1846 |
+
self.cal_mel_mae = self.config.cal_mel_mae
|
| 1847 |
+
self.forward_step = -1
|
| 1848 |
+
|
| 1849 |
+
self.prenet = FlowmatchingPrenet(
|
| 1850 |
+
input_feat_dim=self.config.prenet_in_dim,
|
| 1851 |
+
out_feat_dim=self.config.prenet_out_dim,
|
| 1852 |
+
d_model=self.config.prenet_d_model,
|
| 1853 |
+
attention_heads=self.config.prenet_attention_heads,
|
| 1854 |
+
ffn_dim=self.config.prenet_ffn_dim,
|
| 1855 |
+
nlayers=self.config.prenet_nlayers,
|
| 1856 |
+
activation_function=self.config.prenet_activation_function,
|
| 1857 |
+
max_source_positions=self.config.prenet_max_source_positions,
|
| 1858 |
+
target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
|
| 1859 |
+
)
|
| 1860 |
+
|
| 1861 |
+
self.conditional_decoder = ConditionalDecoder(
|
| 1862 |
+
in_channels=self.in_channels * 2 + self.spk_emb_dim,
|
| 1863 |
+
out_channels=self.in_channels,
|
| 1864 |
+
causal=True,
|
| 1865 |
+
channels=self.config.channels,
|
| 1866 |
+
dropout=self.config.dropout,
|
| 1867 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 1868 |
+
n_blocks=self.config.n_blocks,
|
| 1869 |
+
num_mid_blocks=self.config.num_mid_blocks,
|
| 1870 |
+
num_heads=self.config.num_heads,
|
| 1871 |
+
act_fn=self.config.act_fn,
|
| 1872 |
+
)
|
| 1873 |
+
|
| 1874 |
+
self.cfm = ConditionalCFM(
|
| 1875 |
+
in_channels=self.in_channels,
|
| 1876 |
+
cfm_params=self.config.cfm_params,
|
| 1877 |
+
n_spks=0,
|
| 1878 |
+
spk_emb_dim=self.spk_emb_dim,
|
| 1879 |
+
)
|
| 1880 |
+
|
| 1881 |
+
|
| 1882 |
+
def unpack_hidden_states(self, hidden_states, output_length):
|
| 1883 |
+
unpacked = unpack_hidden_states(hidden_states, output_length)
|
| 1884 |
+
return unpacked, output_length
|
| 1885 |
+
|
| 1886 |
+
def forward(
|
| 1887 |
+
self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
|
| 1888 |
+
):
|
| 1889 |
+
"""
|
| 1890 |
+
:param refined_mel: [bs, max_input_len, mel_bin]
|
| 1891 |
+
:param input_length: [batch_size]
|
| 1892 |
+
:param refined_mel: [bs, mel_bin, max_input_len]
|
| 1893 |
+
:return:
|
| 1894 |
+
"""
|
| 1895 |
+
self.forward_step += 1
|
| 1896 |
+
|
| 1897 |
+
orig_dtype = refined_mel.dtype
|
| 1898 |
+
prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
|
| 1899 |
+
prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
|
| 1900 |
+
|
| 1901 |
+
if self.prenet is not None:
|
| 1902 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
| 1903 |
+
if mel_labels_length is None:
|
| 1904 |
+
mel_labels_length = self.prenet.compute_output_length(input_length)
|
| 1905 |
+
refined_mel, input_length = self.prenet(
|
| 1906 |
+
refined_mel, input_length, mel_labels_length
|
| 1907 |
+
)
|
| 1908 |
+
|
| 1909 |
+
float_dtype = refined_mel.dtype
|
| 1910 |
+
refined_mel = refined_mel.float()
|
| 1911 |
+
input_length = input_length.long()
|
| 1912 |
+
|
| 1913 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
| 1914 |
+
sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
|
| 1915 |
+
refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
|
| 1916 |
+
sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
|
| 1917 |
+
|
| 1918 |
+
fm_mel = self.cfm.forward(
|
| 1919 |
+
estimator=self.conditional_decoder,
|
| 1920 |
+
mu=refined_mel.to(float_dtype),
|
| 1921 |
+
mask=sequence_mask.float(),
|
| 1922 |
+
n_timesteps=self.diffusion_steps,
|
| 1923 |
+
)
|
| 1924 |
+
return OmniAudioFlowMatchingDecoderOutput(
|
| 1925 |
+
flow_matching_mel=fm_mel.transpose(1, 2),
|
| 1926 |
+
flow_matching_mel_lengths=mel_labels_length,
|
| 1927 |
+
)
|
| 1928 |
+
|
| 1929 |
+
|
| 1930 |
+
@torch.no_grad()
|
| 1931 |
+
def decode_wave_vocoder2(response, vocoder, audio_tokenizer):
|
| 1932 |
+
response_len = (response[:,:,0] == audio_tokenizer.config.audio_config.vq_config.codebook_sizes[0]).long().argmax(dim=1)
|
| 1933 |
+
valid_response_list = [response[i, :response_len[i], :] for i in range(response.shape[0]) if int(response_len[i])>0]
|
| 1934 |
+
|
| 1935 |
+
if len(valid_response_list)==0:
|
| 1936 |
+
return []
|
| 1937 |
+
flatten_response = torch.cat(valid_response_list, dim=0) if len(valid_response_list)>1 else valid_response_list[0]
|
| 1938 |
+
valid_response_len = response_len[response_len>0]
|
| 1939 |
+
ret = audio_tokenizer.decode(flatten_response.view(-1,response.shape[-1]),
|
| 1940 |
+
bridge_length=valid_response_len)
|
| 1941 |
+
batch_size = response.shape[0]
|
| 1942 |
+
valid_start = 0
|
| 1943 |
+
r = []
|
| 1944 |
+
for i in range(batch_size):
|
| 1945 |
+
if response_len[i]==0:
|
| 1946 |
+
r.append(None)
|
| 1947 |
+
continue
|
| 1948 |
+
if isinstance(ret, torch.Tensor):
|
| 1949 |
+
r.append(ret[valid_start:valid_start+1])
|
| 1950 |
+
valid_start+=1
|
| 1951 |
+
continue
|
| 1952 |
+
decode_wave = vocoder.decode(ret.flow_matching_mel[valid_start ][:ret.flow_matching_mel_lengths[valid_start ], :].transpose(0, 1).to(torch.float32).unsqueeze(0))
|
| 1953 |
+
r.append(decode_wave.cpu())
|
| 1954 |
+
valid_start+=1
|
| 1955 |
+
return r
|
| 1956 |
+
|
| 1957 |
+
|
| 1958 |
+
@torch.no_grad()
|
| 1959 |
+
def decode_save_concat2(response_list, vocoder, model, path, sampling_rate=16000, wave_concat_overlap=800):
|
| 1960 |
+
wave_list = []
|
| 1961 |
+
for response in response_list:
|
| 1962 |
+
wave_list.extend([wave_i for wave_i in decode_wave_vocoder2(response, vocoder, model) if wave_i is not None])
|
| 1963 |
+
new_wave_list = [wave_list[0]]
|
| 1964 |
+
for w in wave_list[1:]:
|
| 1965 |
+
if new_wave_list[-1].shape[1] > wave_concat_overlap and w.shape[1] > wave_concat_overlap:
|
| 1966 |
+
new_wave_list.append((new_wave_list[-1][:, -wave_concat_overlap:] * torch.linspace(1.0, 0.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]
|
| 1967 |
+
+ w[:, :wave_concat_overlap] * torch.linspace(0.0, 1.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]))
|
| 1968 |
+
new_wave_list.append(w)
|
| 1969 |
+
full_wave = torch.cat(new_wave_list, dim=1) if len(new_wave_list) > 1 else new_wave_list[0]
|
| 1970 |
+
torchaudio.save(path, full_wave, sampling_rate)
|
| 1971 |
+
|
| 1972 |
+
|
| 1973 |
+
class LongcatNextAudioTokenizer(nn.Module):
|
| 1974 |
+
|
| 1975 |
+
def __init__(self, config):
|
| 1976 |
+
super().__init__()
|
| 1977 |
+
self.config = config
|
| 1978 |
+
self.audio_model = LongcatNextAudioEncoder(config.audio_config)
|
| 1979 |
+
self.audio_bridge_model = LongcatNextAudioVQBridger(config.audio_config)
|
| 1980 |
+
self.audio_decoder = LongcatNextAudioDecoder(config.audio_config)
|
| 1981 |
+
self.audio_flow_matching_decoder = LongcatNextAudioFlowMatchingDecoder(config.audio_config)
|
| 1982 |
+
self.cosy24kvocoder = None
|
| 1983 |
+
|
| 1984 |
+
@torch.no_grad()
|
| 1985 |
+
def encode(self, x, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
|
| 1986 |
+
audio_emb = self.audio_model(x, encoder_length)
|
| 1987 |
+
audio_tokens = self.audio_bridge_model(audio_emb, bridge_length)
|
| 1988 |
+
return audio_tokens
|
| 1989 |
+
|
| 1990 |
+
@torch.no_grad()
|
| 1991 |
+
def decode(self, audio_ids, bridge_length: Optional[torch.Tensor] = None):
|
| 1992 |
+
audio_emb = self.audio_bridge_model.decode(audio_ids)
|
| 1993 |
+
audio_dec = self.audio_decoder(
|
| 1994 |
+
audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
|
| 1995 |
+
)
|
| 1996 |
+
if self.config.audio_config.flow_matching_config.use_hidden_states_before_dconv2:
|
| 1997 |
+
hidden_states, hidden_states_length = (
|
| 1998 |
+
self.audio_flow_matching_decoder.unpack_hidden_states(
|
| 1999 |
+
audio_dec.hidden_states_before_dconv2,
|
| 2000 |
+
audio_dec.output_length_before_dconv2,
|
| 2001 |
+
)
|
| 2002 |
+
)
|
| 2003 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
| 2004 |
+
hidden_states, hidden_states_length
|
| 2005 |
+
)
|
| 2006 |
+
else:
|
| 2007 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
| 2008 |
+
audio_dec.refined_mel, audio_dec.mel_length
|
| 2009 |
+
)
|
| 2010 |
+
return audio_flow_matching_decoder_ret
|
| 2011 |
+
|
| 2012 |
+
@torch.no_grad()
|
| 2013 |
+
def lazy_decode_and_save(self, audio_ids, sampling_rate, wave_concat_overlap, save_path):
|
| 2014 |
+
if self.cosy24kvocoder is None:
|
| 2015 |
+
print("lazy load cosy24kvocoder ...")
|
| 2016 |
+
device = next(self.parameters()).device
|
| 2017 |
+
self.cosy24kvocoder = Cosy24kVocoder.from_pretrained(self.config.audio_config.cosy24kvocoder_config.weight_path).to(device)
|
| 2018 |
+
|
| 2019 |
+
if audio_ids[-1, 0] != self.config.audio_config.vq_config.codebook_sizes[0]: # exceed max_new_tokens
|
| 2020 |
+
audio_ids = F.pad(audio_ids, (0, 0, 0, 1), value=self.config.audio_config.vq_config.codebook_sizes[0])
|
| 2021 |
+
|
| 2022 |
+
audio_end_pos = [-1] + (audio_ids[:, 0] == self.config.audio_config.vq_config.codebook_sizes[0]).nonzero().view(-1).tolist()
|
| 2023 |
+
|
| 2024 |
+
audio_ids_chunk = []
|
| 2025 |
+
for i in range(len(audio_end_pos) - 1):
|
| 2026 |
+
start = audio_end_pos[i] + 1
|
| 2027 |
+
end = audio_end_pos[i+1] + 1
|
| 2028 |
+
audio_ids_chunk.append(audio_ids[start:end].unsqueeze(0))
|
| 2029 |
+
|
| 2030 |
+
audio_ids = audio_ids_chunk
|
| 2031 |
+
|
| 2032 |
+
decode_save_concat2(
|
| 2033 |
+
response_list=audio_ids,
|
| 2034 |
+
vocoder=self.cosy24kvocoder,
|
| 2035 |
+
model=self,
|
| 2036 |
+
path=save_path,
|
| 2037 |
+
sampling_rate=sampling_rate,
|
| 2038 |
+
wave_concat_overlap=wave_concat_overlap,
|
| 2039 |
+
)
|
modular_longcat_next_visual.py
ADDED
|
@@ -0,0 +1,1077 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.checkpoint
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.amp import autocast
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from flash_attn import flash_attn_varlen_func
|
| 13 |
+
|
| 14 |
+
from transformers.activations import ACT2FN
|
| 15 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 16 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 17 |
+
Qwen2RMSNorm,
|
| 18 |
+
Qwen2_5_VisionTransformerPretrainedModel,
|
| 19 |
+
)
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
|
| 22 |
+
from .image_refiner import (
|
| 23 |
+
ImageRefinerContainer,
|
| 24 |
+
RefinerImageProcessor,
|
| 25 |
+
RefinerPipeline,
|
| 26 |
+
de_transform,
|
| 27 |
+
tensor2pil,
|
| 28 |
+
)
|
| 29 |
+
from .refiner_modules import FlowMatchEulerDiscreteScheduler
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def uniform_init(*shape):
|
| 35 |
+
t = torch.zeros(shape)
|
| 36 |
+
nn.init.kaiming_uniform_(t)
|
| 37 |
+
return t
|
| 38 |
+
|
| 39 |
+
class VQEmbedding(nn.Module):
|
| 40 |
+
"""VQ embedding module with ema update."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5, init_std=0.02):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.ema = ema
|
| 46 |
+
self.decay = decay
|
| 47 |
+
self.eps = eps
|
| 48 |
+
self.restart_unused_codes = restart_unused_codes
|
| 49 |
+
self.n_embed = n_embed
|
| 50 |
+
self.init_std = init_std
|
| 51 |
+
|
| 52 |
+
assert self.ema
|
| 53 |
+
embed = uniform_init(n_embed + 1, embed_dim).to(torch.float32)
|
| 54 |
+
self.embed = nn.Parameter(embed)
|
| 55 |
+
self.embed_ema = nn.Parameter(embed[:-1, :].clone())
|
| 56 |
+
self.cluster_size_ema = nn.Parameter(torch.ones(n_embed))
|
| 57 |
+
del embed
|
| 58 |
+
_ = [p.requires_grad_(False) for p in self.parameters()]
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def compute_distances(self, inputs):
|
| 62 |
+
codebook_t = self.embed[:-1, :].t()
|
| 63 |
+
|
| 64 |
+
(embed_dim, _) = codebook_t.shape
|
| 65 |
+
inputs_shape = inputs.shape
|
| 66 |
+
assert inputs_shape[-1] == embed_dim
|
| 67 |
+
|
| 68 |
+
inputs_flat = inputs.reshape(-1, embed_dim)
|
| 69 |
+
|
| 70 |
+
inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True)
|
| 71 |
+
codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True)
|
| 72 |
+
distances = torch.addmm(
|
| 73 |
+
inputs_norm_sq + codebook_t_norm_sq,
|
| 74 |
+
inputs_flat,
|
| 75 |
+
codebook_t,
|
| 76 |
+
alpha=-2.0,
|
| 77 |
+
)
|
| 78 |
+
distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1]
|
| 79 |
+
return distances
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def find_nearest_embedding(self, inputs):
|
| 83 |
+
distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1]
|
| 84 |
+
embed_idxs = distances.argmin(dim=-1) # use padding index or not
|
| 85 |
+
|
| 86 |
+
return embed_idxs
|
| 87 |
+
|
| 88 |
+
@autocast('cuda', enabled=True, dtype=torch.float32)
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def forward(self, inputs):
|
| 91 |
+
if inputs.dtype != torch.float32:
|
| 92 |
+
inputs = inputs.to(torch.float32)
|
| 93 |
+
embed_idxs = self.find_nearest_embedding(inputs)
|
| 94 |
+
embeds = self.embed[embed_idxs]
|
| 95 |
+
return embeds, embed_idxs
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RQBottleneck(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
Quantization bottleneck via Residual Quantization.
|
| 101 |
+
|
| 102 |
+
Arguments:
|
| 103 |
+
latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D)
|
| 104 |
+
code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d)
|
| 105 |
+
n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook)
|
| 106 |
+
If isinstance(n_embed, int), the sizes of all codebooks are same.
|
| 107 |
+
shared_codebook (bool): If True, codebooks are shared in all location. If False,
|
| 108 |
+
uses separate codebooks along the ``depth'' dimension. (default: False)
|
| 109 |
+
restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch
|
| 110 |
+
as the new embedding of unused codes in training. (default: True)
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self,
|
| 114 |
+
latent_shape,
|
| 115 |
+
code_shape,
|
| 116 |
+
n_embed,
|
| 117 |
+
decay=0.99,
|
| 118 |
+
shared_codebook=False,
|
| 119 |
+
restart_unused_codes=True,
|
| 120 |
+
commitment_loss='cumsum'
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
if not len(code_shape) == len(latent_shape) == 3:
|
| 125 |
+
raise ValueError("incompatible code shape or latent shape")
|
| 126 |
+
if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]):
|
| 127 |
+
raise ValueError("incompatible code shape or latent shape")
|
| 128 |
+
|
| 129 |
+
#residual quantization does not divide feature dims for quantization.
|
| 130 |
+
embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2]
|
| 131 |
+
|
| 132 |
+
self.latent_shape = torch.Size(latent_shape)
|
| 133 |
+
self.code_shape = torch.Size(code_shape)
|
| 134 |
+
self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))])
|
| 135 |
+
|
| 136 |
+
self.shared_codebook = shared_codebook
|
| 137 |
+
if self.shared_codebook:
|
| 138 |
+
if isinstance(n_embed, Iterable) or isinstance(decay, Iterable):
|
| 139 |
+
raise ValueError("Shared codebooks are incompatible \
|
| 140 |
+
with list types of momentums or sizes: Change it into int")
|
| 141 |
+
|
| 142 |
+
self.restart_unused_codes = restart_unused_codes
|
| 143 |
+
self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])]
|
| 144 |
+
self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])]
|
| 145 |
+
assert len(self.n_embed) == self.code_shape[-1]
|
| 146 |
+
assert len(self.decay) == self.code_shape[-1]
|
| 147 |
+
|
| 148 |
+
if self.shared_codebook:
|
| 149 |
+
codebook0 = VQEmbedding(self.n_embed[0],
|
| 150 |
+
embed_dim,
|
| 151 |
+
decay=self.decay[0],
|
| 152 |
+
restart_unused_codes=restart_unused_codes,
|
| 153 |
+
).to(torch.float32)
|
| 154 |
+
self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])])
|
| 155 |
+
else:
|
| 156 |
+
codebooks = [VQEmbedding(self.n_embed[idx],
|
| 157 |
+
embed_dim,
|
| 158 |
+
decay=self.decay[idx],
|
| 159 |
+
restart_unused_codes=restart_unused_codes,
|
| 160 |
+
).to(torch.float32) for idx in range(self.code_shape[-1])]
|
| 161 |
+
self.codebooks = nn.ModuleList(codebooks)
|
| 162 |
+
|
| 163 |
+
self.commitment_loss = commitment_loss
|
| 164 |
+
|
| 165 |
+
def to_code_shape(self, x):
|
| 166 |
+
(B, H, W, D) = x.shape
|
| 167 |
+
(rH, rW, _) = self.shape_divisor
|
| 168 |
+
|
| 169 |
+
x = x.reshape(B, H//rH, rH, W//rW, rW, D)
|
| 170 |
+
x = x.permute(0, 1, 3, 2, 4, 5)
|
| 171 |
+
x = x.reshape(B, H//rH, W//rW, -1)
|
| 172 |
+
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
def to_latent_shape(self, x):
|
| 176 |
+
(B, h, w, _) = x.shape
|
| 177 |
+
(_, _, D) = self.latent_shape
|
| 178 |
+
(rH, rW, _) = self.shape_divisor
|
| 179 |
+
|
| 180 |
+
x = x.reshape(B, h, w, rH, rW, D)
|
| 181 |
+
x = x.permute(0, 1, 3, 2, 4, 5)
|
| 182 |
+
x = x.reshape(B, h*rH, w*rW, D)
|
| 183 |
+
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
def quantize(self, x):
|
| 187 |
+
r"""
|
| 188 |
+
Return list of quantized features and the selected codewords by the residual quantization.
|
| 189 |
+
The code is selected by the residuals between x and quantized features by the previous codebooks.
|
| 190 |
+
|
| 191 |
+
Arguments:
|
| 192 |
+
x (Tensor): bottleneck feature maps to quantize.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks.
|
| 196 |
+
codes (LongTensor): codewords index, corresponding to quants.
|
| 197 |
+
|
| 198 |
+
Shape:
|
| 199 |
+
- x: (B, h, w, embed_dim)
|
| 200 |
+
- quant_list[i]: (B, h, w, embed_dim)
|
| 201 |
+
- codes: (B, h, w, d)
|
| 202 |
+
"""
|
| 203 |
+
B, h, w, embed_dim = x.shape
|
| 204 |
+
ori_dtype = x.dtype
|
| 205 |
+
x = x.to(torch.float32)
|
| 206 |
+
self.codebooks = self.codebooks.to(torch.float32)
|
| 207 |
+
|
| 208 |
+
residual_feature = x.detach().clone()
|
| 209 |
+
|
| 210 |
+
quant_list = []
|
| 211 |
+
code_list = []
|
| 212 |
+
aggregated_quants = torch.zeros_like(x)
|
| 213 |
+
for i in range(self.code_shape[-1]):
|
| 214 |
+
quant, code = self.codebooks[i](residual_feature)
|
| 215 |
+
residual_feature.sub_(quant)
|
| 216 |
+
aggregated_quants.add_(quant)
|
| 217 |
+
quant_list.append(aggregated_quants.clone().to(dtype=ori_dtype))
|
| 218 |
+
code_list.append(code.unsqueeze(-1))
|
| 219 |
+
|
| 220 |
+
codes = torch.cat(code_list, dim=-1)
|
| 221 |
+
return quant_list, codes
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
x_reshaped = self.to_code_shape(x)
|
| 225 |
+
# 强制使用float32精度来执行
|
| 226 |
+
quant_list, codes = self.quantize(x_reshaped)
|
| 227 |
+
# quant_list, codes = self.quantize(x_reshaped)
|
| 228 |
+
|
| 229 |
+
commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list)
|
| 230 |
+
quants_trunc = self.to_latent_shape(quant_list[-1])
|
| 231 |
+
quants_trunc = x + (quants_trunc - x).detach()
|
| 232 |
+
|
| 233 |
+
'''
|
| 234 |
+
if self.shared_codebook:
|
| 235 |
+
cur_len = codes.view(-1).shape[0]
|
| 236 |
+
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
|
| 237 |
+
self.codebook_used[-cur_len:] = codes.view(-1)
|
| 238 |
+
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_embed[0]
|
| 239 |
+
else:
|
| 240 |
+
# info|code: torch.Size([10, 16, 16, 4])
|
| 241 |
+
codebook_usage = 0
|
| 242 |
+
for idx in range(self.code_shape[-1]):
|
| 243 |
+
cur_len = codes[..., idx].view(-1).shape[0]
|
| 244 |
+
self.codebook_used[idx, :-cur_len] = self.codebook_used[idx, cur_len:].clone()
|
| 245 |
+
self.codebook_used[idx, -cur_len:] = codes[..., idx].view(-1)
|
| 246 |
+
codebook_usage += len(torch.unique(self.codebook_used[idx]))
|
| 247 |
+
codebook_usage /= (self.n_embed[0] * self.code_shape[-1])
|
| 248 |
+
'''
|
| 249 |
+
codebook_usage = 0
|
| 250 |
+
# (vq_loss, commit_loss, entropy_loss, codebook_usage) # 格式对齐
|
| 251 |
+
codebook_loss = [0, commitment_loss, 0, codebook_usage]
|
| 252 |
+
|
| 253 |
+
return quants_trunc, codebook_loss, codes
|
| 254 |
+
|
| 255 |
+
def compute_commitment_loss(self, x, quant_list):
|
| 256 |
+
r"""
|
| 257 |
+
Compute the commitment loss for the residual quantization.
|
| 258 |
+
The loss is iteratively computed by aggregating quantized features.
|
| 259 |
+
"""
|
| 260 |
+
loss_list = []
|
| 261 |
+
|
| 262 |
+
for idx, quant in enumerate(quant_list):
|
| 263 |
+
partial_loss = (x-quant.detach()).pow(2.0).mean()
|
| 264 |
+
loss_list.append(partial_loss)
|
| 265 |
+
|
| 266 |
+
commitment_loss = torch.mean(torch.stack(loss_list))
|
| 267 |
+
return commitment_loss
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Qwen2_5_VisionRotaryEmbedding_Modified(nn.Module):
|
| 272 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 275 |
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 276 |
+
|
| 277 |
+
def forward(self, seqlen: int, device: torch.device) -> torch.Tensor:
|
| 278 |
+
self.inv_freq = self.inv_freq.to(device)
|
| 279 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 280 |
+
freqs = torch.outer(seq, self.inv_freq)
|
| 281 |
+
return freqs
|
| 282 |
+
|
| 283 |
+
class VisualEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
| 284 |
+
|
| 285 |
+
def __init__(self, config):
|
| 286 |
+
config._attn_implementation = 'flash_attention_2'
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding_Modified(config.hidden_size // config.num_heads // 2)
|
| 289 |
+
self.gradient_checkpointing = False
|
| 290 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
| 291 |
+
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
|
| 292 |
+
del self.merger # register visual.merger in visual_bridge_model
|
| 293 |
+
|
| 294 |
+
def get_dtype(self) -> torch.dtype:
|
| 295 |
+
return self.blocks[0].mlp.down_proj.weight.dtype
|
| 296 |
+
|
| 297 |
+
def get_device(self) -> torch.device:
|
| 298 |
+
return self.blocks[0].mlp.down_proj.weight.device
|
| 299 |
+
|
| 300 |
+
def rot_pos_emb(self, grid_thw):
|
| 301 |
+
pos_ids = []
|
| 302 |
+
for t, h, w in grid_thw:
|
| 303 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 304 |
+
hpos_ids = hpos_ids.reshape(
|
| 305 |
+
h // self.spatial_merge_size,
|
| 306 |
+
self.spatial_merge_size,
|
| 307 |
+
w // self.spatial_merge_size,
|
| 308 |
+
self.spatial_merge_size,
|
| 309 |
+
)
|
| 310 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 311 |
+
hpos_ids = hpos_ids.flatten()
|
| 312 |
+
|
| 313 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 314 |
+
wpos_ids = wpos_ids.reshape(
|
| 315 |
+
h // self.spatial_merge_size,
|
| 316 |
+
self.spatial_merge_size,
|
| 317 |
+
w // self.spatial_merge_size,
|
| 318 |
+
self.spatial_merge_size,
|
| 319 |
+
)
|
| 320 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 321 |
+
wpos_ids = wpos_ids.flatten()
|
| 322 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 323 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 324 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 325 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
|
| 326 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 327 |
+
return rotary_pos_emb
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
pixel_values: torch.Tensor,
|
| 332 |
+
grid_thw: torch.Tensor,
|
| 333 |
+
require_window_index: bool = False,
|
| 334 |
+
):
|
| 335 |
+
'''
|
| 336 |
+
pixel_values.shape=[NumOfPatches, 1176]
|
| 337 |
+
grid_thw.shape=[NumOfSamples, 3]. [grid_t,grid_h,grid_w]
|
| 338 |
+
'''
|
| 339 |
+
hidden_states = pixel_values.to(torch.bfloat16)
|
| 340 |
+
grid_thw = grid_thw.to(pixel_values.device)
|
| 341 |
+
|
| 342 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 343 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 344 |
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
| 345 |
+
cu_window_seqlens = torch.tensor(
|
| 346 |
+
cu_window_seqlens,
|
| 347 |
+
device=hidden_states.device,
|
| 348 |
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
| 349 |
+
)
|
| 350 |
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
| 351 |
+
|
| 352 |
+
seq_len, _ = hidden_states.size()
|
| 353 |
+
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 354 |
+
hidden_states = hidden_states[window_index, :, :]
|
| 355 |
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
| 356 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 357 |
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
| 358 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
| 359 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 360 |
+
position_embeddings = (emb.cos(), emb.sin())
|
| 361 |
+
|
| 362 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 363 |
+
dim=0,
|
| 364 |
+
# Select dtype based on the following factors:
|
| 365 |
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
| 366 |
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
| 367 |
+
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
| 368 |
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
| 369 |
+
)
|
| 370 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 371 |
+
|
| 372 |
+
for layer_num, blk in enumerate(self.blocks):
|
| 373 |
+
if layer_num in self.fullatt_block_indexes:
|
| 374 |
+
cu_seqlens_now = cu_seqlens
|
| 375 |
+
else:
|
| 376 |
+
cu_seqlens_now = cu_window_seqlens
|
| 377 |
+
if self.gradient_checkpointing and self.training:
|
| 378 |
+
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings)
|
| 379 |
+
else:
|
| 380 |
+
hidden_states = blk(
|
| 381 |
+
hidden_states,
|
| 382 |
+
cu_seqlens=cu_seqlens_now,
|
| 383 |
+
position_embeddings=position_embeddings,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if require_window_index:
|
| 387 |
+
return hidden_states, window_index
|
| 388 |
+
return hidden_states
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class OmniVisualBridge(nn.Module):
|
| 392 |
+
def __init__(self, config):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.config = config
|
| 395 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
| 396 |
+
self.hidden_size = self.config.hidden_size * (self.merge_size**2)
|
| 397 |
+
self.window_index = self.config.window_size
|
| 398 |
+
self.ln_q = Qwen2RMSNorm(self.config.hidden_size, eps=1e-6)
|
| 399 |
+
self.mlp = nn.Sequential(
|
| 400 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 401 |
+
nn.GELU(),
|
| 402 |
+
nn.Linear(self.hidden_size, self.config.out_hidden_size),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def forward(self, x: torch.Tensor, window_index) -> torch.Tensor:
|
| 406 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 407 |
+
reverse_indices = torch.argsort(window_index)
|
| 408 |
+
x = x[reverse_indices, :]
|
| 409 |
+
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class VisualQuantizer(nn.Module):
|
| 414 |
+
def __init__(self, quantizer_config):
|
| 415 |
+
super().__init__()
|
| 416 |
+
|
| 417 |
+
self.config = quantizer_config
|
| 418 |
+
self.depth = self.config.depth
|
| 419 |
+
self.decay = self.config.decay
|
| 420 |
+
self.codebook_size = self.config.codebook_size
|
| 421 |
+
self.codebook_dim = self.config.codebook_dim
|
| 422 |
+
self.shared_codebook = self.config.shared_codebook
|
| 423 |
+
self.restart_unused_codes = self.config.restart_unused_codes
|
| 424 |
+
self.in_channels = self.config.in_channels
|
| 425 |
+
|
| 426 |
+
self.vq_loss_ratio = self.config.vq_loss_ratio
|
| 427 |
+
self.entropy_loss_ratio = self.config.entropy_loss_ratio
|
| 428 |
+
self.commit_loss_ratio = self.config.commit_loss_ratio
|
| 429 |
+
|
| 430 |
+
code_h_w = int(448 / 14)
|
| 431 |
+
latent_shape = [code_h_w, code_h_w, self.codebook_dim]
|
| 432 |
+
code_shape = [code_h_w, code_h_w, self.depth]
|
| 433 |
+
|
| 434 |
+
self.quantize = RQBottleneck(
|
| 435 |
+
latent_shape=latent_shape,
|
| 436 |
+
code_shape=code_shape,
|
| 437 |
+
n_embed=self.codebook_size,
|
| 438 |
+
decay=self.decay,
|
| 439 |
+
shared_codebook=self.shared_codebook,
|
| 440 |
+
restart_unused_codes=self.restart_unused_codes,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
if self.config.quant_conv:
|
| 444 |
+
self.quant_conv = nn.Sequential(
|
| 445 |
+
nn.LayerNorm(self.in_channels),
|
| 446 |
+
nn.Linear(self.in_channels, self.in_channels),
|
| 447 |
+
nn.GELU(),
|
| 448 |
+
nn.Linear(self.in_channels, self.codebook_dim)
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
self.quant_conv = None
|
| 452 |
+
|
| 453 |
+
def encode(self, x):
|
| 454 |
+
L, D = x.shape
|
| 455 |
+
to_qnt_feat = x.clone()
|
| 456 |
+
to_qnt_feat = to_qnt_feat.unsqueeze(0) # [L, D] -> [1, L, D]
|
| 457 |
+
N = 1
|
| 458 |
+
|
| 459 |
+
if self.quant_conv is not None:
|
| 460 |
+
to_qnt_feat = self.quant_conv(to_qnt_feat)
|
| 461 |
+
|
| 462 |
+
# quantizer needs nchw format. N,L,d -> N,1,L,d -> N,d,1,L
|
| 463 |
+
to_qnt_feat = to_qnt_feat.reshape(N, 1, L, self.codebook_dim).permute(0,3,1,2)
|
| 464 |
+
if self.config.quantizer_type == "rq":
|
| 465 |
+
to_qnt_feat = to_qnt_feat.permute(0, 2, 3, 1).contiguous() # N,d,1,L -> N,1,L,d
|
| 466 |
+
quant, emb_loss, info = self.quantize(to_qnt_feat)
|
| 467 |
+
info = info.reshape(-1, info.shape[-1]) # n,h,w,lv -> n*h*w,lv
|
| 468 |
+
info = [None, None, info]
|
| 469 |
+
quant = quant.permute(0, 3, 1, 2).contiguous() # N,1,L,d -> N,d,1,L
|
| 470 |
+
else:
|
| 471 |
+
quant, emb_loss, info = self.quantize(to_qnt_feat)
|
| 472 |
+
return quant, emb_loss, info, x.detach()
|
| 473 |
+
|
| 474 |
+
def forward(self, x):
|
| 475 |
+
quant, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices), align_feature = \
|
| 476 |
+
self.encode(x)
|
| 477 |
+
return min_encoding_indices
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class MLP(nn.Module):
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
hidden_size: int,
|
| 484 |
+
intermediate_size: int,
|
| 485 |
+
hidden_act: str,
|
| 486 |
+
):
|
| 487 |
+
super().__init__()
|
| 488 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 489 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 490 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 491 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 492 |
+
|
| 493 |
+
def forward(self, x):
|
| 494 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 495 |
+
|
| 496 |
+
class DecoderLayer(nn.Module):
|
| 497 |
+
def __init__(self, config):
|
| 498 |
+
super().__init__()
|
| 499 |
+
self.hidden_size = config.hidden_size
|
| 500 |
+
self.mlp = MLP(
|
| 501 |
+
hidden_size=self.hidden_size,
|
| 502 |
+
intermediate_size=config.visual_embedding_layer_intermediate_size,
|
| 503 |
+
hidden_act=config.visual_embedding_layer_hidden_act,
|
| 504 |
+
)
|
| 505 |
+
self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 506 |
+
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
hidden_states: torch.Tensor,
|
| 510 |
+
):
|
| 511 |
+
residual = hidden_states
|
| 512 |
+
hidden_states = self.pre_layernorm(hidden_states)
|
| 513 |
+
hidden_states = self.mlp(hidden_states)
|
| 514 |
+
hidden_states = residual + hidden_states
|
| 515 |
+
|
| 516 |
+
return hidden_states
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class VisualEmbeddingBridge(nn.Module):
|
| 520 |
+
def __init__(self, config):
|
| 521 |
+
super().__init__()
|
| 522 |
+
self.pre_buffer = DecoderLayer(config)
|
| 523 |
+
|
| 524 |
+
def forward(self, embeding):
|
| 525 |
+
return self.pre_buffer(embeding)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class VisualVQBridge(nn.Module):
|
| 529 |
+
def __init__(self, visual_config):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.bridge = OmniVisualBridge(visual_config)
|
| 532 |
+
self.quantizer = VisualQuantizer(visual_config.vq_config)
|
| 533 |
+
|
| 534 |
+
def forward(
|
| 535 |
+
self,
|
| 536 |
+
visual_embed: torch.Tensor,
|
| 537 |
+
window_index: torch.Tensor,
|
| 538 |
+
):
|
| 539 |
+
visual_embed = self.bridge(visual_embed, window_index)
|
| 540 |
+
indices = self.quantizer(visual_embed)
|
| 541 |
+
return indices
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class LongcatNextVisualTokenizer(nn.Module):
|
| 545 |
+
|
| 546 |
+
def __init__(self, config):
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.config = config
|
| 549 |
+
self.visual_model = VisualEncoder(config.visual_config)
|
| 550 |
+
self.visual_bridge_model = VisualVQBridge(config.visual_config)
|
| 551 |
+
self.visual_embedding_layer = VisualEmbeddingBridge(config)
|
| 552 |
+
self.image_decoder = None
|
| 553 |
+
self._refiner_pipeline = None
|
| 554 |
+
|
| 555 |
+
@torch.no_grad()
|
| 556 |
+
def encode(self, pixel_values: torch.Tensor, visual_grid_thw: torch.Tensor):
|
| 557 |
+
visual_embed, window_index = self.visual_model(pixel_values, grid_thw=visual_grid_thw, require_window_index=True)
|
| 558 |
+
indices = self.visual_bridge_model(visual_embed, window_index)
|
| 559 |
+
return indices
|
| 560 |
+
|
| 561 |
+
@torch.no_grad()
|
| 562 |
+
def lazy_decode_and_save(self, visual_ids, tokens_h, tokens_w, save_path):
|
| 563 |
+
device = next(self.parameters()).device
|
| 564 |
+
if self.image_decoder is None:
|
| 565 |
+
print("lazy load image_decoder / image_refiner / _refiner_pipeline ...")
|
| 566 |
+
vdc = self.config.visual_config.visual_decoder_config
|
| 567 |
+
self.image_decoder = VisionTransformerDecoder.from_pretrained(
|
| 568 |
+
vdc.image_decoder_config,
|
| 569 |
+
vdc.weight_path,
|
| 570 |
+
).to(device=device, dtype=torch.bfloat16)
|
| 571 |
+
image_refiner = ImageRefinerContainer.from_pretrained(vdc, vdc.weight_path).to(device=device, dtype=torch.bfloat16)
|
| 572 |
+
|
| 573 |
+
sc = vdc.scheduler_config
|
| 574 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
| 575 |
+
num_train_timesteps=sc.num_train_timesteps,
|
| 576 |
+
dynamic_time_shift=sc.dynamic_time_shift)
|
| 577 |
+
self._refiner_pipeline = RefinerPipeline(
|
| 578 |
+
vae=image_refiner.vae,
|
| 579 |
+
transformer=image_refiner.base_transformer,
|
| 580 |
+
scheduler=scheduler,
|
| 581 |
+
cond_proj=image_refiner.cond_proj,
|
| 582 |
+
)
|
| 583 |
+
self._refiner_pipeline.set_progress_bar_config(disable=False)
|
| 584 |
+
|
| 585 |
+
data = torch.as_tensor(visual_ids, dtype=torch.long)
|
| 586 |
+
if data.ndim == 1:
|
| 587 |
+
data = data.view(-1, len(self.config.visual_config.vq_config.codebook_sizes))
|
| 588 |
+
if data.ndim == 2:
|
| 589 |
+
data = data.unsqueeze(0)
|
| 590 |
+
batch_size = data.shape[0]
|
| 591 |
+
|
| 592 |
+
quant_features = None
|
| 593 |
+
for idx in range(len(self.config.visual_config.vq_config.codebook_sizes)):
|
| 594 |
+
embed = self.visual_bridge_model.quantizer.quantize.codebooks[idx].embed
|
| 595 |
+
feat = embed[data[..., idx].to(embed.device)]
|
| 596 |
+
quant_features = feat if quant_features is None else quant_features + feat
|
| 597 |
+
quant_features = quant_features.to(device)
|
| 598 |
+
|
| 599 |
+
# tokens_h/tokens_w are the merged grid; expand to the full (unmerged) grid
|
| 600 |
+
s = self.image_decoder.spatial_merge_size
|
| 601 |
+
grid_thw_list = [(1, tokens_h * s, tokens_w * s)]
|
| 602 |
+
grid_thw_batch = list(grid_thw_list) * batch_size
|
| 603 |
+
|
| 604 |
+
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 605 |
+
image_std = [0.26862954, 0.26130258, 0.27577711]
|
| 606 |
+
|
| 607 |
+
emb_2d = quant_features.reshape(-1, quant_features.shape[-1]).contiguous()
|
| 608 |
+
device_type = "cuda" if str(device).startswith("cuda") else str(device)
|
| 609 |
+
with torch.amp.autocast(device_type=device_type, enabled=True, dtype=torch.float32):
|
| 610 |
+
decoder_out = self.image_decoder(emb_2d, grid_thw_batch, return_pixel_features=False)
|
| 611 |
+
|
| 612 |
+
decoded_tensors = decoder_out.get("images") or []
|
| 613 |
+
decoded_images = [tensor2pil(t, image_mean, image_std) for t in decoded_tensors]
|
| 614 |
+
decoded_path = save_path.replace(".png", "_decoded.png")
|
| 615 |
+
# decoded_images[0].save(decoded_path)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
ref_input = []
|
| 619 |
+
for t in decoded_tensors:
|
| 620 |
+
img_01 = de_transform(t, mean=image_mean, std=image_std, rescale_factor=1 / 255)
|
| 621 |
+
img_norm = RefinerImageProcessor.normalize(img_01)
|
| 622 |
+
ref_input.append(img_norm.squeeze(0).to(device))
|
| 623 |
+
|
| 624 |
+
generators = [torch.Generator(device=device).manual_seed(42 + b) for b in range(batch_size)]
|
| 625 |
+
out = self._refiner_pipeline(
|
| 626 |
+
encoder_hidden_states=quant_features,
|
| 627 |
+
grid_thw_list=grid_thw_list,
|
| 628 |
+
image=ref_input,
|
| 629 |
+
generator=generators[0] if batch_size == 1 else generators,
|
| 630 |
+
output_type="pil",
|
| 631 |
+
return_dict=True,
|
| 632 |
+
)
|
| 633 |
+
refined_images = out.images
|
| 634 |
+
refined_path = save_path.replace(".png", "_refined.png")
|
| 635 |
+
refined_images[0].save(refined_path)
|
| 636 |
+
|
| 637 |
+
return [refined_path]
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# ---------------------------------------------------------------------------
|
| 641 |
+
# Vision Transformer Decoder
|
| 642 |
+
# ---------------------------------------------------------------------------
|
| 643 |
+
|
| 644 |
+
def _rotate_half(x):
|
| 645 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 646 |
+
x1, x2 = x.unbind(dim=-1)
|
| 647 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 648 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
class VisionRoPE2D(nn.Module):
|
| 652 |
+
"""2D Rotary Position Embedding for Q/K in vision decoder attention."""
|
| 653 |
+
|
| 654 |
+
def __init__(self, theta: float = 10000.0):
|
| 655 |
+
super().__init__()
|
| 656 |
+
self.theta = theta
|
| 657 |
+
|
| 658 |
+
def _rope_half(self, x_half, pos_1d, theta):
|
| 659 |
+
BH, T, d_half = x_half.shape
|
| 660 |
+
idx = torch.arange(0, d_half, 2, device=x_half.device, dtype=torch.float32)
|
| 661 |
+
inv_freq = (1.0 / (theta ** (idx / d_half))).to(x_half.dtype)
|
| 662 |
+
angles = pos_1d.to(x_half.dtype)[:, None] * inv_freq[None, :]
|
| 663 |
+
cos = torch.repeat_interleave(torch.cos(angles), 2, dim=-1).unsqueeze(0)
|
| 664 |
+
sin = torch.repeat_interleave(torch.sin(angles), 2, dim=-1).unsqueeze(0)
|
| 665 |
+
return x_half * cos + _rotate_half(x_half) * sin
|
| 666 |
+
|
| 667 |
+
def forward(self, x, positions_2d):
|
| 668 |
+
d_half = x.shape[-1] // 2
|
| 669 |
+
x_y = self._rope_half(x[:, :, :d_half], positions_2d[:, 0], self.theta)
|
| 670 |
+
x_x = self._rope_half(x[:, :, d_half:], positions_2d[:, 1], self.theta)
|
| 671 |
+
return torch.cat([x_y, x_x], dim=-1)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
class VisionAttention(nn.Module):
|
| 675 |
+
"""Multi-headed attention with 2D RoPE + FlashAttention varlen."""
|
| 676 |
+
|
| 677 |
+
def __init__(self, config, rope=None, rope_shift=0):
|
| 678 |
+
super().__init__()
|
| 679 |
+
self.config = config
|
| 680 |
+
self.embed_dim = config.hidden_size
|
| 681 |
+
self.num_heads = config.num_attention_heads
|
| 682 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 683 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 684 |
+
raise ValueError(
|
| 685 |
+
f"embed_dim must be divisible by num_heads (got embed_dim={self.embed_dim}, num_heads={self.num_heads})"
|
| 686 |
+
)
|
| 687 |
+
self.scale = self.head_dim ** -0.5
|
| 688 |
+
self.dropout = config.attention_dropout
|
| 689 |
+
self.subln = config.subln
|
| 690 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "k_bias", True))
|
| 691 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "v_bias", True))
|
| 692 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "q_bias", True))
|
| 693 |
+
self.inner_attn_ln = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) if config.subln else nn.Identity()
|
| 694 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 695 |
+
self.rope = rope
|
| 696 |
+
self.rope_shift = int(rope_shift)
|
| 697 |
+
|
| 698 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 699 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 700 |
+
|
| 701 |
+
def _maybe_flash_attention(self, query_states, key_states, value_states, seq_lens, training):
|
| 702 |
+
if not (query_states.is_cuda and (query_states.dtype in (torch.float16, torch.bfloat16, torch.float32))):
|
| 703 |
+
return None
|
| 704 |
+
if seq_lens is None:
|
| 705 |
+
return None
|
| 706 |
+
try:
|
| 707 |
+
BxH, T, hd = query_states.shape
|
| 708 |
+
H = self.num_heads
|
| 709 |
+
assert BxH % H == 0
|
| 710 |
+
B = BxH // H
|
| 711 |
+
if int(seq_lens.sum().item()) != T:
|
| 712 |
+
return None
|
| 713 |
+
q = query_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
|
| 714 |
+
k = key_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
|
| 715 |
+
v = value_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
|
| 716 |
+
cu_q = torch.zeros(seq_lens.numel() + 1, dtype=torch.int32, device=seq_lens.device)
|
| 717 |
+
cu_q[1:] = torch.cumsum(seq_lens.to(torch.int32), dim=0)
|
| 718 |
+
cu_k = cu_q
|
| 719 |
+
max_seqlen = int(seq_lens.max().item())
|
| 720 |
+
orig_dtype = q.dtype
|
| 721 |
+
use_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.float16
|
| 722 |
+
if q.dtype != use_dtype:
|
| 723 |
+
q = q.to(use_dtype)
|
| 724 |
+
k = k.to(use_dtype)
|
| 725 |
+
v = v.to(use_dtype)
|
| 726 |
+
out = flash_attn_varlen_func(
|
| 727 |
+
q, k, v, cu_q, cu_k, max_seqlen, max_seqlen,
|
| 728 |
+
dropout_p=self.dropout if training else 0.0,
|
| 729 |
+
softmax_scale=None, causal=False, return_attn_probs=False
|
| 730 |
+
)
|
| 731 |
+
if out.dtype != orig_dtype:
|
| 732 |
+
out = out.to(orig_dtype)
|
| 733 |
+
return out.view(B, -1, H, hd).transpose(1, 2).contiguous().view(B * H, T, hd)
|
| 734 |
+
except Exception:
|
| 735 |
+
return None
|
| 736 |
+
|
| 737 |
+
def forward(
|
| 738 |
+
self,
|
| 739 |
+
hidden_states: torch.Tensor,
|
| 740 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 741 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 742 |
+
output_attentions: Optional[bool] = False,
|
| 743 |
+
positions_2d: Optional[torch.Tensor] = None,
|
| 744 |
+
seq_lens: Optional[torch.Tensor] = None,
|
| 745 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 746 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 747 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
| 748 |
+
key_states = self.k_proj(hidden_states)
|
| 749 |
+
value_states = self.v_proj(hidden_states)
|
| 750 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
|
| 751 |
+
key_states = self._shape(key_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
|
| 752 |
+
value_states = self._shape(value_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
|
| 753 |
+
if self.rope is not None and positions_2d is not None:
|
| 754 |
+
if self.rope_shift > 0:
|
| 755 |
+
q_pref = query_states[:, :self.rope_shift, :]
|
| 756 |
+
k_pref = key_states[:, :self.rope_shift, :]
|
| 757 |
+
q_rot = self.rope(query_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
|
| 758 |
+
k_rot = self.rope(key_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
|
| 759 |
+
query_states = torch.cat([q_pref, q_rot], dim=1).type_as(value_states)
|
| 760 |
+
key_states = torch.cat([k_pref, k_rot], dim=1).type_as(value_states)
|
| 761 |
+
else:
|
| 762 |
+
query_states = self.rope(query_states, positions_2d).type_as(value_states)
|
| 763 |
+
key_states = self.rope(key_states, positions_2d).type_as(value_states)
|
| 764 |
+
attn_output = self._maybe_flash_attention(
|
| 765 |
+
query_states, key_states, value_states, seq_lens=seq_lens, training=self.training
|
| 766 |
+
)
|
| 767 |
+
if attn_output is not None:
|
| 768 |
+
attn_weights_reshaped = None
|
| 769 |
+
else:
|
| 770 |
+
src_len = key_states.size(1)
|
| 771 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 772 |
+
if causal_attention_mask is not None:
|
| 773 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
| 774 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 775 |
+
if attention_mask is not None:
|
| 776 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 777 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 778 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 779 |
+
if output_attentions:
|
| 780 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 781 |
+
else:
|
| 782 |
+
attn_weights_reshaped = None
|
| 783 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 784 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 785 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 786 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
|
| 787 |
+
attn_output = self.inner_attn_ln(attn_output)
|
| 788 |
+
attn_output = self.out_proj(attn_output)
|
| 789 |
+
return attn_output, attn_weights_reshaped
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
class VisionSwiGLU(nn.Module):
|
| 793 |
+
def __init__(self, config):
|
| 794 |
+
super().__init__()
|
| 795 |
+
self.config = config
|
| 796 |
+
self.hidden_size = config.hidden_size
|
| 797 |
+
self.intermediate_size = config.intermediate_size
|
| 798 |
+
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size)
|
| 799 |
+
self.w2 = nn.Linear(self.hidden_size, self.intermediate_size)
|
| 800 |
+
self.w3 = nn.Linear(self.intermediate_size, self.hidden_size)
|
| 801 |
+
self.act_fn = nn.SiLU()
|
| 802 |
+
self.ffn_ln = Qwen2RMSNorm(self.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
|
| 803 |
+
|
| 804 |
+
def forward(self, x):
|
| 805 |
+
x1 = self.w1(x)
|
| 806 |
+
x2 = self.w2(x)
|
| 807 |
+
hidden = self.act_fn(x1) * x2
|
| 808 |
+
x = self.ffn_ln(hidden)
|
| 809 |
+
x = self.w3(x)
|
| 810 |
+
return x
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class VisionMLP(nn.Module):
|
| 814 |
+
def __init__(self, config):
|
| 815 |
+
super().__init__()
|
| 816 |
+
self.config = config
|
| 817 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 818 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 819 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 820 |
+
self.ffn_ln = Qwen2RMSNorm(config.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
|
| 821 |
+
|
| 822 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 823 |
+
hidden_states = self.fc1(hidden_states)
|
| 824 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 825 |
+
hidden_states = self.ffn_ln(hidden_states)
|
| 826 |
+
hidden_states = self.fc2(hidden_states)
|
| 827 |
+
return hidden_states
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class VisionEncoderLayer(nn.Module):
|
| 831 |
+
def __init__(self, config, rope=None, rope_shift=0):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.embed_dim = config.hidden_size
|
| 834 |
+
self.self_attn = VisionAttention(config, rope=rope, rope_shift=rope_shift)
|
| 835 |
+
self.layer_norm1 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 836 |
+
self.mlp = VisionSwiGLU(config) if config.swiglu else VisionMLP(config)
|
| 837 |
+
self.layer_norm2 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 838 |
+
|
| 839 |
+
def forward(
|
| 840 |
+
self,
|
| 841 |
+
hidden_states: torch.Tensor,
|
| 842 |
+
attention_mask: Optional[torch.Tensor],
|
| 843 |
+
causal_attention_mask: Optional[torch.Tensor],
|
| 844 |
+
output_attentions: Optional[bool] = False,
|
| 845 |
+
positions_2d: Optional[torch.Tensor] = None,
|
| 846 |
+
seq_lens: Optional[torch.Tensor] = None,
|
| 847 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
|
| 848 |
+
residual = hidden_states
|
| 849 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 850 |
+
hidden_states, attn_weights = self.self_attn(
|
| 851 |
+
hidden_states=hidden_states,
|
| 852 |
+
attention_mask=attention_mask,
|
| 853 |
+
causal_attention_mask=causal_attention_mask,
|
| 854 |
+
output_attentions=output_attentions,
|
| 855 |
+
positions_2d=positions_2d,
|
| 856 |
+
seq_lens=seq_lens,
|
| 857 |
+
)
|
| 858 |
+
hidden_states = residual + hidden_states
|
| 859 |
+
residual = hidden_states
|
| 860 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 861 |
+
hidden_states = self.mlp(hidden_states)
|
| 862 |
+
hidden_states = residual + hidden_states
|
| 863 |
+
outputs = (hidden_states,)
|
| 864 |
+
if output_attentions:
|
| 865 |
+
outputs += (attn_weights,)
|
| 866 |
+
return outputs
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class VisionEncoder(nn.Module):
|
| 870 |
+
def __init__(self, config, rope=None, rope_shift=0):
|
| 871 |
+
super().__init__()
|
| 872 |
+
self.config = config
|
| 873 |
+
self.layers = nn.ModuleList(
|
| 874 |
+
[VisionEncoderLayer(config, rope=rope, rope_shift=rope_shift) for _ in range(config.num_hidden_layers)]
|
| 875 |
+
)
|
| 876 |
+
self.gradient_checkpointing = False
|
| 877 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
| 878 |
+
|
| 879 |
+
def forward(
|
| 880 |
+
self,
|
| 881 |
+
inputs_embeds: torch.Tensor,
|
| 882 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 883 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 884 |
+
output_attentions: Optional[bool] = None,
|
| 885 |
+
output_hidden_states: Optional[bool] = None,
|
| 886 |
+
return_dict: Optional[bool] = None,
|
| 887 |
+
positions_2d: Optional[torch.Tensor] = None,
|
| 888 |
+
seq_lens: Optional[torch.Tensor] = None,
|
| 889 |
+
):
|
| 890 |
+
output_attentions = output_attentions if output_attentions is not None else False
|
| 891 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 892 |
+
return_dict = True if return_dict is None else return_dict
|
| 893 |
+
|
| 894 |
+
encoder_states = () if output_hidden_states else None
|
| 895 |
+
all_attentions = () if output_attentions else None
|
| 896 |
+
hidden_states = inputs_embeds
|
| 897 |
+
|
| 898 |
+
for layer in self.layers:
|
| 899 |
+
if output_hidden_states:
|
| 900 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 901 |
+
if self.gradient_checkpointing and self.training:
|
| 902 |
+
def custom_forward(hs, attn, causal, pos2d, seqlens):
|
| 903 |
+
return layer(
|
| 904 |
+
hs,
|
| 905 |
+
attention_mask=attn,
|
| 906 |
+
causal_attention_mask=causal,
|
| 907 |
+
output_attentions=False,
|
| 908 |
+
positions_2d=pos2d,
|
| 909 |
+
seq_lens=seqlens,
|
| 910 |
+
)[0]
|
| 911 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 912 |
+
custom_forward,
|
| 913 |
+
hidden_states,
|
| 914 |
+
attention_mask if attention_mask is not None else torch.tensor(0., device=hidden_states.device),
|
| 915 |
+
causal_attention_mask if causal_attention_mask is not None else torch.tensor(0., device=hidden_states.device),
|
| 916 |
+
positions_2d,
|
| 917 |
+
seq_lens if seq_lens is not None else torch.tensor([], device=hidden_states.device),
|
| 918 |
+
use_reentrant=False,
|
| 919 |
+
)
|
| 920 |
+
else:
|
| 921 |
+
layer_outputs = layer(
|
| 922 |
+
hidden_states,
|
| 923 |
+
attention_mask,
|
| 924 |
+
causal_attention_mask,
|
| 925 |
+
output_attentions=output_attentions,
|
| 926 |
+
positions_2d=positions_2d,
|
| 927 |
+
seq_lens=seq_lens,
|
| 928 |
+
)
|
| 929 |
+
hidden_states = layer_outputs[0]
|
| 930 |
+
if output_attentions:
|
| 931 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 932 |
+
|
| 933 |
+
if output_hidden_states:
|
| 934 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 935 |
+
|
| 936 |
+
if not return_dict:
|
| 937 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 938 |
+
|
| 939 |
+
return BaseModelOutput(
|
| 940 |
+
last_hidden_state=hidden_states,
|
| 941 |
+
hidden_states=encoder_states,
|
| 942 |
+
attentions=all_attentions,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
class PatchUnMerger(nn.Module):
|
| 947 |
+
"""Learnable inverse of Qwen2_5_VLPatchMerger."""
|
| 948 |
+
def __init__(self, dim, context_dim, spatial_merge_size=2):
|
| 949 |
+
super().__init__()
|
| 950 |
+
self.spatial_merge_size = spatial_merge_size
|
| 951 |
+
self.context_dim = context_dim
|
| 952 |
+
hidden = context_dim * (spatial_merge_size ** 2)
|
| 953 |
+
self.ln_q = Qwen2RMSNorm(dim, eps=1e-6)
|
| 954 |
+
self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, hidden))
|
| 955 |
+
|
| 956 |
+
def forward(self, x):
|
| 957 |
+
x = self.mlp(self.ln_q(x))
|
| 958 |
+
return x.view(x.shape[0] * (self.spatial_merge_size ** 2), self.context_dim)
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def restore_spatial_structure_and_convert_to_images(patches, grid_thw_list, patch_size,
|
| 962 |
+
channel_dim=3, temporal_patch_size=2, merge_size=2):
|
| 963 |
+
"""Convert decoder pixel features back to image tensors [3, H, W]."""
|
| 964 |
+
if isinstance(patches, tuple):
|
| 965 |
+
patches = patches[0]
|
| 966 |
+
image_tensors = []
|
| 967 |
+
ptr = 0
|
| 968 |
+
for grid in grid_thw_list:
|
| 969 |
+
gt, gh, gw = (int(x) for x in (grid if not isinstance(grid, torch.Tensor) else grid.tolist()))
|
| 970 |
+
n = gt * gh * gw
|
| 971 |
+
chunk = patches[ptr:ptr + n]
|
| 972 |
+
ptr += n
|
| 973 |
+
r = chunk.reshape(gt, gh // merge_size, gw // merge_size, merge_size, merge_size,
|
| 974 |
+
channel_dim, temporal_patch_size, patch_size, patch_size)
|
| 975 |
+
r = r.permute(0, 6, 5, 1, 3, 7, 2, 4, 8)
|
| 976 |
+
image_tensors.append(r.reshape(gt * temporal_patch_size, channel_dim, gh * patch_size, gw * patch_size)[0])
|
| 977 |
+
return image_tensors
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
class VisionTransformerDecoder(nn.Module):
|
| 981 |
+
def __init__(self, config):
|
| 982 |
+
super().__init__()
|
| 983 |
+
self.config = config
|
| 984 |
+
self.embed_dim = config.hidden_size
|
| 985 |
+
self.patch_size = config.patch_size
|
| 986 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 987 |
+
self.codebook_dim = config.codebook_dim
|
| 988 |
+
self.temporal_patch_size = config.temporal_patch_size
|
| 989 |
+
|
| 990 |
+
self.rope2d = VisionRoPE2D(theta=10000.0)
|
| 991 |
+
self.post_quant_conv = nn.Linear(self.codebook_dim, self.embed_dim)
|
| 992 |
+
self.post_quant_norm = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 993 |
+
self.patch_unmerger = PatchUnMerger(self.embed_dim, self.embed_dim, self.spatial_merge_size)
|
| 994 |
+
self.norm_in = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 995 |
+
self.encoder = VisionEncoder(config, rope=self.rope2d, rope_shift=0)
|
| 996 |
+
self.norm_out = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 997 |
+
self.decoder_head = nn.Sequential(
|
| 998 |
+
nn.Linear(self.embed_dim, config.intermediate_size), nn.GELU(),
|
| 999 |
+
nn.Linear(config.intermediate_size, 3 * self.patch_size * self.patch_size * self.temporal_patch_size),
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
@classmethod
|
| 1003 |
+
def from_pretrained(cls, config, model_path: str):
|
| 1004 |
+
"""Load a pretrained model from a checkpoint."""
|
| 1005 |
+
model = cls(config)
|
| 1006 |
+
weight_dict = load_file(model_path, device="cpu")
|
| 1007 |
+
model.load_state_dict({k.removeprefix("image_decoder."): v for k, v in weight_dict.items() if k.startswith("image_decoder.")}, strict=True)
|
| 1008 |
+
model.eval()
|
| 1009 |
+
return model
|
| 1010 |
+
|
| 1011 |
+
def _build_2d_positions(self, grid_thw_list):
|
| 1012 |
+
pos_list = []
|
| 1013 |
+
for (t, gh, gw) in grid_thw_list:
|
| 1014 |
+
for _ in range(int(t)):
|
| 1015 |
+
for y in range(int(gh)):
|
| 1016 |
+
for x in range(int(gw)):
|
| 1017 |
+
pos_list.append([y, x])
|
| 1018 |
+
return torch.tensor(pos_list, dtype=torch.long)
|
| 1019 |
+
|
| 1020 |
+
def _build_attention_mask(self, grid_thw_list, device, dtype, B, num_heads):
|
| 1021 |
+
counts = [int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list]
|
| 1022 |
+
L = sum(counts)
|
| 1023 |
+
mask = torch.zeros((B, num_heads, L, L), device=device, dtype=dtype)
|
| 1024 |
+
s = 0
|
| 1025 |
+
for c in counts:
|
| 1026 |
+
e = s + c
|
| 1027 |
+
if s > 0:
|
| 1028 |
+
mask[:, :, s:e, :s] = float("-inf")
|
| 1029 |
+
if e < L:
|
| 1030 |
+
mask[:, :, s:e, e:] = float("-inf")
|
| 1031 |
+
s = e
|
| 1032 |
+
return mask
|
| 1033 |
+
|
| 1034 |
+
def forward(self, embeddings, grid_thw, return_pixel_features=False, return_last_latent=False):
|
| 1035 |
+
device = embeddings.device
|
| 1036 |
+
grid_thw_list = ([(int(t), int(h), int(w)) for t, h, w in grid_thw.detach().cpu().numpy()]
|
| 1037 |
+
if isinstance(grid_thw, torch.Tensor) else list(grid_thw))
|
| 1038 |
+
|
| 1039 |
+
if embeddings.shape[-1] == self.codebook_dim:
|
| 1040 |
+
embeddings = self.post_quant_conv(embeddings)
|
| 1041 |
+
embeddings = self.post_quant_norm(embeddings)
|
| 1042 |
+
|
| 1043 |
+
unmerged = self.patch_unmerger(embeddings)
|
| 1044 |
+
if unmerged.dim() == 2:
|
| 1045 |
+
unmerged = unmerged.unsqueeze(0)
|
| 1046 |
+
B, L, D = unmerged.shape
|
| 1047 |
+
hidden_states = self.norm_in(unmerged)
|
| 1048 |
+
|
| 1049 |
+
positions_2d = self._build_2d_positions(grid_thw_list).to(device)
|
| 1050 |
+
seq_lens = torch.tensor([int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list],
|
| 1051 |
+
device=device, dtype=torch.int32)
|
| 1052 |
+
assert positions_2d.shape[0] == L, f"positions_2d {positions_2d.shape[0]} != L {L}"
|
| 1053 |
+
|
| 1054 |
+
last_latent = hidden_states.detach().squeeze(0) if return_last_latent else None
|
| 1055 |
+
enc_out = self.encoder(
|
| 1056 |
+
inputs_embeds=hidden_states,
|
| 1057 |
+
attention_mask=None,
|
| 1058 |
+
causal_attention_mask=None,
|
| 1059 |
+
output_attentions=False,
|
| 1060 |
+
output_hidden_states=False,
|
| 1061 |
+
return_dict=True,
|
| 1062 |
+
positions_2d=positions_2d,
|
| 1063 |
+
seq_lens=seq_lens,
|
| 1064 |
+
)
|
| 1065 |
+
hidden_states = enc_out.last_hidden_state
|
| 1066 |
+
|
| 1067 |
+
hidden_states = self.norm_out(hidden_states)
|
| 1068 |
+
pixel_features = self.decoder_head(hidden_states).squeeze(0)
|
| 1069 |
+
|
| 1070 |
+
out_imgs = (None if return_pixel_features else
|
| 1071 |
+
restore_spatial_structure_and_convert_to_images(
|
| 1072 |
+
pixel_features, grid_thw_list, self.patch_size,
|
| 1073 |
+
temporal_patch_size=self.temporal_patch_size, merge_size=self.spatial_merge_size))
|
| 1074 |
+
ret = {"images": out_imgs, "pixel_features": pixel_features}
|
| 1075 |
+
if last_latent is not None:
|
| 1076 |
+
ret["last_latent"] = last_latent
|
| 1077 |
+
return ret
|
parse_model_response.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import uuid
|
| 4 |
+
|
| 5 |
+
def parse_arguments(json_value):
|
| 6 |
+
"""
|
| 7 |
+
Attempt to parse a string as JSON
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
json_value: String to parse
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
tuple: (parsed_value, is_valid_json)
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
parsed_value = json.loads(json_value)
|
| 17 |
+
return parsed_value, True
|
| 18 |
+
except:
|
| 19 |
+
return json_value, False
|
| 20 |
+
|
| 21 |
+
def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
|
| 22 |
+
"""
|
| 23 |
+
Get the type definition of a tool parameter
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
func_name: Name of the function/tool
|
| 27 |
+
arg_key: Parameter key name
|
| 28 |
+
defined_tools: List of tool definitions
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
str or None: Type of the parameter ('string', 'object', 'array', 'integer', 'number', 'boolean')
|
| 32 |
+
"""
|
| 33 |
+
name2tool = {tool["name"]: tool for tool in defined_tools}
|
| 34 |
+
if func_name not in name2tool:
|
| 35 |
+
return None
|
| 36 |
+
tool = name2tool[func_name]
|
| 37 |
+
if "parameters" not in tool or "properties" not in tool["parameters"]:
|
| 38 |
+
return None
|
| 39 |
+
if arg_key not in tool["parameters"]["properties"]:
|
| 40 |
+
return None
|
| 41 |
+
return tool["parameters"]["properties"][arg_key].get("type")
|
| 42 |
+
|
| 43 |
+
def parse_model_response(response: str, defined_tools: list=[]):
|
| 44 |
+
"""
|
| 45 |
+
Parse model response to extract reasoning_content, content, and tool_calls
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
response: Raw response text from the model
|
| 49 |
+
defined_tools: List of tool definitions
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
dict: Message containing role, reasoning_content (optional), content (optional),
|
| 53 |
+
and tool_calls (optional)
|
| 54 |
+
"""
|
| 55 |
+
text = response
|
| 56 |
+
reasoning_content = None
|
| 57 |
+
content = None
|
| 58 |
+
tool_calls = []
|
| 59 |
+
|
| 60 |
+
formatted_tools = []
|
| 61 |
+
for tool in defined_tools:
|
| 62 |
+
if "function" in tool:
|
| 63 |
+
formatted_tools.append(tool['function'])
|
| 64 |
+
else:
|
| 65 |
+
formatted_tools.append(tool)
|
| 66 |
+
|
| 67 |
+
if '</longcat_think>' in text:
|
| 68 |
+
text = text.replace('<longcat_think>', '')
|
| 69 |
+
thinking_end = text.find('</longcat_think>')
|
| 70 |
+
reasoning_content = text[: thinking_end].strip()
|
| 71 |
+
text = text[thinking_end + len('</longcat_think>'):].lstrip()
|
| 72 |
+
|
| 73 |
+
assert '<longcat_think>' not in text, "Unclosed <longcat_think> tag found in remaining text"
|
| 74 |
+
assert '</longcat_think>' not in text, "Unexpected </longcat_think> tag found without opening tag"
|
| 75 |
+
|
| 76 |
+
if '<longcat_tool_call>' in text:
|
| 77 |
+
index = text.find('<longcat_tool_call>')
|
| 78 |
+
content = text[:index]
|
| 79 |
+
text = text[index:].strip()
|
| 80 |
+
else:
|
| 81 |
+
content = text
|
| 82 |
+
text = ""
|
| 83 |
+
|
| 84 |
+
open_tags = text.count('<longcat_tool_call>')
|
| 85 |
+
close_tags = text.count('</longcat_tool_call>')
|
| 86 |
+
assert open_tags == close_tags, \
|
| 87 |
+
f"Mismatched tool_call tags: {open_tags} opening tags, {close_tags} closing tags"
|
| 88 |
+
|
| 89 |
+
tool_call_strs = re.findall(
|
| 90 |
+
r'<longcat_tool_call>(.*?)</longcat_tool_call>',
|
| 91 |
+
text,
|
| 92 |
+
re.DOTALL
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
for call in tool_call_strs:
|
| 96 |
+
func_name_match = re.match(r'([^\n<]+)', call.strip())
|
| 97 |
+
assert func_name_match, f"Missing function name in tool call: {call[:100]}"
|
| 98 |
+
|
| 99 |
+
func_name = func_name_match.group(1).strip()
|
| 100 |
+
assert func_name, "Empty function name in tool call"
|
| 101 |
+
|
| 102 |
+
# Verify argument tags are properly paired
|
| 103 |
+
arg_key_count = call.count('<longcat_arg_key>')
|
| 104 |
+
arg_key_close_count = call.count('</longcat_arg_key>')
|
| 105 |
+
arg_value_count = call.count('<longcat_arg_value>')
|
| 106 |
+
arg_value_close_count = call.count('</longcat_arg_value>')
|
| 107 |
+
|
| 108 |
+
assert arg_key_count == arg_key_close_count, \
|
| 109 |
+
f"Mismatched arg_key tags in function {func_name}: {arg_key_count} opening, {arg_key_close_count} closing"
|
| 110 |
+
assert arg_value_count == arg_value_close_count, \
|
| 111 |
+
f"Mismatched arg_value tags in function {func_name}: {arg_value_count} opening, {arg_value_close_count} closing"
|
| 112 |
+
assert arg_key_count == arg_value_count, \
|
| 113 |
+
f"Mismatched arg_key and arg_value count in function {func_name}: {arg_key_count} keys, {arg_value_count} values"
|
| 114 |
+
|
| 115 |
+
pairs = re.findall(
|
| 116 |
+
r'<longcat_arg_key>(.*?)</longcat_arg_key>\s*<longcat_arg_value>(.*?)</longcat_arg_value>',
|
| 117 |
+
call,
|
| 118 |
+
re.DOTALL
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
assert len(pairs) == arg_key_count, \
|
| 122 |
+
f"Failed to parse all arguments in function {func_name}: expected {arg_key_count}, got {len(pairs)}"
|
| 123 |
+
|
| 124 |
+
arguments = {}
|
| 125 |
+
for arg_key, arg_value in pairs:
|
| 126 |
+
arg_key = arg_key.strip()
|
| 127 |
+
arg_value = arg_value.strip()
|
| 128 |
+
|
| 129 |
+
assert arg_key, f"Empty argument key in function {func_name}"
|
| 130 |
+
assert arg_key not in arguments, \
|
| 131 |
+
f"Duplicate argument key '{arg_key}' in function {func_name}"
|
| 132 |
+
|
| 133 |
+
arg_type = get_argument_type(func_name, arg_key, formatted_tools)
|
| 134 |
+
|
| 135 |
+
if arg_type and arg_type != 'string':
|
| 136 |
+
parsed_value, is_good_json = parse_arguments(arg_value)
|
| 137 |
+
arg_value = parsed_value
|
| 138 |
+
|
| 139 |
+
arguments[arg_key] = arg_value
|
| 140 |
+
|
| 141 |
+
tool_calls.append({
|
| 142 |
+
'id': "tool-call-" + str(uuid.uuid4()),
|
| 143 |
+
'type': "function",
|
| 144 |
+
'function': {
|
| 145 |
+
'name': func_name,
|
| 146 |
+
'arguments': arguments
|
| 147 |
+
}
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
message = {'role': 'assistant'}
|
| 151 |
+
|
| 152 |
+
if reasoning_content:
|
| 153 |
+
message['reasoning_content'] = reasoning_content
|
| 154 |
+
message['content'] = content
|
| 155 |
+
if tool_calls:
|
| 156 |
+
message['tool_calls'] = tool_calls
|
| 157 |
+
|
| 158 |
+
return message
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_longcat_next.LongcatNextProcessor"
|
| 4 |
+
},
|
| 5 |
+
"avg_pooler": 4,
|
| 6 |
+
"feature_extractor_type": "LongcatNextAudioProcessor",
|
| 7 |
+
"hop_length": 160,
|
| 8 |
+
"kernel_size": 3,
|
| 9 |
+
"max_audio_seconds": 30,
|
| 10 |
+
"max_pixels": 3211264,
|
| 11 |
+
"min_pixels": 50176,
|
| 12 |
+
"n_fft": 400,
|
| 13 |
+
"num_mel_bins": 128,
|
| 14 |
+
"processor_class": "LongcatNextProcessor",
|
| 15 |
+
"sampling_rate": 16000,
|
| 16 |
+
"spatial_merge_size": 2,
|
| 17 |
+
"split_overlap": 0.0,
|
| 18 |
+
"stride_size": 2
|
| 19 |
+
}
|
processing_longcat_next.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Union, List
|
| 3 |
+
from types import SimpleNamespace
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import librosa
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import AutoFeatureExtractor
|
| 10 |
+
from transformers.audio_utils import mel_filter_bank
|
| 11 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 12 |
+
from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
| 13 |
+
from transformers.processing_utils import (
|
| 14 |
+
AudioKwargs,
|
| 15 |
+
ImagesKwargs,
|
| 16 |
+
ProcessingKwargs,
|
| 17 |
+
ProcessorMixin,
|
| 18 |
+
VideosKwargs,
|
| 19 |
+
)
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LongcatNextProcessorKwargs(ProcessingKwargs, total=False):
|
| 26 |
+
images_kwargs: ImagesKwargs
|
| 27 |
+
videos_kwargs: VideosKwargs
|
| 28 |
+
audio_kwargs: AudioKwargs
|
| 29 |
+
_defaults = {
|
| 30 |
+
"text_kwargs": {
|
| 31 |
+
"padding": False,
|
| 32 |
+
"padding_side": "left",
|
| 33 |
+
"return_attention_mask": False,
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LongcatNextAudioProcessor(FeatureExtractionMixin):
|
| 39 |
+
|
| 40 |
+
def __init__(self, **kwargs):
|
| 41 |
+
super().__init__(**kwargs)
|
| 42 |
+
self.mel_filters = mel_filter_bank(
|
| 43 |
+
num_frequency_bins=1 + self.n_fft // 2,
|
| 44 |
+
num_mel_filters=self.num_mel_bins,
|
| 45 |
+
min_frequency=0.0,
|
| 46 |
+
max_frequency=self.sampling_rate / 2.0,
|
| 47 |
+
sampling_rate=self.sampling_rate,
|
| 48 |
+
norm="slaney",
|
| 49 |
+
mel_scale="slaney",
|
| 50 |
+
)
|
| 51 |
+
self.window = torch.hann_window(self.n_fft)
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def zero_mean_unit_var_norm(x):
|
| 55 |
+
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
|
| 56 |
+
|
| 57 |
+
def load_audio_waveform(self, uri, metadata=None, waveform_tensor=None, return_tensors=True, do_normalize=False):
|
| 58 |
+
if metadata is None or waveform_tensor is None:
|
| 59 |
+
# 使用 librosa 统一处理所有音频格式(包括 mp3, wav, flac 等)
|
| 60 |
+
# librosa.load 返回的已经是归一化的 float32 数据
|
| 61 |
+
waveform_np, sample_rate = librosa.load(uri, sr=None, mono=False)
|
| 62 |
+
|
| 63 |
+
# 转换为 tensor,确保维度为 (channels, samples)
|
| 64 |
+
if waveform_np.ndim == 1:
|
| 65 |
+
waveform_tensor = torch.from_numpy(waveform_np).unsqueeze(0)
|
| 66 |
+
else:
|
| 67 |
+
waveform_tensor = torch.from_numpy(waveform_np)
|
| 68 |
+
|
| 69 |
+
# 获取音频元信息
|
| 70 |
+
try:
|
| 71 |
+
sf_info = sf.info(uri)
|
| 72 |
+
metadata = SimpleNamespace(
|
| 73 |
+
sample_rate=sample_rate,
|
| 74 |
+
num_frames=waveform_tensor.shape[1],
|
| 75 |
+
num_channels=waveform_tensor.shape[0],
|
| 76 |
+
bits_per_sample=getattr(sf_info, 'bits_per_sample', 16),
|
| 77 |
+
encoding=getattr(sf_info, 'subtype', 'PCM_F')
|
| 78 |
+
)
|
| 79 |
+
except Exception:
|
| 80 |
+
# 如果 soundfile.info 失败,使用 librosa 提供的信息
|
| 81 |
+
metadata = SimpleNamespace(
|
| 82 |
+
sample_rate=sample_rate,
|
| 83 |
+
num_frames=waveform_tensor.shape[1],
|
| 84 |
+
num_channels=waveform_tensor.shape[0],
|
| 85 |
+
bits_per_sample=16,
|
| 86 |
+
encoding='PCM_F'
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
|
| 90 |
+
|
| 91 |
+
if self.sampling_rate != metadata.sample_rate:
|
| 92 |
+
# 使用 torch.functional 进行重采样
|
| 93 |
+
waveform_tensor = torch.nn.functional.interpolate(
|
| 94 |
+
waveform_tensor.unsqueeze(0),
|
| 95 |
+
size=int(waveform_tensor.shape[1] * self.sampling_rate / metadata.sample_rate),
|
| 96 |
+
mode='linear',
|
| 97 |
+
align_corners=False
|
| 98 |
+
).squeeze(0)
|
| 99 |
+
|
| 100 |
+
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
|
| 101 |
+
if metadata.num_channels > 1:
|
| 102 |
+
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
|
| 103 |
+
|
| 104 |
+
# normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现)
|
| 105 |
+
if do_normalize:
|
| 106 |
+
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
|
| 107 |
+
|
| 108 |
+
if return_tensors: # (channels, samples)
|
| 109 |
+
return waveform_tensor
|
| 110 |
+
else:
|
| 111 |
+
return waveform_tensor.numpy()
|
| 112 |
+
|
| 113 |
+
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
|
| 114 |
+
channels, wave_samples = waveform.shape
|
| 115 |
+
max_audio_samples = self.max_audio_seconds * self.sampling_rate
|
| 116 |
+
if wave_samples <= max_audio_samples or self.split_overlap < 0:
|
| 117 |
+
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
|
| 118 |
+
|
| 119 |
+
split_waveform, start = [], 0
|
| 120 |
+
while start < wave_samples: # 统一按秒数对齐overlap
|
| 121 |
+
if start > int(self.sampling_rate * self.split_overlap):
|
| 122 |
+
start -= int(self.sampling_rate * self.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
|
| 123 |
+
end = min(start + max_audio_samples, wave_samples)
|
| 124 |
+
if end - start>= self.n_fft: # 保证至少有一帧数据
|
| 125 |
+
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
|
| 126 |
+
start = end
|
| 127 |
+
return split_waveform
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def inference_output_length(self, input_length, kernel_size, stride_size, avg_pooler):
|
| 131 |
+
# for whisper + bridge
|
| 132 |
+
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
|
| 133 |
+
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
|
| 134 |
+
if avg_pooler > 1:
|
| 135 |
+
bridge_length = encoder_length // avg_pooler
|
| 136 |
+
return encoder_length, bridge_length
|
| 137 |
+
|
| 138 |
+
def extract_fbank_features(self, waveform):
|
| 139 |
+
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
| 140 |
+
channels, wave_samples = waveform.shape
|
| 141 |
+
assert(wave_samples >= self.n_fft)
|
| 142 |
+
valid_frame_nums = min(self.max_audio_seconds * self.sampling_rate // self.hop_length, wave_samples // self.hop_length + 1)
|
| 143 |
+
if wave_samples < self.max_audio_seconds * self.sampling_rate:
|
| 144 |
+
waveform = torch.nn.functional.pad(waveform, (0, self.max_audio_seconds * self.sampling_rate - wave_samples), "constant", 0)
|
| 145 |
+
else:
|
| 146 |
+
waveform = waveform[:, :self.max_audio_seconds * self.sampling_rate]
|
| 147 |
+
|
| 148 |
+
# window = torch.hann_window(self.n_fft)
|
| 149 |
+
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
|
| 150 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 151 |
+
|
| 152 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
| 153 |
+
mel_spec = mel_filters.T @ magnitudes
|
| 154 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 155 |
+
if waveform.dim() == 2:
|
| 156 |
+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
| 157 |
+
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
| 158 |
+
else:
|
| 159 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 160 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 161 |
+
|
| 162 |
+
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
|
| 163 |
+
log_spec[:, valid_frame_nums:] = 0.0 # pad0
|
| 164 |
+
|
| 165 |
+
return log_spec, valid_frame_nums
|
| 166 |
+
|
| 167 |
+
def process(self, audio_path, **kwargs):
|
| 168 |
+
metadata, waveform_tensors = None, None
|
| 169 |
+
waveforms = self.load_audio_waveform(audio_path, metadata, waveform_tensors, True)
|
| 170 |
+
waveforms = self.split_with_overlap(waveforms)
|
| 171 |
+
|
| 172 |
+
ret_audio, ret_encoder_length, ret_bridge_length = [], [], []
|
| 173 |
+
for i, waveform in enumerate(waveforms):
|
| 174 |
+
audio, input_length = self.extract_fbank_features(waveform)
|
| 175 |
+
encoder_length, bridge_length = self.inference_output_length(input_length, self.kernel_size, self.stride_size, self.avg_pooler)
|
| 176 |
+
if bridge_length <= 0:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
ret_audio.append(audio)
|
| 180 |
+
ret_encoder_length.append(encoder_length)
|
| 181 |
+
ret_bridge_length.append(bridge_length)
|
| 182 |
+
return ret_audio, ret_encoder_length, ret_bridge_length
|
| 183 |
+
|
| 184 |
+
def __call__(self, audio: Union[str, List[str]], **kwargs):
|
| 185 |
+
if isinstance(audio, str):
|
| 186 |
+
audio = [audio]
|
| 187 |
+
results = {
|
| 188 |
+
"audio": [],
|
| 189 |
+
"encoder_length": [],
|
| 190 |
+
"bridge_length": [],
|
| 191 |
+
}
|
| 192 |
+
for audio_path in audio:
|
| 193 |
+
audio, encoder_length, bridge_length = self.process(audio_path, **kwargs)
|
| 194 |
+
results["audio"].append(audio)
|
| 195 |
+
results["encoder_length"].append(encoder_length)
|
| 196 |
+
results["bridge_length"].append(bridge_length)
|
| 197 |
+
return results
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class LongcatNextProcessor(ProcessorMixin):
|
| 201 |
+
|
| 202 |
+
attributes = ["image_processor", "video_processor", "audio_processor", "tokenizer"]
|
| 203 |
+
|
| 204 |
+
image_processor_class = "Qwen2VLImageProcessor"
|
| 205 |
+
video_processor_class = "Qwen2VLImageProcessor"
|
| 206 |
+
audio_processor_class = "LongcatNextAudioProcessor"
|
| 207 |
+
tokenizer_class = "AutoTokenizer"
|
| 208 |
+
|
| 209 |
+
def __init__(self, image_processor=None, video_processor=None, audio_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
| 210 |
+
super().__init__(image_processor, video_processor, audio_processor, tokenizer, chat_template=chat_template)
|
| 211 |
+
init_token_list = [
|
| 212 |
+
"image_start_token", "image_end_token", "image_pad_token", "image_newline_token",
|
| 213 |
+
"audio_start_token", "audio_end_token", "audio_pad_token",
|
| 214 |
+
]
|
| 215 |
+
for attr in init_token_list:
|
| 216 |
+
token_str = self.tokenizer.init_kwargs.get(attr)
|
| 217 |
+
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
| 218 |
+
assert len(token_ids) == 1, (f"{attr}='{token_str}' encode to get {len(token_ids)} id(s) {token_ids}, expect 1 id")
|
| 219 |
+
setattr(self, f"{attr}", token_str)
|
| 220 |
+
setattr(self, f"{attr}_id", token_ids[0])
|
| 221 |
+
|
| 222 |
+
def __call__(
|
| 223 |
+
self,
|
| 224 |
+
text: str,
|
| 225 |
+
**kwargs,
|
| 226 |
+
) -> List["LongcatNextProcessorOutput"]:
|
| 227 |
+
|
| 228 |
+
if text is None:
|
| 229 |
+
raise ValueError("You need to specify either a `text` input to process.")
|
| 230 |
+
|
| 231 |
+
output_kwargs = self._merge_kwargs(
|
| 232 |
+
LongcatNextProcessorKwargs,
|
| 233 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 234 |
+
**kwargs,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
assert isinstance(text, str)
|
| 238 |
+
|
| 239 |
+
image_path_list = re.findall(rf"{self.image_start_token}(.*?){self.image_end_token}", text)
|
| 240 |
+
audio_path_list = re.findall(rf"{self.audio_start_token}(.*?){self.audio_end_token}", text)
|
| 241 |
+
|
| 242 |
+
if len(image_path_list) > 0:
|
| 243 |
+
images_inputs = self.image_processor(images=image_path_list, **output_kwargs["images_kwargs"])
|
| 244 |
+
image_grid_thw = images_inputs["image_grid_thw"]
|
| 245 |
+
for i, image_path in enumerate(image_path_list):
|
| 246 |
+
image_token_num = image_grid_thw[i][0] * (image_grid_thw[i][1]//self.image_processor.spatial_merge_size) * (image_grid_thw[i][2]//self.image_processor.spatial_merge_size)
|
| 247 |
+
text = text.replace(f"{self.image_start_token}{image_path}{self.image_end_token}", f"{self.image_start_token}{self.image_pad_token * image_token_num}{self.image_end_token}")
|
| 248 |
+
else:
|
| 249 |
+
images_inputs = {}
|
| 250 |
+
|
| 251 |
+
if len(audio_path_list) > 0:
|
| 252 |
+
audio_inputs = self.audio_processor(audio=audio_path_list, **output_kwargs["audio_kwargs"])
|
| 253 |
+
for i, audio_path in enumerate(audio_path_list):
|
| 254 |
+
audio_token_num = np.sum(audio_inputs["bridge_length"][i])
|
| 255 |
+
text = text.replace(f"{self.audio_start_token}{audio_path}{self.audio_end_token}", f"{self.audio_start_token}{self.audio_pad_token * audio_token_num}{self.audio_end_token}")
|
| 256 |
+
for key in audio_inputs:
|
| 257 |
+
audio_inputs[key] = [val for b_val in audio_inputs[key] for val in b_val]
|
| 258 |
+
else:
|
| 259 |
+
audio_inputs = {}
|
| 260 |
+
|
| 261 |
+
texts_inputs = self.tokenizer([text], **output_kwargs["text_kwargs"])
|
| 262 |
+
|
| 263 |
+
batch_feature_func = lambda x: BatchFeature(
|
| 264 |
+
data={**x},
|
| 265 |
+
tensor_type=kwargs.get("return_tensors"),
|
| 266 |
+
)
|
| 267 |
+
return (
|
| 268 |
+
batch_feature_func(texts_inputs),
|
| 269 |
+
batch_feature_func({k.replace("image", "visual"): v for k, v in images_inputs.items()}) if len(images_inputs) > 0 else None,
|
| 270 |
+
batch_feature_func(audio_inputs) if len(audio_inputs) > 0 else None,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class LongcatNextAudioProcessorConfig(PretrainedConfig):
|
| 275 |
+
pass
|
| 276 |
+
AutoFeatureExtractor.register(LongcatNextAudioProcessorConfig, LongcatNextAudioProcessor)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
__all__ = ["LongcatNextAudioProcessor", "LongcatNextProcessor"]
|
processor_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_map": {
|
| 3 |
+
"AutoProcessor": "processing_longcat_next.LongcatNextProcessor"
|
| 4 |
+
},
|
| 5 |
+
"processor_class": "LongcatNextProcessor"
|
| 6 |
+
}
|
quantization_config.json
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bits": 4,
|
| 3 |
+
"data_type": "int",
|
| 4 |
+
"group_size": 128,
|
| 5 |
+
"sym": true,
|
| 6 |
+
"batch_size": 1,
|
| 7 |
+
"gradient_accumulate_steps": 8,
|
| 8 |
+
"seqlen": 512,
|
| 9 |
+
"autoround_version": "0.13.0",
|
| 10 |
+
"block_name_to_quantize": "model.layers",
|
| 11 |
+
"quant_method": "auto-round",
|
| 12 |
+
"packing_format": "auto_round:auto_gptq",
|
| 13 |
+
"extra_config": {
|
| 14 |
+
"model.layers.0.mlp.router.classifier": {
|
| 15 |
+
"bits": 16,
|
| 16 |
+
"data_type": "float"
|
| 17 |
+
},
|
| 18 |
+
"model.layers.1.mlp.router.classifier": {
|
| 19 |
+
"bits": 16,
|
| 20 |
+
"data_type": "float"
|
| 21 |
+
},
|
| 22 |
+
"model.layers.2.mlp.router.classifier": {
|
| 23 |
+
"bits": 16,
|
| 24 |
+
"data_type": "float"
|
| 25 |
+
},
|
| 26 |
+
"model.layers.3.mlp.router.classifier": {
|
| 27 |
+
"bits": 16,
|
| 28 |
+
"data_type": "float"
|
| 29 |
+
},
|
| 30 |
+
"model.layers.4.mlp.router.classifier": {
|
| 31 |
+
"bits": 16,
|
| 32 |
+
"data_type": "float"
|
| 33 |
+
},
|
| 34 |
+
"model.layers.5.mlp.router.classifier": {
|
| 35 |
+
"bits": 16,
|
| 36 |
+
"data_type": "float"
|
| 37 |
+
},
|
| 38 |
+
"model.layers.6.mlp.router.classifier": {
|
| 39 |
+
"bits": 16,
|
| 40 |
+
"data_type": "float"
|
| 41 |
+
},
|
| 42 |
+
"model.layers.7.mlp.router.classifier": {
|
| 43 |
+
"bits": 16,
|
| 44 |
+
"data_type": "float"
|
| 45 |
+
},
|
| 46 |
+
"model.layers.8.mlp.router.classifier": {
|
| 47 |
+
"bits": 16,
|
| 48 |
+
"data_type": "float"
|
| 49 |
+
},
|
| 50 |
+
"model.layers.9.mlp.router.classifier": {
|
| 51 |
+
"bits": 16,
|
| 52 |
+
"data_type": "float"
|
| 53 |
+
},
|
| 54 |
+
"model.layers.10.mlp.router.classifier": {
|
| 55 |
+
"bits": 16,
|
| 56 |
+
"data_type": "float"
|
| 57 |
+
},
|
| 58 |
+
"model.layers.11.mlp.router.classifier": {
|
| 59 |
+
"bits": 16,
|
| 60 |
+
"data_type": "float"
|
| 61 |
+
},
|
| 62 |
+
"model.layers.12.mlp.router.classifier": {
|
| 63 |
+
"bits": 16,
|
| 64 |
+
"data_type": "float"
|
| 65 |
+
},
|
| 66 |
+
"model.layers.13.mlp.router.classifier": {
|
| 67 |
+
"bits": 16,
|
| 68 |
+
"data_type": "float"
|
| 69 |
+
},
|
| 70 |
+
".*classifier.*": {
|
| 71 |
+
"bits": 16,
|
| 72 |
+
"data_type": "float"
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
refiner_modules.py
ADDED
|
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------
|
| 2 |
+
# Standard / third-party imports shared by all sections
|
| 3 |
+
# ---------------------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
import itertools
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
from flash_attn import flash_attn_varlen_func # type: ignore
|
| 11 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # type: ignore
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch.nn import RMSNorm
|
| 17 |
+
|
| 18 |
+
from einops import rearrange, repeat
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 22 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 23 |
+
from diffusers.models.activations import get_activation
|
| 24 |
+
from diffusers.models.attention_processor import Attention
|
| 25 |
+
from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed
|
| 26 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 28 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 29 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def swiglu(x, y):
|
| 35 |
+
return F.silu(x.float(), inplace=False).to(x.dtype) * y
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TimestepEmbedding(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
in_channels: int,
|
| 42 |
+
time_embed_dim: int,
|
| 43 |
+
act_fn: str = "silu",
|
| 44 |
+
out_dim: int = None,
|
| 45 |
+
post_act_fn: Optional[str] = None,
|
| 46 |
+
cond_proj_dim=None,
|
| 47 |
+
sample_proj_bias=True,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 52 |
+
|
| 53 |
+
if cond_proj_dim is not None:
|
| 54 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 55 |
+
else:
|
| 56 |
+
self.cond_proj = None
|
| 57 |
+
|
| 58 |
+
self.act = get_activation(act_fn)
|
| 59 |
+
|
| 60 |
+
if out_dim is not None:
|
| 61 |
+
time_embed_dim_out = out_dim
|
| 62 |
+
else:
|
| 63 |
+
time_embed_dim_out = time_embed_dim
|
| 64 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 65 |
+
|
| 66 |
+
if post_act_fn is None:
|
| 67 |
+
self.post_act = None
|
| 68 |
+
else:
|
| 69 |
+
self.post_act = get_activation(post_act_fn)
|
| 70 |
+
|
| 71 |
+
self.initialize_weights()
|
| 72 |
+
|
| 73 |
+
def initialize_weights(self):
|
| 74 |
+
nn.init.normal_(self.linear_1.weight, std=0.02)
|
| 75 |
+
nn.init.zeros_(self.linear_1.bias)
|
| 76 |
+
nn.init.normal_(self.linear_2.weight, std=0.02)
|
| 77 |
+
nn.init.zeros_(self.linear_2.bias)
|
| 78 |
+
|
| 79 |
+
def forward(self, sample, condition=None):
|
| 80 |
+
if condition is not None:
|
| 81 |
+
sample = sample + self.cond_proj(condition)
|
| 82 |
+
sample = self.linear_1(sample)
|
| 83 |
+
if self.act is not None:
|
| 84 |
+
sample = self.act(sample)
|
| 85 |
+
sample = self.linear_2(sample)
|
| 86 |
+
if self.post_act is not None:
|
| 87 |
+
sample = self.post_act(sample)
|
| 88 |
+
return sample
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def apply_rotary_emb(
|
| 92 |
+
x: torch.Tensor,
|
| 93 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 94 |
+
use_real: bool = True,
|
| 95 |
+
use_real_unbind_dim: int = -1,
|
| 96 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 97 |
+
"""
|
| 98 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 99 |
+
"""
|
| 100 |
+
if use_real:
|
| 101 |
+
cos, sin = freqs_cis # [S, D]
|
| 102 |
+
cos = cos[None, None]
|
| 103 |
+
sin = sin[None, None]
|
| 104 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 105 |
+
|
| 106 |
+
if use_real_unbind_dim == -1:
|
| 107 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 108 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 109 |
+
elif use_real_unbind_dim == -2:
|
| 110 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
|
| 111 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 114 |
+
|
| 115 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 116 |
+
return out
|
| 117 |
+
else:
|
| 118 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
|
| 119 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 120 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 121 |
+
return x_out.type_as(x)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class TeaCacheParams:
|
| 126 |
+
"""
|
| 127 |
+
TeaCache parameters for Transformer2DModel.
|
| 128 |
+
See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding.
|
| 129 |
+
"""
|
| 130 |
+
previous_residual: Optional[torch.Tensor] = None
|
| 131 |
+
previous_modulated_inp: Optional[torch.Tensor] = None
|
| 132 |
+
accumulated_rel_l1_distance: float = 0
|
| 133 |
+
is_first_or_last_step: bool = False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def derivative_approximation(*args, **kwargs):
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def taylor_formula(*args, **kwargs):
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def taylor_cache_init(*args, **kwargs):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def cache_init(*args, **kwargs):
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def cal_type(*args, **kwargs):
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class LuminaRMSNormZero(nn.Module):
|
| 157 |
+
"""
|
| 158 |
+
Norm layer adaptive RMS normalization zero.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
embedding_dim: int,
|
| 164 |
+
norm_eps: float,
|
| 165 |
+
norm_elementwise_affine: bool,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.silu = nn.SiLU()
|
| 169 |
+
self.linear = nn.Linear(
|
| 170 |
+
min(embedding_dim, 1024),
|
| 171 |
+
4 * embedding_dim,
|
| 172 |
+
bias=True,
|
| 173 |
+
)
|
| 174 |
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
| 175 |
+
|
| 176 |
+
def forward(
|
| 177 |
+
self,
|
| 178 |
+
x: torch.Tensor,
|
| 179 |
+
emb: Optional[torch.Tensor] = None,
|
| 180 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 181 |
+
emb = self.linear(self.silu(emb))
|
| 182 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
| 183 |
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
| 184 |
+
return x, gate_msa, scale_mlp, gate_mlp
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class LuminaLayerNormContinuous(nn.Module):
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
embedding_dim: int,
|
| 191 |
+
conditioning_embedding_dim: int,
|
| 192 |
+
elementwise_affine=True,
|
| 193 |
+
eps=1e-5,
|
| 194 |
+
bias=True,
|
| 195 |
+
norm_type="layer_norm",
|
| 196 |
+
out_dim: Optional[int] = None,
|
| 197 |
+
):
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
self.silu = nn.SiLU()
|
| 201 |
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
| 202 |
+
|
| 203 |
+
if norm_type == "layer_norm":
|
| 204 |
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
| 205 |
+
elif norm_type == "rms_norm":
|
| 206 |
+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 209 |
+
|
| 210 |
+
self.linear_2 = None
|
| 211 |
+
if out_dim is not None:
|
| 212 |
+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
| 213 |
+
|
| 214 |
+
def forward(
|
| 215 |
+
self,
|
| 216 |
+
x: torch.Tensor,
|
| 217 |
+
conditioning_embedding: torch.Tensor,
|
| 218 |
+
) -> torch.Tensor:
|
| 219 |
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
| 220 |
+
scale = emb
|
| 221 |
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
| 222 |
+
if self.linear_2 is not None:
|
| 223 |
+
x = self.linear_2(x)
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class LuminaFeedForward(nn.Module):
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
dim: int,
|
| 231 |
+
inner_dim: int,
|
| 232 |
+
multiple_of: Optional[int] = 256,
|
| 233 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 234 |
+
):
|
| 235 |
+
super().__init__()
|
| 236 |
+
|
| 237 |
+
if ffn_dim_multiplier is not None:
|
| 238 |
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
| 239 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
| 240 |
+
|
| 241 |
+
self.linear_1 = nn.Linear(dim, inner_dim, bias=False)
|
| 242 |
+
self.linear_2 = nn.Linear(inner_dim, dim, bias=False)
|
| 243 |
+
self.linear_3 = nn.Linear(dim, inner_dim, bias=False)
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
h1, h2 = self.linear_1(x), self.linear_3(x)
|
| 247 |
+
return self.linear_2(swiglu(h1, h2))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
hidden_size: int = 4096,
|
| 254 |
+
text_feat_dim: int = 2048,
|
| 255 |
+
frequency_embedding_size: int = 256,
|
| 256 |
+
norm_eps: float = 1e-5,
|
| 257 |
+
timestep_scale: float = 1.0,
|
| 258 |
+
) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
|
| 261 |
+
self.time_proj = Timesteps(
|
| 262 |
+
num_channels=frequency_embedding_size,
|
| 263 |
+
flip_sin_to_cos=True,
|
| 264 |
+
downscale_freq_shift=0.0,
|
| 265 |
+
scale=timestep_scale,
|
| 266 |
+
)
|
| 267 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 268 |
+
in_channels=frequency_embedding_size,
|
| 269 |
+
time_embed_dim=min(hidden_size, 1024),
|
| 270 |
+
)
|
| 271 |
+
self.caption_embedder = nn.Sequential(
|
| 272 |
+
RMSNorm(text_feat_dim, eps=norm_eps),
|
| 273 |
+
nn.Linear(text_feat_dim, hidden_size, bias=True),
|
| 274 |
+
)
|
| 275 |
+
self._initialize_weights()
|
| 276 |
+
|
| 277 |
+
def _initialize_weights(self):
|
| 278 |
+
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
|
| 279 |
+
nn.init.zeros_(self.caption_embedder[1].bias)
|
| 280 |
+
|
| 281 |
+
def forward(
|
| 282 |
+
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
| 283 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 284 |
+
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
| 285 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 286 |
+
caption_embed = self.caption_embedder(text_hidden_states)
|
| 287 |
+
return time_embed, caption_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class AttnProcessorFlash2Varlen:
|
| 291 |
+
"""
|
| 292 |
+
Processor for implementing scaled dot-product attention with flash attention
|
| 293 |
+
and variable length sequences.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self) -> None:
|
| 297 |
+
pass
|
| 298 |
+
# if not is_flash_attn_available():
|
| 299 |
+
# raise ImportError(
|
| 300 |
+
# "AttnProcessorFlash2Varlen requires flash_attn. "
|
| 301 |
+
# "Please install flash_attn."
|
| 302 |
+
# )
|
| 303 |
+
|
| 304 |
+
def _upad_input(
|
| 305 |
+
self,
|
| 306 |
+
query_layer: torch.Tensor,
|
| 307 |
+
key_layer: torch.Tensor,
|
| 308 |
+
value_layer: torch.Tensor,
|
| 309 |
+
attention_mask: torch.Tensor,
|
| 310 |
+
query_length: int,
|
| 311 |
+
num_heads: int,
|
| 312 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 313 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 314 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 315 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 316 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 317 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 318 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
| 319 |
+
|
| 320 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 321 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 322 |
+
|
| 323 |
+
key_layer = index_first_axis(
|
| 324 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k,
|
| 325 |
+
)
|
| 326 |
+
value_layer = index_first_axis(
|
| 327 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if query_length == kv_seq_len:
|
| 331 |
+
query_layer = index_first_axis(
|
| 332 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k,
|
| 333 |
+
)
|
| 334 |
+
cu_seqlens_q = cu_seqlens_k
|
| 335 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 336 |
+
indices_q = indices_k
|
| 337 |
+
elif query_length == 1:
|
| 338 |
+
max_seqlen_in_batch_q = 1
|
| 339 |
+
cu_seqlens_q = torch.arange(
|
| 340 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 341 |
+
)
|
| 342 |
+
indices_q = cu_seqlens_q[:-1]
|
| 343 |
+
query_layer = query_layer.squeeze(1)
|
| 344 |
+
else:
|
| 345 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 346 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
| 347 |
+
query_layer, attention_mask
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
return (
|
| 351 |
+
query_layer, key_layer, value_layer, indices_q,
|
| 352 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 353 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def __call__(
|
| 357 |
+
self,
|
| 358 |
+
attn: Attention,
|
| 359 |
+
hidden_states: torch.Tensor,
|
| 360 |
+
encoder_hidden_states: torch.Tensor,
|
| 361 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 362 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 363 |
+
base_sequence_length: Optional[int] = None,
|
| 364 |
+
) -> torch.Tensor:
|
| 365 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 366 |
+
|
| 367 |
+
query = attn.to_q(hidden_states)
|
| 368 |
+
key = attn.to_k(encoder_hidden_states)
|
| 369 |
+
value = attn.to_v(encoder_hidden_states)
|
| 370 |
+
|
| 371 |
+
query_dim = query.shape[-1]
|
| 372 |
+
inner_dim = key.shape[-1]
|
| 373 |
+
head_dim = query_dim // attn.heads
|
| 374 |
+
dtype = query.dtype
|
| 375 |
+
kv_heads = inner_dim // head_dim
|
| 376 |
+
|
| 377 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 378 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 379 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 380 |
+
|
| 381 |
+
if attn.norm_q is not None:
|
| 382 |
+
query = attn.norm_q(query)
|
| 383 |
+
if attn.norm_k is not None:
|
| 384 |
+
key = attn.norm_k(key)
|
| 385 |
+
|
| 386 |
+
if image_rotary_emb is not None:
|
| 387 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 388 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 389 |
+
|
| 390 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 391 |
+
|
| 392 |
+
if base_sequence_length is not None:
|
| 393 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 394 |
+
else:
|
| 395 |
+
softmax_scale = attn.scale
|
| 396 |
+
|
| 397 |
+
(
|
| 398 |
+
query_states, key_states, value_states, indices_q,
|
| 399 |
+
cu_seq_lens, max_seq_lens,
|
| 400 |
+
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
|
| 401 |
+
|
| 402 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 403 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 404 |
+
|
| 405 |
+
if kv_heads < attn.heads:
|
| 406 |
+
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
| 407 |
+
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
| 408 |
+
|
| 409 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 410 |
+
query_states, key_states, value_states,
|
| 411 |
+
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 412 |
+
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 413 |
+
dropout_p=0.0, causal=False, softmax_scale=softmax_scale,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
|
| 417 |
+
hidden_states = hidden_states.flatten(-2)
|
| 418 |
+
hidden_states = hidden_states.type_as(query)
|
| 419 |
+
|
| 420 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 421 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 422 |
+
return hidden_states
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class AttnProcessor:
|
| 426 |
+
"""
|
| 427 |
+
Processor for implementing scaled dot-product attention (PyTorch 2.0+).
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
def __init__(self) -> None:
|
| 431 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 432 |
+
raise ImportError(
|
| 433 |
+
"AttnProcessor requires PyTorch 2.0. "
|
| 434 |
+
"Please upgrade PyTorch to version 2.0 or later."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def __call__(
|
| 438 |
+
self,
|
| 439 |
+
attn: Attention,
|
| 440 |
+
hidden_states: torch.Tensor,
|
| 441 |
+
encoder_hidden_states: torch.Tensor,
|
| 442 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 443 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 444 |
+
base_sequence_length: Optional[int] = None,
|
| 445 |
+
) -> torch.Tensor:
|
| 446 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 447 |
+
|
| 448 |
+
query = attn.to_q(hidden_states)
|
| 449 |
+
key = attn.to_k(encoder_hidden_states)
|
| 450 |
+
value = attn.to_v(encoder_hidden_states)
|
| 451 |
+
|
| 452 |
+
query_dim = query.shape[-1]
|
| 453 |
+
inner_dim = key.shape[-1]
|
| 454 |
+
head_dim = query_dim // attn.heads
|
| 455 |
+
dtype = query.dtype
|
| 456 |
+
kv_heads = inner_dim // head_dim
|
| 457 |
+
|
| 458 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 459 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 460 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 461 |
+
|
| 462 |
+
if attn.norm_q is not None:
|
| 463 |
+
query = attn.norm_q(query)
|
| 464 |
+
if attn.norm_k is not None:
|
| 465 |
+
key = attn.norm_k(key)
|
| 466 |
+
|
| 467 |
+
if image_rotary_emb is not None:
|
| 468 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 469 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 470 |
+
|
| 471 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 472 |
+
|
| 473 |
+
if base_sequence_length is not None:
|
| 474 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 475 |
+
else:
|
| 476 |
+
softmax_scale = attn.scale
|
| 477 |
+
|
| 478 |
+
if attention_mask is not None:
|
| 479 |
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
| 480 |
+
|
| 481 |
+
query = query.transpose(1, 2)
|
| 482 |
+
key = key.transpose(1, 2)
|
| 483 |
+
value = value.transpose(1, 2)
|
| 484 |
+
|
| 485 |
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
| 486 |
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
| 487 |
+
|
| 488 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 489 |
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
| 490 |
+
)
|
| 491 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 492 |
+
hidden_states = hidden_states.type_as(query)
|
| 493 |
+
|
| 494 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 495 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 496 |
+
return hidden_states
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class RotaryPosEmbed(nn.Module):
|
| 501 |
+
def __init__(
|
| 502 |
+
self,
|
| 503 |
+
theta: int,
|
| 504 |
+
axes_dim: Tuple[int, int, int],
|
| 505 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 506 |
+
patch_size: int = 2,
|
| 507 |
+
):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self.theta = theta
|
| 510 |
+
self.axes_dim = axes_dim
|
| 511 |
+
self.axes_lens = axes_lens
|
| 512 |
+
self.patch_size = patch_size
|
| 513 |
+
|
| 514 |
+
@staticmethod
|
| 515 |
+
def get_freqs_cis(
|
| 516 |
+
axes_dim: Tuple[int, int, int],
|
| 517 |
+
axes_lens: Tuple[int, int, int],
|
| 518 |
+
theta: int,
|
| 519 |
+
) -> List[torch.Tensor]:
|
| 520 |
+
freqs_cis = []
|
| 521 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 522 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
| 523 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
|
| 524 |
+
freqs_cis.append(emb)
|
| 525 |
+
return freqs_cis
|
| 526 |
+
|
| 527 |
+
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
|
| 528 |
+
device = ids.device
|
| 529 |
+
if ids.device.type == "mps":
|
| 530 |
+
ids = ids.to("cpu")
|
| 531 |
+
|
| 532 |
+
result = []
|
| 533 |
+
for i in range(len(self.axes_dim)):
|
| 534 |
+
freqs = freqs_cis[i].to(ids.device)
|
| 535 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
| 536 |
+
result.append(
|
| 537 |
+
torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)
|
| 538 |
+
)
|
| 539 |
+
return torch.cat(result, dim=-1).to(device)
|
| 540 |
+
|
| 541 |
+
def forward(
|
| 542 |
+
self,
|
| 543 |
+
freqs_cis,
|
| 544 |
+
attention_mask,
|
| 545 |
+
l_effective_ref_img_len,
|
| 546 |
+
l_effective_img_len,
|
| 547 |
+
ref_img_sizes,
|
| 548 |
+
img_sizes,
|
| 549 |
+
device,
|
| 550 |
+
):
|
| 551 |
+
batch_size = len(attention_mask)
|
| 552 |
+
p = self.patch_size
|
| 553 |
+
|
| 554 |
+
encoder_seq_len = attention_mask.shape[1]
|
| 555 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
| 556 |
+
|
| 557 |
+
seq_lengths = [
|
| 558 |
+
cap_len + sum(ref_img_len) + img_len
|
| 559 |
+
for cap_len, ref_img_len, img_len in zip(
|
| 560 |
+
l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len
|
| 561 |
+
)
|
| 562 |
+
]
|
| 563 |
+
|
| 564 |
+
max_seq_len = max(seq_lengths)
|
| 565 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 566 |
+
max_img_len = max(l_effective_img_len)
|
| 567 |
+
|
| 568 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 569 |
+
|
| 570 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 571 |
+
position_ids[i, :cap_seq_len] = repeat(
|
| 572 |
+
torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
pe_shift = cap_seq_len
|
| 576 |
+
pe_shift_len = cap_seq_len
|
| 577 |
+
|
| 578 |
+
if ref_img_sizes[i] is not None:
|
| 579 |
+
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
| 580 |
+
H, W = ref_img_size
|
| 581 |
+
ref_H_tokens, ref_W_tokens = H // p, W // p
|
| 582 |
+
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
| 583 |
+
|
| 584 |
+
row_ids = repeat(
|
| 585 |
+
torch.arange(ref_H_tokens, dtype=torch.int32, device=device),
|
| 586 |
+
"h -> h w", w=ref_W_tokens,
|
| 587 |
+
).flatten()
|
| 588 |
+
col_ids = repeat(
|
| 589 |
+
torch.arange(ref_W_tokens, dtype=torch.int32, device=device),
|
| 590 |
+
"w -> h w", h=ref_H_tokens,
|
| 591 |
+
).flatten()
|
| 592 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
| 593 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
| 594 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
| 595 |
+
|
| 596 |
+
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
| 597 |
+
pe_shift_len += ref_img_len
|
| 598 |
+
|
| 599 |
+
H, W = img_sizes[i]
|
| 600 |
+
H_tokens, W_tokens = H // p, W // p
|
| 601 |
+
assert H_tokens * W_tokens == l_effective_img_len[i]
|
| 602 |
+
|
| 603 |
+
row_ids = repeat(
|
| 604 |
+
torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens
|
| 605 |
+
).flatten()
|
| 606 |
+
col_ids = repeat(
|
| 607 |
+
torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens
|
| 608 |
+
).flatten()
|
| 609 |
+
|
| 610 |
+
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
| 611 |
+
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
| 612 |
+
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
| 613 |
+
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
| 614 |
+
|
| 615 |
+
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
|
| 616 |
+
|
| 617 |
+
cap_freqs_cis = torch.zeros(
|
| 618 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 619 |
+
)
|
| 620 |
+
ref_img_freqs_cis = torch.zeros(
|
| 621 |
+
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 622 |
+
)
|
| 623 |
+
img_freqs_cis = torch.zeros(
|
| 624 |
+
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(
|
| 628 |
+
zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)
|
| 629 |
+
):
|
| 630 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 631 |
+
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[
|
| 632 |
+
i, cap_seq_len:cap_seq_len + sum(ref_img_len)
|
| 633 |
+
]
|
| 634 |
+
img_freqs_cis[i, :img_len] = freqs_cis[
|
| 635 |
+
i,
|
| 636 |
+
cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len,
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
return (
|
| 640 |
+
cap_freqs_cis,
|
| 641 |
+
ref_img_freqs_cis,
|
| 642 |
+
img_freqs_cis,
|
| 643 |
+
freqs_cis,
|
| 644 |
+
l_effective_cap_len,
|
| 645 |
+
seq_lengths,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class TransformerBlock(nn.Module):
|
| 650 |
+
"""
|
| 651 |
+
Transformer block for refiner model.
|
| 652 |
+
"""
|
| 653 |
+
|
| 654 |
+
def __init__(
|
| 655 |
+
self,
|
| 656 |
+
dim: int,
|
| 657 |
+
num_attention_heads: int,
|
| 658 |
+
num_kv_heads: int,
|
| 659 |
+
multiple_of: int,
|
| 660 |
+
ffn_dim_multiplier: float,
|
| 661 |
+
norm_eps: float,
|
| 662 |
+
modulation: bool = True,
|
| 663 |
+
) -> None:
|
| 664 |
+
super().__init__()
|
| 665 |
+
self.head_dim = dim // num_attention_heads
|
| 666 |
+
self.modulation = modulation
|
| 667 |
+
|
| 668 |
+
try:
|
| 669 |
+
processor = AttnProcessorFlash2Varlen()
|
| 670 |
+
except ImportError:
|
| 671 |
+
processor = AttnProcessor()
|
| 672 |
+
|
| 673 |
+
self.attn = Attention(
|
| 674 |
+
query_dim=dim,
|
| 675 |
+
cross_attention_dim=None,
|
| 676 |
+
dim_head=dim // num_attention_heads,
|
| 677 |
+
qk_norm="rms_norm",
|
| 678 |
+
heads=num_attention_heads,
|
| 679 |
+
kv_heads=num_kv_heads,
|
| 680 |
+
eps=1e-5,
|
| 681 |
+
bias=False,
|
| 682 |
+
out_bias=False,
|
| 683 |
+
processor=processor,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
self.feed_forward = LuminaFeedForward(
|
| 687 |
+
dim=dim,
|
| 688 |
+
inner_dim=4 * dim,
|
| 689 |
+
multiple_of=multiple_of,
|
| 690 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if modulation:
|
| 694 |
+
self.norm1 = LuminaRMSNormZero(
|
| 695 |
+
embedding_dim=dim,
|
| 696 |
+
norm_eps=norm_eps,
|
| 697 |
+
norm_elementwise_affine=True,
|
| 698 |
+
)
|
| 699 |
+
else:
|
| 700 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
| 701 |
+
|
| 702 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 703 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
| 704 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 705 |
+
|
| 706 |
+
self.initialize_weights()
|
| 707 |
+
|
| 708 |
+
def initialize_weights(self) -> None:
|
| 709 |
+
nn.init.xavier_uniform_(self.attn.to_q.weight)
|
| 710 |
+
nn.init.xavier_uniform_(self.attn.to_k.weight)
|
| 711 |
+
nn.init.xavier_uniform_(self.attn.to_v.weight)
|
| 712 |
+
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
|
| 713 |
+
|
| 714 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
|
| 715 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
|
| 716 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
|
| 717 |
+
|
| 718 |
+
if self.modulation:
|
| 719 |
+
nn.init.zeros_(self.norm1.linear.weight)
|
| 720 |
+
nn.init.zeros_(self.norm1.linear.bias)
|
| 721 |
+
|
| 722 |
+
def forward(
|
| 723 |
+
self,
|
| 724 |
+
hidden_states: torch.Tensor,
|
| 725 |
+
attention_mask: torch.Tensor,
|
| 726 |
+
image_rotary_emb: torch.Tensor,
|
| 727 |
+
temb: Optional[torch.Tensor] = None,
|
| 728 |
+
) -> torch.Tensor:
|
| 729 |
+
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
|
| 730 |
+
if enable_taylorseer:
|
| 731 |
+
if self.modulation:
|
| 732 |
+
if temb is None:
|
| 733 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
| 734 |
+
|
| 735 |
+
if self.current['type'] == 'full':
|
| 736 |
+
self.current['module'] = 'total'
|
| 737 |
+
taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
|
| 738 |
+
|
| 739 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 740 |
+
attn_output = self.attn(
|
| 741 |
+
hidden_states=norm_hidden_states,
|
| 742 |
+
encoder_hidden_states=norm_hidden_states,
|
| 743 |
+
attention_mask=attention_mask,
|
| 744 |
+
image_rotary_emb=image_rotary_emb,
|
| 745 |
+
)
|
| 746 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 747 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 748 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 749 |
+
|
| 750 |
+
derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
|
| 751 |
+
|
| 752 |
+
elif self.current['type'] == 'Taylor':
|
| 753 |
+
self.current['module'] = 'total'
|
| 754 |
+
hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
|
| 755 |
+
else:
|
| 756 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 757 |
+
attn_output = self.attn(
|
| 758 |
+
hidden_states=norm_hidden_states,
|
| 759 |
+
encoder_hidden_states=norm_hidden_states,
|
| 760 |
+
attention_mask=attention_mask,
|
| 761 |
+
image_rotary_emb=image_rotary_emb,
|
| 762 |
+
)
|
| 763 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 764 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 765 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 766 |
+
else:
|
| 767 |
+
if self.modulation:
|
| 768 |
+
if temb is None:
|
| 769 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
| 770 |
+
|
| 771 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 772 |
+
attn_output = self.attn(
|
| 773 |
+
hidden_states=norm_hidden_states,
|
| 774 |
+
encoder_hidden_states=norm_hidden_states,
|
| 775 |
+
attention_mask=attention_mask,
|
| 776 |
+
image_rotary_emb=image_rotary_emb,
|
| 777 |
+
)
|
| 778 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 779 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 780 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 781 |
+
else:
|
| 782 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 783 |
+
attn_output = self.attn(
|
| 784 |
+
hidden_states=norm_hidden_states,
|
| 785 |
+
encoder_hidden_states=norm_hidden_states,
|
| 786 |
+
attention_mask=attention_mask,
|
| 787 |
+
image_rotary_emb=image_rotary_emb,
|
| 788 |
+
)
|
| 789 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 790 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 791 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 792 |
+
|
| 793 |
+
return hidden_states
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 797 |
+
"""
|
| 798 |
+
Transformer 2D Model.
|
| 799 |
+
"""
|
| 800 |
+
|
| 801 |
+
_supports_gradient_checkpointing = True
|
| 802 |
+
_no_split_modules = ["TransformerBlock"]
|
| 803 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
| 804 |
+
|
| 805 |
+
@register_to_config
|
| 806 |
+
def __init__(
|
| 807 |
+
self,
|
| 808 |
+
patch_size: int = 2,
|
| 809 |
+
in_channels: int = 16,
|
| 810 |
+
out_channels: Optional[int] = None,
|
| 811 |
+
hidden_size: int = 2304,
|
| 812 |
+
num_layers: int = 26,
|
| 813 |
+
num_refiner_layers: int = 2,
|
| 814 |
+
num_attention_heads: int = 24,
|
| 815 |
+
num_kv_heads: int = 8,
|
| 816 |
+
multiple_of: int = 256,
|
| 817 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 818 |
+
norm_eps: float = 1e-5,
|
| 819 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 820 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 821 |
+
text_feat_dim: int = 1024,
|
| 822 |
+
timestep_scale: float = 1.0,
|
| 823 |
+
) -> None:
|
| 824 |
+
super().__init__()
|
| 825 |
+
|
| 826 |
+
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
|
| 827 |
+
raise ValueError(
|
| 828 |
+
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
|
| 829 |
+
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.out_channels = out_channels or in_channels
|
| 833 |
+
|
| 834 |
+
self.rope_embedder = RotaryPosEmbed(
|
| 835 |
+
theta=10000,
|
| 836 |
+
axes_dim=axes_dim_rope,
|
| 837 |
+
axes_lens=axes_lens,
|
| 838 |
+
patch_size=patch_size,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
self.x_embedder = nn.Linear(
|
| 842 |
+
in_features=patch_size * patch_size * in_channels,
|
| 843 |
+
out_features=hidden_size,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
self.ref_image_patch_embedder = nn.Linear(
|
| 847 |
+
in_features=patch_size * patch_size * in_channels,
|
| 848 |
+
out_features=hidden_size,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 852 |
+
hidden_size=hidden_size,
|
| 853 |
+
text_feat_dim=text_feat_dim,
|
| 854 |
+
norm_eps=norm_eps,
|
| 855 |
+
timestep_scale=timestep_scale,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
self.noise_refiner = nn.ModuleList([
|
| 859 |
+
TransformerBlock(
|
| 860 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 861 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
|
| 862 |
+
)
|
| 863 |
+
for _ in range(num_refiner_layers)
|
| 864 |
+
])
|
| 865 |
+
|
| 866 |
+
self.ref_image_refiner = nn.ModuleList([
|
| 867 |
+
TransformerBlock(
|
| 868 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 869 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
|
| 870 |
+
)
|
| 871 |
+
for _ in range(num_refiner_layers)
|
| 872 |
+
])
|
| 873 |
+
|
| 874 |
+
self.context_refiner = nn.ModuleList([
|
| 875 |
+
TransformerBlock(
|
| 876 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 877 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False,
|
| 878 |
+
)
|
| 879 |
+
for _ in range(num_refiner_layers)
|
| 880 |
+
])
|
| 881 |
+
|
| 882 |
+
self.layers = nn.ModuleList([
|
| 883 |
+
TransformerBlock(
|
| 884 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 885 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
|
| 886 |
+
)
|
| 887 |
+
for _ in range(num_layers)
|
| 888 |
+
])
|
| 889 |
+
|
| 890 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 891 |
+
embedding_dim=hidden_size,
|
| 892 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 893 |
+
elementwise_affine=False,
|
| 894 |
+
eps=1e-6,
|
| 895 |
+
bias=True,
|
| 896 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size))
|
| 900 |
+
|
| 901 |
+
self.gradient_checkpointing = False
|
| 902 |
+
|
| 903 |
+
self.initialize_weights()
|
| 904 |
+
|
| 905 |
+
self.enable_teacache = False
|
| 906 |
+
self.teacache_rel_l1_thresh = 0.05
|
| 907 |
+
self.teacache_params = TeaCacheParams()
|
| 908 |
+
|
| 909 |
+
coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
|
| 910 |
+
self.rescale_func = np.poly1d(coefficients)
|
| 911 |
+
|
| 912 |
+
def initialize_weights(self) -> None:
|
| 913 |
+
nn.init.xavier_uniform_(self.x_embedder.weight)
|
| 914 |
+
nn.init.constant_(self.x_embedder.bias, 0.0)
|
| 915 |
+
|
| 916 |
+
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
| 917 |
+
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
| 918 |
+
|
| 919 |
+
nn.init.zeros_(self.norm_out.linear_1.weight)
|
| 920 |
+
nn.init.zeros_(self.norm_out.linear_1.bias)
|
| 921 |
+
nn.init.zeros_(self.norm_out.linear_2.weight)
|
| 922 |
+
nn.init.zeros_(self.norm_out.linear_2.bias)
|
| 923 |
+
|
| 924 |
+
nn.init.normal_(self.image_index_embedding, std=0.02)
|
| 925 |
+
|
| 926 |
+
def img_patch_embed_and_refine(
|
| 927 |
+
self,
|
| 928 |
+
hidden_states,
|
| 929 |
+
ref_image_hidden_states,
|
| 930 |
+
padded_img_mask,
|
| 931 |
+
padded_ref_img_mask,
|
| 932 |
+
noise_rotary_emb,
|
| 933 |
+
ref_img_rotary_emb,
|
| 934 |
+
l_effective_ref_img_len,
|
| 935 |
+
l_effective_img_len,
|
| 936 |
+
temb,
|
| 937 |
+
):
|
| 938 |
+
batch_size = len(hidden_states)
|
| 939 |
+
max_combined_img_len = max([
|
| 940 |
+
img_len + sum(ref_img_len)
|
| 941 |
+
for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)
|
| 942 |
+
])
|
| 943 |
+
|
| 944 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 945 |
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
| 946 |
+
|
| 947 |
+
for i in range(batch_size):
|
| 948 |
+
shift = 0
|
| 949 |
+
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
| 950 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = (
|
| 951 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :]
|
| 952 |
+
+ self.image_index_embedding[j]
|
| 953 |
+
)
|
| 954 |
+
shift += ref_img_len
|
| 955 |
+
|
| 956 |
+
for layer in self.noise_refiner:
|
| 957 |
+
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
| 958 |
+
|
| 959 |
+
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
| 960 |
+
num_ref_images = len(flat_l_effective_ref_img_len)
|
| 961 |
+
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
| 962 |
+
|
| 963 |
+
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
| 964 |
+
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(
|
| 965 |
+
num_ref_images, max_ref_img_len, self.config.hidden_size
|
| 966 |
+
)
|
| 967 |
+
batch_ref_img_rotary_emb = hidden_states.new_zeros(
|
| 968 |
+
num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype
|
| 969 |
+
)
|
| 970 |
+
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
| 971 |
+
|
| 972 |
+
idx = 0
|
| 973 |
+
for i in range(batch_size):
|
| 974 |
+
shift = 0
|
| 975 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 976 |
+
batch_ref_img_mask[idx, :ref_img_len] = True
|
| 977 |
+
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
| 978 |
+
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
| 979 |
+
batch_temb[idx] = temb[i]
|
| 980 |
+
shift += ref_img_len
|
| 981 |
+
idx += 1
|
| 982 |
+
|
| 983 |
+
for layer in self.ref_image_refiner:
|
| 984 |
+
batch_ref_image_hidden_states = layer(
|
| 985 |
+
batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
idx = 0
|
| 989 |
+
for i in range(batch_size):
|
| 990 |
+
shift = 0
|
| 991 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 992 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
| 993 |
+
shift += ref_img_len
|
| 994 |
+
idx += 1
|
| 995 |
+
|
| 996 |
+
combined_img_hidden_states = hidden_states.new_zeros(
|
| 997 |
+
batch_size, max_combined_img_len, self.config.hidden_size
|
| 998 |
+
)
|
| 999 |
+
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
| 1000 |
+
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
| 1001 |
+
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
| 1002 |
+
|
| 1003 |
+
return combined_img_hidden_states
|
| 1004 |
+
|
| 1005 |
+
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
| 1006 |
+
batch_size = len(hidden_states)
|
| 1007 |
+
p = self.config.patch_size
|
| 1008 |
+
device = hidden_states[0].device
|
| 1009 |
+
|
| 1010 |
+
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
| 1011 |
+
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
| 1012 |
+
|
| 1013 |
+
if ref_image_hidden_states is not None and len(ref_image_hidden_states) > 0:
|
| 1014 |
+
ref_img_sizes = [
|
| 1015 |
+
[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None
|
| 1016 |
+
for imgs in ref_image_hidden_states
|
| 1017 |
+
]
|
| 1018 |
+
l_effective_ref_img_len = [
|
| 1019 |
+
[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes]
|
| 1020 |
+
if _ref_img_sizes is not None else [0]
|
| 1021 |
+
for _ref_img_sizes in ref_img_sizes
|
| 1022 |
+
]
|
| 1023 |
+
else:
|
| 1024 |
+
ref_img_sizes = [None for _ in range(batch_size)]
|
| 1025 |
+
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
| 1026 |
+
|
| 1027 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 1028 |
+
max_img_len = max(l_effective_img_len)
|
| 1029 |
+
|
| 1030 |
+
flat_ref_img_hidden_states = []
|
| 1031 |
+
for i in range(batch_size):
|
| 1032 |
+
if ref_img_sizes[i] is not None:
|
| 1033 |
+
imgs = []
|
| 1034 |
+
for ref_img in ref_image_hidden_states[i]:
|
| 1035 |
+
C, H, W = ref_img.size()
|
| 1036 |
+
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 1037 |
+
imgs.append(ref_img)
|
| 1038 |
+
flat_ref_img_hidden_states.append(torch.cat(imgs, dim=0))
|
| 1039 |
+
else:
|
| 1040 |
+
flat_ref_img_hidden_states.append(None)
|
| 1041 |
+
|
| 1042 |
+
flat_hidden_states = []
|
| 1043 |
+
for i in range(batch_size):
|
| 1044 |
+
img = hidden_states[i]
|
| 1045 |
+
C, H, W = img.size()
|
| 1046 |
+
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 1047 |
+
flat_hidden_states.append(img)
|
| 1048 |
+
|
| 1049 |
+
padded_ref_img_hidden_states = torch.zeros(
|
| 1050 |
+
batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1],
|
| 1051 |
+
device=device, dtype=flat_hidden_states[0].dtype,
|
| 1052 |
+
)
|
| 1053 |
+
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
| 1054 |
+
for i in range(batch_size):
|
| 1055 |
+
if ref_img_sizes[i] is not None:
|
| 1056 |
+
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
| 1057 |
+
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
| 1058 |
+
|
| 1059 |
+
padded_hidden_states = torch.zeros(
|
| 1060 |
+
batch_size, max_img_len, flat_hidden_states[0].shape[-1],
|
| 1061 |
+
device=device, dtype=flat_hidden_states[0].dtype,
|
| 1062 |
+
)
|
| 1063 |
+
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
| 1064 |
+
for i in range(batch_size):
|
| 1065 |
+
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
| 1066 |
+
padded_img_mask[i, :l_effective_img_len[i]] = True
|
| 1067 |
+
|
| 1068 |
+
return (
|
| 1069 |
+
padded_hidden_states,
|
| 1070 |
+
padded_ref_img_hidden_states,
|
| 1071 |
+
padded_img_mask,
|
| 1072 |
+
padded_ref_img_mask,
|
| 1073 |
+
l_effective_ref_img_len,
|
| 1074 |
+
l_effective_img_len,
|
| 1075 |
+
ref_img_sizes,
|
| 1076 |
+
img_sizes,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
def forward(
|
| 1080 |
+
self,
|
| 1081 |
+
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
|
| 1082 |
+
timestep: torch.Tensor,
|
| 1083 |
+
text_hidden_states: torch.Tensor,
|
| 1084 |
+
freqs_cis: torch.Tensor,
|
| 1085 |
+
text_attention_mask: torch.Tensor,
|
| 1086 |
+
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
| 1087 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1088 |
+
return_dict: bool = False,
|
| 1089 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 1090 |
+
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
|
| 1091 |
+
if enable_taylorseer:
|
| 1092 |
+
cal_type(self.cache_dic, self.current)
|
| 1093 |
+
|
| 1094 |
+
if attention_kwargs is not None:
|
| 1095 |
+
attention_kwargs = attention_kwargs.copy()
|
| 1096 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 1097 |
+
else:
|
| 1098 |
+
lora_scale = 1.0
|
| 1099 |
+
|
| 1100 |
+
if USE_PEFT_BACKEND:
|
| 1101 |
+
scale_lora_layers(self, lora_scale)
|
| 1102 |
+
else:
|
| 1103 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 1104 |
+
logger.warning(
|
| 1105 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
batch_size = len(hidden_states)
|
| 1109 |
+
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
| 1110 |
+
|
| 1111 |
+
if is_hidden_states_tensor:
|
| 1112 |
+
assert hidden_states.ndim == 4
|
| 1113 |
+
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
| 1114 |
+
|
| 1115 |
+
device = hidden_states[0].device
|
| 1116 |
+
|
| 1117 |
+
assert isinstance(text_hidden_states, torch.Tensor), \
|
| 1118 |
+
f"text_hidden_states must be Tensor, got {type(text_hidden_states)}. " \
|
| 1119 |
+
f"Check if freqs_cis and text_hidden_states are swapped in the caller."
|
| 1120 |
+
|
| 1121 |
+
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
| 1122 |
+
|
| 1123 |
+
(
|
| 1124 |
+
hidden_states,
|
| 1125 |
+
ref_image_hidden_states,
|
| 1126 |
+
img_mask,
|
| 1127 |
+
ref_img_mask,
|
| 1128 |
+
l_effective_ref_img_len,
|
| 1129 |
+
l_effective_img_len,
|
| 1130 |
+
ref_img_sizes,
|
| 1131 |
+
img_sizes,
|
| 1132 |
+
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
| 1133 |
+
|
| 1134 |
+
(
|
| 1135 |
+
context_rotary_emb,
|
| 1136 |
+
ref_img_rotary_emb,
|
| 1137 |
+
noise_rotary_emb,
|
| 1138 |
+
rotary_emb,
|
| 1139 |
+
encoder_seq_lengths,
|
| 1140 |
+
seq_lengths,
|
| 1141 |
+
) = self.rope_embedder(
|
| 1142 |
+
freqs_cis,
|
| 1143 |
+
text_attention_mask,
|
| 1144 |
+
l_effective_ref_img_len,
|
| 1145 |
+
l_effective_img_len,
|
| 1146 |
+
ref_img_sizes,
|
| 1147 |
+
img_sizes,
|
| 1148 |
+
device,
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
# 2. Context refinement
|
| 1152 |
+
for layer in self.context_refiner:
|
| 1153 |
+
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
| 1154 |
+
|
| 1155 |
+
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
| 1156 |
+
hidden_states,
|
| 1157 |
+
ref_image_hidden_states,
|
| 1158 |
+
img_mask,
|
| 1159 |
+
ref_img_mask,
|
| 1160 |
+
noise_rotary_emb,
|
| 1161 |
+
ref_img_rotary_emb,
|
| 1162 |
+
l_effective_ref_img_len,
|
| 1163 |
+
l_effective_img_len,
|
| 1164 |
+
temb,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
# 3. Joint Transformer blocks
|
| 1168 |
+
max_seq_len = max(seq_lengths)
|
| 1169 |
+
|
| 1170 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
| 1171 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
| 1172 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 1173 |
+
attention_mask[i, :seq_len] = True
|
| 1174 |
+
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
|
| 1175 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
| 1176 |
+
|
| 1177 |
+
hidden_states = joint_hidden_states
|
| 1178 |
+
|
| 1179 |
+
if self.enable_teacache:
|
| 1180 |
+
teacache_hidden_states = hidden_states.clone()
|
| 1181 |
+
teacache_temb = temb.clone()
|
| 1182 |
+
modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
|
| 1183 |
+
if self.teacache_params.is_first_or_last_step:
|
| 1184 |
+
should_calc = True
|
| 1185 |
+
self.teacache_params.accumulated_rel_l1_distance = 0
|
| 1186 |
+
else:
|
| 1187 |
+
self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
|
| 1188 |
+
((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean()
|
| 1189 |
+
/ self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
|
| 1190 |
+
)
|
| 1191 |
+
if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
|
| 1192 |
+
should_calc = False
|
| 1193 |
+
else:
|
| 1194 |
+
should_calc = True
|
| 1195 |
+
self.teacache_params.accumulated_rel_l1_distance = 0
|
| 1196 |
+
self.teacache_params.previous_modulated_inp = modulated_inp
|
| 1197 |
+
|
| 1198 |
+
if self.enable_teacache:
|
| 1199 |
+
if not should_calc:
|
| 1200 |
+
hidden_states += self.teacache_params.previous_residual
|
| 1201 |
+
else:
|
| 1202 |
+
ori_hidden_states = hidden_states.clone()
|
| 1203 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 1204 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1205 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1206 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
| 1207 |
+
)
|
| 1208 |
+
else:
|
| 1209 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 1210 |
+
self.teacache_params.previous_residual = hidden_states - ori_hidden_states
|
| 1211 |
+
else:
|
| 1212 |
+
if enable_taylorseer:
|
| 1213 |
+
self.current['stream'] = 'layers_stream'
|
| 1214 |
+
|
| 1215 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 1216 |
+
if enable_taylorseer:
|
| 1217 |
+
layer.current = self.current
|
| 1218 |
+
layer.cache_dic = self.cache_dic
|
| 1219 |
+
layer.enable_taylorseer = True
|
| 1220 |
+
self.current['layer'] = layer_idx
|
| 1221 |
+
|
| 1222 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1223 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1224 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
| 1225 |
+
)
|
| 1226 |
+
else:
|
| 1227 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 1228 |
+
|
| 1229 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1230 |
+
|
| 1231 |
+
p = self.config.patch_size
|
| 1232 |
+
output = []
|
| 1233 |
+
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
| 1234 |
+
height, width = img_size
|
| 1235 |
+
output.append(rearrange(
|
| 1236 |
+
hidden_states[i][seq_len - img_len:seq_len],
|
| 1237 |
+
'(h w) (p1 p2 c) -> c (h p1) (w p2)',
|
| 1238 |
+
h=height // p, w=width // p, p1=p, p2=p,
|
| 1239 |
+
))
|
| 1240 |
+
if is_hidden_states_tensor:
|
| 1241 |
+
output = torch.stack(output, dim=0)
|
| 1242 |
+
|
| 1243 |
+
if USE_PEFT_BACKEND:
|
| 1244 |
+
unscale_lora_layers(self, lora_scale)
|
| 1245 |
+
|
| 1246 |
+
if enable_taylorseer:
|
| 1247 |
+
self.current['step'] += 1
|
| 1248 |
+
|
| 1249 |
+
if not return_dict:
|
| 1250 |
+
return output
|
| 1251 |
+
return Transformer2DModelOutput(sample=output)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
# ---------------------------------------------------------------------------
|
| 1255 |
+
# FlowMatch Euler Discrete Scheduler (merged from scheduling_flow_match_euler_discrete.py)
|
| 1256 |
+
# ---------------------------------------------------------------------------
|
| 1257 |
+
|
| 1258 |
+
@dataclass
|
| 1259 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 1260 |
+
prev_sample: torch.FloatTensor
|
| 1261 |
+
|
| 1262 |
+
|
| 1263 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 1264 |
+
_compatibles = []
|
| 1265 |
+
order = 1
|
| 1266 |
+
|
| 1267 |
+
@register_to_config
|
| 1268 |
+
def __init__(self, num_train_timesteps: int = 1000, dynamic_time_shift: bool = False):
|
| 1269 |
+
timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
|
| 1270 |
+
self.timesteps = timesteps
|
| 1271 |
+
self._step_index = None
|
| 1272 |
+
self._begin_index = None
|
| 1273 |
+
|
| 1274 |
+
@property
|
| 1275 |
+
def step_index(self):
|
| 1276 |
+
return self._step_index
|
| 1277 |
+
|
| 1278 |
+
@property
|
| 1279 |
+
def begin_index(self):
|
| 1280 |
+
return self._begin_index
|
| 1281 |
+
|
| 1282 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 1283 |
+
self._begin_index = begin_index
|
| 1284 |
+
|
| 1285 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 1286 |
+
if schedule_timesteps is None:
|
| 1287 |
+
schedule_timesteps = self._timesteps
|
| 1288 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 1289 |
+
pos = 1 if len(indices) > 1 else 0
|
| 1290 |
+
return indices[pos].item()
|
| 1291 |
+
|
| 1292 |
+
def set_timesteps(self, num_inference_steps=None, device=None, timesteps=None, num_tokens=None):
|
| 1293 |
+
if timesteps is None:
|
| 1294 |
+
self.num_inference_steps = num_inference_steps
|
| 1295 |
+
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
| 1296 |
+
if self.config.dynamic_time_shift and num_tokens is not None:
|
| 1297 |
+
m = np.sqrt(num_tokens) / 40
|
| 1298 |
+
timesteps = timesteps / (m - m * timesteps + timesteps)
|
| 1299 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
| 1300 |
+
_timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
|
| 1301 |
+
self.timesteps = timesteps
|
| 1302 |
+
self._timesteps = _timesteps
|
| 1303 |
+
self._step_index = None
|
| 1304 |
+
self._begin_index = None
|
| 1305 |
+
|
| 1306 |
+
def _init_step_index(self, timestep):
|
| 1307 |
+
if self.begin_index is None:
|
| 1308 |
+
if isinstance(timestep, torch.Tensor):
|
| 1309 |
+
timestep = timestep.to(self.timesteps.device)
|
| 1310 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 1311 |
+
else:
|
| 1312 |
+
self._step_index = self._begin_index
|
| 1313 |
+
|
| 1314 |
+
def step(self, model_output, timestep, sample, generator=None, return_dict=True):
|
| 1315 |
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
| 1316 |
+
raise ValueError("Pass scheduler.timesteps values, not integer indices.")
|
| 1317 |
+
if self.step_index is None:
|
| 1318 |
+
self._init_step_index(timestep)
|
| 1319 |
+
sample = sample.to(torch.float32)
|
| 1320 |
+
t = self._timesteps[self.step_index]
|
| 1321 |
+
t_next = self._timesteps[self.step_index + 1]
|
| 1322 |
+
prev_sample = sample + (t_next - t) * model_output
|
| 1323 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 1324 |
+
self._step_index += 1
|
| 1325 |
+
if not return_dict:
|
| 1326 |
+
return (prev_sample,)
|
| 1327 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 1328 |
+
|
| 1329 |
+
def __len__(self):
|
| 1330 |
+
return self.config.num_train_timesteps
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<mask_131048>",
|
| 4 |
+
"<mask_131049>",
|
| 5 |
+
"<mask_131050>",
|
| 6 |
+
"<mask_131051>",
|
| 7 |
+
"<mask_131052>",
|
| 8 |
+
"<mask_131053>",
|
| 9 |
+
"<mask_131054>",
|
| 10 |
+
"<mask_131055>",
|
| 11 |
+
"<mask_131056>",
|
| 12 |
+
"<mask_131057>",
|
| 13 |
+
"<mask_131058>",
|
| 14 |
+
"<mask_131059>",
|
| 15 |
+
"<mask_131060>",
|
| 16 |
+
"<mask_131061>",
|
| 17 |
+
"<mask_131062>",
|
| 18 |
+
"<mask_131063>",
|
| 19 |
+
"<mask_131064>",
|
| 20 |
+
"<mask_131065>",
|
| 21 |
+
"<longcat_img_token_size>",
|
| 22 |
+
"</longcat_img_token_size>",
|
| 23 |
+
"<mask_131068>",
|
| 24 |
+
"<mask_131069>",
|
| 25 |
+
"<mask_131070>",
|
| 26 |
+
"<mask_131071>",
|
| 27 |
+
"<longcat_point_start>",
|
| 28 |
+
"<longcat_point_end>",
|
| 29 |
+
"<longcat_point_delim>",
|
| 30 |
+
"<longcat_polygon_start>",
|
| 31 |
+
"<longcat_polygon_end>",
|
| 32 |
+
"<mask_131077>",
|
| 33 |
+
"<mask_131078>",
|
| 34 |
+
"<longcat_audio_start>",
|
| 35 |
+
"<longcat_audio_end>",
|
| 36 |
+
"<longcat_audio_pad>",
|
| 37 |
+
"<longcat_img_start>",
|
| 38 |
+
"<longcat_img_end>",
|
| 39 |
+
"<longcat_img_pad>",
|
| 40 |
+
"<longcat_img_newline>",
|
| 41 |
+
"<longcat_box_start>",
|
| 42 |
+
"<longcat_box_end>",
|
| 43 |
+
"<longcat_box_delim>",
|
| 44 |
+
"<longcat_ref_start>",
|
| 45 |
+
"<longcat_ref_end>",
|
| 46 |
+
"<longcat_img_delim>",
|
| 47 |
+
"<longcat_audio_delim>",
|
| 48 |
+
"<longcat_video_palce>",
|
| 49 |
+
"<longcat_video_start>",
|
| 50 |
+
"<longcat_video_end>",
|
| 51 |
+
"<longcat_audiotext_start>",
|
| 52 |
+
"<longcat_audiotext_end>",
|
| 53 |
+
"<longcat_audiotext_pad>",
|
| 54 |
+
"<longcat_audiogen_start>",
|
| 55 |
+
"<longcat_audiogen_end>"
|
| 56 |
+
],
|
| 57 |
+
"bos_token": {
|
| 58 |
+
"content": "<longcat_s>",
|
| 59 |
+
"lstrip": false,
|
| 60 |
+
"normalized": false,
|
| 61 |
+
"rstrip": false,
|
| 62 |
+
"single_word": false
|
| 63 |
+
},
|
| 64 |
+
"eos_token": {
|
| 65 |
+
"content": "</longcat_s>",
|
| 66 |
+
"lstrip": false,
|
| 67 |
+
"normalized": false,
|
| 68 |
+
"rstrip": false,
|
| 69 |
+
"single_word": false
|
| 70 |
+
},
|
| 71 |
+
"pad_token": {
|
| 72 |
+
"content": "<longcat_pad>",
|
| 73 |
+
"lstrip": false,
|
| 74 |
+
"normalized": false,
|
| 75 |
+
"rstrip": false,
|
| 76 |
+
"single_word": false
|
| 77 |
+
},
|
| 78 |
+
"unk_token": {
|
| 79 |
+
"content": "<longcat_unk>",
|
| 80 |
+
"lstrip": false,
|
| 81 |
+
"normalized": false,
|
| 82 |
+
"rstrip": false,
|
| 83 |
+
"single_word": false
|
| 84 |
+
}
|
| 85 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,2300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": true,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<longcat_unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<longcat_s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</longcat_s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"3": {
|
| 31 |
+
"content": "<longcat_pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
},
|
| 38 |
+
"4": {
|
| 39 |
+
"content": "<shift_unk>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false,
|
| 44 |
+
"special": true
|
| 45 |
+
},
|
| 46 |
+
"5": {
|
| 47 |
+
"content": "<shift_s>",
|
| 48 |
+
"lstrip": false,
|
| 49 |
+
"normalized": false,
|
| 50 |
+
"rstrip": false,
|
| 51 |
+
"single_word": false,
|
| 52 |
+
"special": true
|
| 53 |
+
},
|
| 54 |
+
"6": {
|
| 55 |
+
"content": "</shift_s>",
|
| 56 |
+
"lstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"rstrip": false,
|
| 59 |
+
"single_word": false,
|
| 60 |
+
"special": true
|
| 61 |
+
},
|
| 62 |
+
"7": {
|
| 63 |
+
"content": "<shift_pad>",
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"normalized": false,
|
| 66 |
+
"rstrip": false,
|
| 67 |
+
"single_word": false,
|
| 68 |
+
"special": true
|
| 69 |
+
},
|
| 70 |
+
"8": {
|
| 71 |
+
"content": "<mask_0>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": false,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false,
|
| 76 |
+
"special": true
|
| 77 |
+
},
|
| 78 |
+
"9": {
|
| 79 |
+
"content": "<reponame>",
|
| 80 |
+
"lstrip": false,
|
| 81 |
+
"normalized": false,
|
| 82 |
+
"rstrip": false,
|
| 83 |
+
"single_word": false,
|
| 84 |
+
"special": true
|
| 85 |
+
},
|
| 86 |
+
"10": {
|
| 87 |
+
"content": "<filename>",
|
| 88 |
+
"lstrip": false,
|
| 89 |
+
"normalized": false,
|
| 90 |
+
"rstrip": false,
|
| 91 |
+
"single_word": false,
|
| 92 |
+
"special": true
|
| 93 |
+
},
|
| 94 |
+
"11": {
|
| 95 |
+
"content": "<gh_stars>",
|
| 96 |
+
"lstrip": false,
|
| 97 |
+
"normalized": false,
|
| 98 |
+
"rstrip": false,
|
| 99 |
+
"single_word": false,
|
| 100 |
+
"special": true
|
| 101 |
+
},
|
| 102 |
+
"12": {
|
| 103 |
+
"content": "<issue_start>",
|
| 104 |
+
"lstrip": false,
|
| 105 |
+
"normalized": false,
|
| 106 |
+
"rstrip": false,
|
| 107 |
+
"single_word": false,
|
| 108 |
+
"special": true
|
| 109 |
+
},
|
| 110 |
+
"13": {
|
| 111 |
+
"content": "<issue_comment>",
|
| 112 |
+
"lstrip": false,
|
| 113 |
+
"normalized": false,
|
| 114 |
+
"rstrip": false,
|
| 115 |
+
"single_word": false,
|
| 116 |
+
"special": true
|
| 117 |
+
},
|
| 118 |
+
"14": {
|
| 119 |
+
"content": "<issue_closed>",
|
| 120 |
+
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
+
"rstrip": false,
|
| 123 |
+
"single_word": false,
|
| 124 |
+
"special": true
|
| 125 |
+
},
|
| 126 |
+
"15": {
|
| 127 |
+
"content": "<jupyter_start>",
|
| 128 |
+
"lstrip": false,
|
| 129 |
+
"normalized": false,
|
| 130 |
+
"rstrip": false,
|
| 131 |
+
"single_word": false,
|
| 132 |
+
"special": true
|
| 133 |
+
},
|
| 134 |
+
"16": {
|
| 135 |
+
"content": "<jupyter_text>",
|
| 136 |
+
"lstrip": false,
|
| 137 |
+
"normalized": false,
|
| 138 |
+
"rstrip": false,
|
| 139 |
+
"single_word": false,
|
| 140 |
+
"special": true
|
| 141 |
+
},
|
| 142 |
+
"17": {
|
| 143 |
+
"content": "<jupyter_code>",
|
| 144 |
+
"lstrip": false,
|
| 145 |
+
"normalized": false,
|
| 146 |
+
"rstrip": false,
|
| 147 |
+
"single_word": false,
|
| 148 |
+
"special": true
|
| 149 |
+
},
|
| 150 |
+
"18": {
|
| 151 |
+
"content": "<jupyter_output>",
|
| 152 |
+
"lstrip": false,
|
| 153 |
+
"normalized": false,
|
| 154 |
+
"rstrip": false,
|
| 155 |
+
"single_word": false,
|
| 156 |
+
"special": true
|
| 157 |
+
},
|
| 158 |
+
"19": {
|
| 159 |
+
"content": "<empty_output>",
|
| 160 |
+
"lstrip": false,
|
| 161 |
+
"normalized": false,
|
| 162 |
+
"rstrip": false,
|
| 163 |
+
"single_word": false,
|
| 164 |
+
"special": true
|
| 165 |
+
},
|
| 166 |
+
"20": {
|
| 167 |
+
"content": "<commit_before>",
|
| 168 |
+
"lstrip": false,
|
| 169 |
+
"normalized": false,
|
| 170 |
+
"rstrip": false,
|
| 171 |
+
"single_word": false,
|
| 172 |
+
"special": true
|
| 173 |
+
},
|
| 174 |
+
"21": {
|
| 175 |
+
"content": "<commit_msg>",
|
| 176 |
+
"lstrip": false,
|
| 177 |
+
"normalized": false,
|
| 178 |
+
"rstrip": false,
|
| 179 |
+
"single_word": false,
|
| 180 |
+
"special": true
|
| 181 |
+
},
|
| 182 |
+
"22": {
|
| 183 |
+
"content": "<commit_after>",
|
| 184 |
+
"lstrip": false,
|
| 185 |
+
"normalized": false,
|
| 186 |
+
"rstrip": false,
|
| 187 |
+
"single_word": false,
|
| 188 |
+
"special": true
|
| 189 |
+
},
|
| 190 |
+
"23": {
|
| 191 |
+
"content": "<program_lang>",
|
| 192 |
+
"lstrip": false,
|
| 193 |
+
"normalized": false,
|
| 194 |
+
"rstrip": false,
|
| 195 |
+
"single_word": false,
|
| 196 |
+
"special": true
|
| 197 |
+
},
|
| 198 |
+
"24": {
|
| 199 |
+
"content": "<|image_placeholder|>",
|
| 200 |
+
"lstrip": false,
|
| 201 |
+
"normalized": false,
|
| 202 |
+
"rstrip": false,
|
| 203 |
+
"single_word": false,
|
| 204 |
+
"special": true
|
| 205 |
+
},
|
| 206 |
+
"25": {
|
| 207 |
+
"content": "<|url_placeholder|>",
|
| 208 |
+
"lstrip": false,
|
| 209 |
+
"normalized": false,
|
| 210 |
+
"rstrip": false,
|
| 211 |
+
"single_word": false,
|
| 212 |
+
"special": true
|
| 213 |
+
},
|
| 214 |
+
"26": {
|
| 215 |
+
"content": "<|hyperlink_placeholder|>",
|
| 216 |
+
"lstrip": false,
|
| 217 |
+
"normalized": false,
|
| 218 |
+
"rstrip": false,
|
| 219 |
+
"single_word": false,
|
| 220 |
+
"special": true
|
| 221 |
+
},
|
| 222 |
+
"27": {
|
| 223 |
+
"content": "<|table_placeholder|>",
|
| 224 |
+
"lstrip": false,
|
| 225 |
+
"normalized": false,
|
| 226 |
+
"rstrip": false,
|
| 227 |
+
"single_word": false,
|
| 228 |
+
"special": true
|
| 229 |
+
},
|
| 230 |
+
"28": {
|
| 231 |
+
"content": "<|equation_placeholder|>",
|
| 232 |
+
"lstrip": false,
|
| 233 |
+
"normalized": false,
|
| 234 |
+
"rstrip": false,
|
| 235 |
+
"single_word": false,
|
| 236 |
+
"special": true
|
| 237 |
+
},
|
| 238 |
+
"29": {
|
| 239 |
+
"content": "<|code_placeholder|>",
|
| 240 |
+
"lstrip": false,
|
| 241 |
+
"normalized": false,
|
| 242 |
+
"rstrip": false,
|
| 243 |
+
"single_word": false,
|
| 244 |
+
"special": true
|
| 245 |
+
},
|
| 246 |
+
"30": {
|
| 247 |
+
"content": "<|reference_placeholder|>",
|
| 248 |
+
"lstrip": false,
|
| 249 |
+
"normalized": false,
|
| 250 |
+
"rstrip": false,
|
| 251 |
+
"single_word": false,
|
| 252 |
+
"special": true
|
| 253 |
+
},
|
| 254 |
+
"31": {
|
| 255 |
+
"content": "<|endoftext|>",
|
| 256 |
+
"lstrip": false,
|
| 257 |
+
"normalized": false,
|
| 258 |
+
"rstrip": false,
|
| 259 |
+
"single_word": false,
|
| 260 |
+
"special": true
|
| 261 |
+
},
|
| 262 |
+
"32": {
|
| 263 |
+
"content": "<fim_prefix>",
|
| 264 |
+
"lstrip": false,
|
| 265 |
+
"normalized": false,
|
| 266 |
+
"rstrip": false,
|
| 267 |
+
"single_word": false,
|
| 268 |
+
"special": true
|
| 269 |
+
},
|
| 270 |
+
"33": {
|
| 271 |
+
"content": "<fim_middle>",
|
| 272 |
+
"lstrip": false,
|
| 273 |
+
"normalized": false,
|
| 274 |
+
"rstrip": false,
|
| 275 |
+
"single_word": false,
|
| 276 |
+
"special": true
|
| 277 |
+
},
|
| 278 |
+
"34": {
|
| 279 |
+
"content": "<fim_suffix>",
|
| 280 |
+
"lstrip": false,
|
| 281 |
+
"normalized": false,
|
| 282 |
+
"rstrip": false,
|
| 283 |
+
"single_word": false,
|
| 284 |
+
"special": true
|
| 285 |
+
},
|
| 286 |
+
"35": {
|
| 287 |
+
"content": "<fim_pad>",
|
| 288 |
+
"lstrip": false,
|
| 289 |
+
"normalized": false,
|
| 290 |
+
"rstrip": false,
|
| 291 |
+
"single_word": false,
|
| 292 |
+
"special": true
|
| 293 |
+
},
|
| 294 |
+
"36": {
|
| 295 |
+
"content": "<longcat_think>",
|
| 296 |
+
"lstrip": false,
|
| 297 |
+
"normalized": false,
|
| 298 |
+
"rstrip": false,
|
| 299 |
+
"single_word": false,
|
| 300 |
+
"special": false
|
| 301 |
+
},
|
| 302 |
+
"37": {
|
| 303 |
+
"content": "</longcat_think>",
|
| 304 |
+
"lstrip": false,
|
| 305 |
+
"normalized": false,
|
| 306 |
+
"rstrip": false,
|
| 307 |
+
"single_word": false,
|
| 308 |
+
"special": false
|
| 309 |
+
},
|
| 310 |
+
"38": {
|
| 311 |
+
"content": "<longcat_answer>",
|
| 312 |
+
"lstrip": false,
|
| 313 |
+
"normalized": false,
|
| 314 |
+
"rstrip": false,
|
| 315 |
+
"single_word": false,
|
| 316 |
+
"special": false
|
| 317 |
+
},
|
| 318 |
+
"39": {
|
| 319 |
+
"content": "</longcat_answer>",
|
| 320 |
+
"lstrip": false,
|
| 321 |
+
"normalized": false,
|
| 322 |
+
"rstrip": false,
|
| 323 |
+
"single_word": false,
|
| 324 |
+
"special": false
|
| 325 |
+
},
|
| 326 |
+
"40": {
|
| 327 |
+
"content": "<longcat_files>",
|
| 328 |
+
"lstrip": false,
|
| 329 |
+
"normalized": false,
|
| 330 |
+
"rstrip": false,
|
| 331 |
+
"single_word": false,
|
| 332 |
+
"special": false
|
| 333 |
+
},
|
| 334 |
+
"41": {
|
| 335 |
+
"content": "</longcat_files>",
|
| 336 |
+
"lstrip": false,
|
| 337 |
+
"normalized": false,
|
| 338 |
+
"rstrip": false,
|
| 339 |
+
"single_word": false,
|
| 340 |
+
"special": false
|
| 341 |
+
},
|
| 342 |
+
"42": {
|
| 343 |
+
"content": "<longcat_tool_call>",
|
| 344 |
+
"lstrip": false,
|
| 345 |
+
"normalized": false,
|
| 346 |
+
"rstrip": false,
|
| 347 |
+
"single_word": false,
|
| 348 |
+
"special": false
|
| 349 |
+
},
|
| 350 |
+
"43": {
|
| 351 |
+
"content": "</longcat_tool_call>",
|
| 352 |
+
"lstrip": false,
|
| 353 |
+
"normalized": false,
|
| 354 |
+
"rstrip": false,
|
| 355 |
+
"single_word": false,
|
| 356 |
+
"special": false
|
| 357 |
+
},
|
| 358 |
+
"44": {
|
| 359 |
+
"content": "<longcat_tool_declare>",
|
| 360 |
+
"lstrip": false,
|
| 361 |
+
"normalized": false,
|
| 362 |
+
"rstrip": false,
|
| 363 |
+
"single_word": false,
|
| 364 |
+
"special": true
|
| 365 |
+
},
|
| 366 |
+
"45": {
|
| 367 |
+
"content": "</longcat_tool_declare>",
|
| 368 |
+
"lstrip": false,
|
| 369 |
+
"normalized": false,
|
| 370 |
+
"rstrip": false,
|
| 371 |
+
"single_word": false,
|
| 372 |
+
"special": true
|
| 373 |
+
},
|
| 374 |
+
"46": {
|
| 375 |
+
"content": "<longcat_system>",
|
| 376 |
+
"lstrip": false,
|
| 377 |
+
"normalized": false,
|
| 378 |
+
"rstrip": false,
|
| 379 |
+
"single_word": false,
|
| 380 |
+
"special": true
|
| 381 |
+
},
|
| 382 |
+
"47": {
|
| 383 |
+
"content": "<longcat_user>",
|
| 384 |
+
"lstrip": false,
|
| 385 |
+
"normalized": false,
|
| 386 |
+
"rstrip": false,
|
| 387 |
+
"single_word": false,
|
| 388 |
+
"special": true
|
| 389 |
+
},
|
| 390 |
+
"48": {
|
| 391 |
+
"content": "<longcat_assistant>",
|
| 392 |
+
"lstrip": false,
|
| 393 |
+
"normalized": false,
|
| 394 |
+
"rstrip": false,
|
| 395 |
+
"single_word": false,
|
| 396 |
+
"special": true
|
| 397 |
+
},
|
| 398 |
+
"49": {
|
| 399 |
+
"content": "<longcat_tool_response>",
|
| 400 |
+
"lstrip": false,
|
| 401 |
+
"normalized": false,
|
| 402 |
+
"rstrip": false,
|
| 403 |
+
"single_word": false,
|
| 404 |
+
"special": false
|
| 405 |
+
},
|
| 406 |
+
"50": {
|
| 407 |
+
"content": "</longcat_tool_response>",
|
| 408 |
+
"lstrip": false,
|
| 409 |
+
"normalized": false,
|
| 410 |
+
"rstrip": false,
|
| 411 |
+
"single_word": false,
|
| 412 |
+
"special": false
|
| 413 |
+
},
|
| 414 |
+
"51": {
|
| 415 |
+
"content": "<longcat_arg_key>",
|
| 416 |
+
"lstrip": false,
|
| 417 |
+
"normalized": false,
|
| 418 |
+
"rstrip": false,
|
| 419 |
+
"single_word": false,
|
| 420 |
+
"special": false
|
| 421 |
+
},
|
| 422 |
+
"52": {
|
| 423 |
+
"content": "</longcat_arg_key>",
|
| 424 |
+
"lstrip": false,
|
| 425 |
+
"normalized": false,
|
| 426 |
+
"rstrip": false,
|
| 427 |
+
"single_word": false,
|
| 428 |
+
"special": false
|
| 429 |
+
},
|
| 430 |
+
"53": {
|
| 431 |
+
"content": "<longcat_arg_value>",
|
| 432 |
+
"lstrip": false,
|
| 433 |
+
"normalized": false,
|
| 434 |
+
"rstrip": false,
|
| 435 |
+
"single_word": false,
|
| 436 |
+
"special": false
|
| 437 |
+
},
|
| 438 |
+
"54": {
|
| 439 |
+
"content": "</longcat_arg_value>",
|
| 440 |
+
"lstrip": false,
|
| 441 |
+
"normalized": false,
|
| 442 |
+
"rstrip": false,
|
| 443 |
+
"single_word": false,
|
| 444 |
+
"special": false
|
| 445 |
+
},
|
| 446 |
+
"55": {
|
| 447 |
+
"content": "<mask_31>",
|
| 448 |
+
"lstrip": false,
|
| 449 |
+
"normalized": false,
|
| 450 |
+
"rstrip": false,
|
| 451 |
+
"single_word": false,
|
| 452 |
+
"special": true
|
| 453 |
+
},
|
| 454 |
+
"56": {
|
| 455 |
+
"content": "<mask_32>",
|
| 456 |
+
"lstrip": false,
|
| 457 |
+
"normalized": false,
|
| 458 |
+
"rstrip": false,
|
| 459 |
+
"single_word": false,
|
| 460 |
+
"special": true
|
| 461 |
+
},
|
| 462 |
+
"57": {
|
| 463 |
+
"content": "<mask_33>",
|
| 464 |
+
"lstrip": false,
|
| 465 |
+
"normalized": false,
|
| 466 |
+
"rstrip": false,
|
| 467 |
+
"single_word": false,
|
| 468 |
+
"special": true
|
| 469 |
+
},
|
| 470 |
+
"58": {
|
| 471 |
+
"content": "<mask_34>",
|
| 472 |
+
"lstrip": false,
|
| 473 |
+
"normalized": false,
|
| 474 |
+
"rstrip": false,
|
| 475 |
+
"single_word": false,
|
| 476 |
+
"special": true
|
| 477 |
+
},
|
| 478 |
+
"59": {
|
| 479 |
+
"content": "<mask_35>",
|
| 480 |
+
"lstrip": false,
|
| 481 |
+
"normalized": false,
|
| 482 |
+
"rstrip": false,
|
| 483 |
+
"single_word": false,
|
| 484 |
+
"special": true
|
| 485 |
+
},
|
| 486 |
+
"60": {
|
| 487 |
+
"content": "<mask_36>",
|
| 488 |
+
"lstrip": false,
|
| 489 |
+
"normalized": false,
|
| 490 |
+
"rstrip": false,
|
| 491 |
+
"single_word": false,
|
| 492 |
+
"special": true
|
| 493 |
+
},
|
| 494 |
+
"61": {
|
| 495 |
+
"content": "<mask_37>",
|
| 496 |
+
"lstrip": false,
|
| 497 |
+
"normalized": false,
|
| 498 |
+
"rstrip": false,
|
| 499 |
+
"single_word": false,
|
| 500 |
+
"special": true
|
| 501 |
+
},
|
| 502 |
+
"62": {
|
| 503 |
+
"content": "<mask_38>",
|
| 504 |
+
"lstrip": false,
|
| 505 |
+
"normalized": false,
|
| 506 |
+
"rstrip": false,
|
| 507 |
+
"single_word": false,
|
| 508 |
+
"special": true
|
| 509 |
+
},
|
| 510 |
+
"63": {
|
| 511 |
+
"content": "<mask_39>",
|
| 512 |
+
"lstrip": false,
|
| 513 |
+
"normalized": false,
|
| 514 |
+
"rstrip": false,
|
| 515 |
+
"single_word": false,
|
| 516 |
+
"special": true
|
| 517 |
+
},
|
| 518 |
+
"64": {
|
| 519 |
+
"content": "<mask_40>",
|
| 520 |
+
"lstrip": false,
|
| 521 |
+
"normalized": false,
|
| 522 |
+
"rstrip": false,
|
| 523 |
+
"single_word": false,
|
| 524 |
+
"special": true
|
| 525 |
+
},
|
| 526 |
+
"65": {
|
| 527 |
+
"content": "<mask_41>",
|
| 528 |
+
"lstrip": false,
|
| 529 |
+
"normalized": false,
|
| 530 |
+
"rstrip": false,
|
| 531 |
+
"single_word": false,
|
| 532 |
+
"special": true
|
| 533 |
+
},
|
| 534 |
+
"66": {
|
| 535 |
+
"content": "<mask_42>",
|
| 536 |
+
"lstrip": false,
|
| 537 |
+
"normalized": false,
|
| 538 |
+
"rstrip": false,
|
| 539 |
+
"single_word": false,
|
| 540 |
+
"special": true
|
| 541 |
+
},
|
| 542 |
+
"67": {
|
| 543 |
+
"content": "<mask_43>",
|
| 544 |
+
"lstrip": false,
|
| 545 |
+
"normalized": false,
|
| 546 |
+
"rstrip": false,
|
| 547 |
+
"single_word": false,
|
| 548 |
+
"special": true
|
| 549 |
+
},
|
| 550 |
+
"68": {
|
| 551 |
+
"content": "<mask_44>",
|
| 552 |
+
"lstrip": false,
|
| 553 |
+
"normalized": false,
|
| 554 |
+
"rstrip": false,
|
| 555 |
+
"single_word": false,
|
| 556 |
+
"special": true
|
| 557 |
+
},
|
| 558 |
+
"69": {
|
| 559 |
+
"content": "<mask_45>",
|
| 560 |
+
"lstrip": false,
|
| 561 |
+
"normalized": false,
|
| 562 |
+
"rstrip": false,
|
| 563 |
+
"single_word": false,
|
| 564 |
+
"special": true
|
| 565 |
+
},
|
| 566 |
+
"70": {
|
| 567 |
+
"content": "<mask_46>",
|
| 568 |
+
"lstrip": false,
|
| 569 |
+
"normalized": false,
|
| 570 |
+
"rstrip": false,
|
| 571 |
+
"single_word": false,
|
| 572 |
+
"special": true
|
| 573 |
+
},
|
| 574 |
+
"71": {
|
| 575 |
+
"content": "<mask_47>",
|
| 576 |
+
"lstrip": false,
|
| 577 |
+
"normalized": false,
|
| 578 |
+
"rstrip": false,
|
| 579 |
+
"single_word": false,
|
| 580 |
+
"special": true
|
| 581 |
+
},
|
| 582 |
+
"72": {
|
| 583 |
+
"content": "<mask_48>",
|
| 584 |
+
"lstrip": false,
|
| 585 |
+
"normalized": false,
|
| 586 |
+
"rstrip": false,
|
| 587 |
+
"single_word": false,
|
| 588 |
+
"special": true
|
| 589 |
+
},
|
| 590 |
+
"73": {
|
| 591 |
+
"content": "<mask_49>",
|
| 592 |
+
"lstrip": false,
|
| 593 |
+
"normalized": false,
|
| 594 |
+
"rstrip": false,
|
| 595 |
+
"single_word": false,
|
| 596 |
+
"special": true
|
| 597 |
+
},
|
| 598 |
+
"74": {
|
| 599 |
+
"content": "<mask_50>",
|
| 600 |
+
"lstrip": false,
|
| 601 |
+
"normalized": false,
|
| 602 |
+
"rstrip": false,
|
| 603 |
+
"single_word": false,
|
| 604 |
+
"special": true
|
| 605 |
+
},
|
| 606 |
+
"75": {
|
| 607 |
+
"content": "<mask_51>",
|
| 608 |
+
"lstrip": false,
|
| 609 |
+
"normalized": false,
|
| 610 |
+
"rstrip": false,
|
| 611 |
+
"single_word": false,
|
| 612 |
+
"special": true
|
| 613 |
+
},
|
| 614 |
+
"76": {
|
| 615 |
+
"content": "<mask_52>",
|
| 616 |
+
"lstrip": false,
|
| 617 |
+
"normalized": false,
|
| 618 |
+
"rstrip": false,
|
| 619 |
+
"single_word": false,
|
| 620 |
+
"special": true
|
| 621 |
+
},
|
| 622 |
+
"77": {
|
| 623 |
+
"content": "<mask_53>",
|
| 624 |
+
"lstrip": false,
|
| 625 |
+
"normalized": false,
|
| 626 |
+
"rstrip": false,
|
| 627 |
+
"single_word": false,
|
| 628 |
+
"special": true
|
| 629 |
+
},
|
| 630 |
+
"78": {
|
| 631 |
+
"content": "<mask_54>",
|
| 632 |
+
"lstrip": false,
|
| 633 |
+
"normalized": false,
|
| 634 |
+
"rstrip": false,
|
| 635 |
+
"single_word": false,
|
| 636 |
+
"special": true
|
| 637 |
+
},
|
| 638 |
+
"79": {
|
| 639 |
+
"content": "<mask_55>",
|
| 640 |
+
"lstrip": false,
|
| 641 |
+
"normalized": false,
|
| 642 |
+
"rstrip": false,
|
| 643 |
+
"single_word": false,
|
| 644 |
+
"special": true
|
| 645 |
+
},
|
| 646 |
+
"80": {
|
| 647 |
+
"content": "<mask_56>",
|
| 648 |
+
"lstrip": false,
|
| 649 |
+
"normalized": false,
|
| 650 |
+
"rstrip": false,
|
| 651 |
+
"single_word": false,
|
| 652 |
+
"special": true
|
| 653 |
+
},
|
| 654 |
+
"81": {
|
| 655 |
+
"content": "<mask_57>",
|
| 656 |
+
"lstrip": false,
|
| 657 |
+
"normalized": false,
|
| 658 |
+
"rstrip": false,
|
| 659 |
+
"single_word": false,
|
| 660 |
+
"special": true
|
| 661 |
+
},
|
| 662 |
+
"82": {
|
| 663 |
+
"content": "<mask_58>",
|
| 664 |
+
"lstrip": false,
|
| 665 |
+
"normalized": false,
|
| 666 |
+
"rstrip": false,
|
| 667 |
+
"single_word": false,
|
| 668 |
+
"special": true
|
| 669 |
+
},
|
| 670 |
+
"83": {
|
| 671 |
+
"content": "<mask_59>",
|
| 672 |
+
"lstrip": false,
|
| 673 |
+
"normalized": false,
|
| 674 |
+
"rstrip": false,
|
| 675 |
+
"single_word": false,
|
| 676 |
+
"special": true
|
| 677 |
+
},
|
| 678 |
+
"84": {
|
| 679 |
+
"content": "<mask_60>",
|
| 680 |
+
"lstrip": false,
|
| 681 |
+
"normalized": false,
|
| 682 |
+
"rstrip": false,
|
| 683 |
+
"single_word": false,
|
| 684 |
+
"special": true
|
| 685 |
+
},
|
| 686 |
+
"85": {
|
| 687 |
+
"content": "<mask_61>",
|
| 688 |
+
"lstrip": false,
|
| 689 |
+
"normalized": false,
|
| 690 |
+
"rstrip": false,
|
| 691 |
+
"single_word": false,
|
| 692 |
+
"special": true
|
| 693 |
+
},
|
| 694 |
+
"86": {
|
| 695 |
+
"content": "<mask_62>",
|
| 696 |
+
"lstrip": false,
|
| 697 |
+
"normalized": false,
|
| 698 |
+
"rstrip": false,
|
| 699 |
+
"single_word": false,
|
| 700 |
+
"special": true
|
| 701 |
+
},
|
| 702 |
+
"87": {
|
| 703 |
+
"content": "<mask_63>",
|
| 704 |
+
"lstrip": false,
|
| 705 |
+
"normalized": false,
|
| 706 |
+
"rstrip": false,
|
| 707 |
+
"single_word": false,
|
| 708 |
+
"special": true
|
| 709 |
+
},
|
| 710 |
+
"88": {
|
| 711 |
+
"content": "<mask_64>",
|
| 712 |
+
"lstrip": false,
|
| 713 |
+
"normalized": false,
|
| 714 |
+
"rstrip": false,
|
| 715 |
+
"single_word": false,
|
| 716 |
+
"special": true
|
| 717 |
+
},
|
| 718 |
+
"89": {
|
| 719 |
+
"content": "<mask_65>",
|
| 720 |
+
"lstrip": false,
|
| 721 |
+
"normalized": false,
|
| 722 |
+
"rstrip": false,
|
| 723 |
+
"single_word": false,
|
| 724 |
+
"special": true
|
| 725 |
+
},
|
| 726 |
+
"90": {
|
| 727 |
+
"content": "<mask_66>",
|
| 728 |
+
"lstrip": false,
|
| 729 |
+
"normalized": false,
|
| 730 |
+
"rstrip": false,
|
| 731 |
+
"single_word": false,
|
| 732 |
+
"special": true
|
| 733 |
+
},
|
| 734 |
+
"91": {
|
| 735 |
+
"content": "<mask_67>",
|
| 736 |
+
"lstrip": false,
|
| 737 |
+
"normalized": false,
|
| 738 |
+
"rstrip": false,
|
| 739 |
+
"single_word": false,
|
| 740 |
+
"special": true
|
| 741 |
+
},
|
| 742 |
+
"92": {
|
| 743 |
+
"content": "<mask_68>",
|
| 744 |
+
"lstrip": false,
|
| 745 |
+
"normalized": false,
|
| 746 |
+
"rstrip": false,
|
| 747 |
+
"single_word": false,
|
| 748 |
+
"special": true
|
| 749 |
+
},
|
| 750 |
+
"93": {
|
| 751 |
+
"content": "<mask_69>",
|
| 752 |
+
"lstrip": false,
|
| 753 |
+
"normalized": false,
|
| 754 |
+
"rstrip": false,
|
| 755 |
+
"single_word": false,
|
| 756 |
+
"special": true
|
| 757 |
+
},
|
| 758 |
+
"94": {
|
| 759 |
+
"content": "<mask_70>",
|
| 760 |
+
"lstrip": false,
|
| 761 |
+
"normalized": false,
|
| 762 |
+
"rstrip": false,
|
| 763 |
+
"single_word": false,
|
| 764 |
+
"special": true
|
| 765 |
+
},
|
| 766 |
+
"95": {
|
| 767 |
+
"content": "<mask_71>",
|
| 768 |
+
"lstrip": false,
|
| 769 |
+
"normalized": false,
|
| 770 |
+
"rstrip": false,
|
| 771 |
+
"single_word": false,
|
| 772 |
+
"special": true
|
| 773 |
+
},
|
| 774 |
+
"96": {
|
| 775 |
+
"content": "<mask_72>",
|
| 776 |
+
"lstrip": false,
|
| 777 |
+
"normalized": false,
|
| 778 |
+
"rstrip": false,
|
| 779 |
+
"single_word": false,
|
| 780 |
+
"special": true
|
| 781 |
+
},
|
| 782 |
+
"97": {
|
| 783 |
+
"content": "<mask_73>",
|
| 784 |
+
"lstrip": false,
|
| 785 |
+
"normalized": false,
|
| 786 |
+
"rstrip": false,
|
| 787 |
+
"single_word": false,
|
| 788 |
+
"special": true
|
| 789 |
+
},
|
| 790 |
+
"98": {
|
| 791 |
+
"content": "<mask_74>",
|
| 792 |
+
"lstrip": false,
|
| 793 |
+
"normalized": false,
|
| 794 |
+
"rstrip": false,
|
| 795 |
+
"single_word": false,
|
| 796 |
+
"special": true
|
| 797 |
+
},
|
| 798 |
+
"99": {
|
| 799 |
+
"content": "<mask_75>",
|
| 800 |
+
"lstrip": false,
|
| 801 |
+
"normalized": false,
|
| 802 |
+
"rstrip": false,
|
| 803 |
+
"single_word": false,
|
| 804 |
+
"special": true
|
| 805 |
+
},
|
| 806 |
+
"100": {
|
| 807 |
+
"content": "<mask_76>",
|
| 808 |
+
"lstrip": false,
|
| 809 |
+
"normalized": false,
|
| 810 |
+
"rstrip": false,
|
| 811 |
+
"single_word": false,
|
| 812 |
+
"special": true
|
| 813 |
+
},
|
| 814 |
+
"101": {
|
| 815 |
+
"content": "<mask_77>",
|
| 816 |
+
"lstrip": false,
|
| 817 |
+
"normalized": false,
|
| 818 |
+
"rstrip": false,
|
| 819 |
+
"single_word": false,
|
| 820 |
+
"special": true
|
| 821 |
+
},
|
| 822 |
+
"102": {
|
| 823 |
+
"content": "<mask_78>",
|
| 824 |
+
"lstrip": false,
|
| 825 |
+
"normalized": false,
|
| 826 |
+
"rstrip": false,
|
| 827 |
+
"single_word": false,
|
| 828 |
+
"special": true
|
| 829 |
+
},
|
| 830 |
+
"103": {
|
| 831 |
+
"content": "<mask_79>",
|
| 832 |
+
"lstrip": false,
|
| 833 |
+
"normalized": false,
|
| 834 |
+
"rstrip": false,
|
| 835 |
+
"single_word": false,
|
| 836 |
+
"special": true
|
| 837 |
+
},
|
| 838 |
+
"104": {
|
| 839 |
+
"content": "<mask_80>",
|
| 840 |
+
"lstrip": false,
|
| 841 |
+
"normalized": false,
|
| 842 |
+
"rstrip": false,
|
| 843 |
+
"single_word": false,
|
| 844 |
+
"special": true
|
| 845 |
+
},
|
| 846 |
+
"105": {
|
| 847 |
+
"content": "<mask_81>",
|
| 848 |
+
"lstrip": false,
|
| 849 |
+
"normalized": false,
|
| 850 |
+
"rstrip": false,
|
| 851 |
+
"single_word": false,
|
| 852 |
+
"special": true
|
| 853 |
+
},
|
| 854 |
+
"106": {
|
| 855 |
+
"content": "<mask_82>",
|
| 856 |
+
"lstrip": false,
|
| 857 |
+
"normalized": false,
|
| 858 |
+
"rstrip": false,
|
| 859 |
+
"single_word": false,
|
| 860 |
+
"special": true
|
| 861 |
+
},
|
| 862 |
+
"107": {
|
| 863 |
+
"content": "<mask_83>",
|
| 864 |
+
"lstrip": false,
|
| 865 |
+
"normalized": false,
|
| 866 |
+
"rstrip": false,
|
| 867 |
+
"single_word": false,
|
| 868 |
+
"special": true
|
| 869 |
+
},
|
| 870 |
+
"108": {
|
| 871 |
+
"content": "<mask_84>",
|
| 872 |
+
"lstrip": false,
|
| 873 |
+
"normalized": false,
|
| 874 |
+
"rstrip": false,
|
| 875 |
+
"single_word": false,
|
| 876 |
+
"special": true
|
| 877 |
+
},
|
| 878 |
+
"109": {
|
| 879 |
+
"content": "<mask_85>",
|
| 880 |
+
"lstrip": false,
|
| 881 |
+
"normalized": false,
|
| 882 |
+
"rstrip": false,
|
| 883 |
+
"single_word": false,
|
| 884 |
+
"special": true
|
| 885 |
+
},
|
| 886 |
+
"110": {
|
| 887 |
+
"content": "<mask_86>",
|
| 888 |
+
"lstrip": false,
|
| 889 |
+
"normalized": false,
|
| 890 |
+
"rstrip": false,
|
| 891 |
+
"single_word": false,
|
| 892 |
+
"special": true
|
| 893 |
+
},
|
| 894 |
+
"111": {
|
| 895 |
+
"content": "<mask_87>",
|
| 896 |
+
"lstrip": false,
|
| 897 |
+
"normalized": false,
|
| 898 |
+
"rstrip": false,
|
| 899 |
+
"single_word": false,
|
| 900 |
+
"special": true
|
| 901 |
+
},
|
| 902 |
+
"112": {
|
| 903 |
+
"content": "<mask_88>",
|
| 904 |
+
"lstrip": false,
|
| 905 |
+
"normalized": false,
|
| 906 |
+
"rstrip": false,
|
| 907 |
+
"single_word": false,
|
| 908 |
+
"special": true
|
| 909 |
+
},
|
| 910 |
+
"113": {
|
| 911 |
+
"content": "<mask_89>",
|
| 912 |
+
"lstrip": false,
|
| 913 |
+
"normalized": false,
|
| 914 |
+
"rstrip": false,
|
| 915 |
+
"single_word": false,
|
| 916 |
+
"special": true
|
| 917 |
+
},
|
| 918 |
+
"114": {
|
| 919 |
+
"content": "<mask_90>",
|
| 920 |
+
"lstrip": false,
|
| 921 |
+
"normalized": false,
|
| 922 |
+
"rstrip": false,
|
| 923 |
+
"single_word": false,
|
| 924 |
+
"special": true
|
| 925 |
+
},
|
| 926 |
+
"115": {
|
| 927 |
+
"content": "<mask_91>",
|
| 928 |
+
"lstrip": false,
|
| 929 |
+
"normalized": false,
|
| 930 |
+
"rstrip": false,
|
| 931 |
+
"single_word": false,
|
| 932 |
+
"special": true
|
| 933 |
+
},
|
| 934 |
+
"116": {
|
| 935 |
+
"content": "<mask_92>",
|
| 936 |
+
"lstrip": false,
|
| 937 |
+
"normalized": false,
|
| 938 |
+
"rstrip": false,
|
| 939 |
+
"single_word": false,
|
| 940 |
+
"special": true
|
| 941 |
+
},
|
| 942 |
+
"117": {
|
| 943 |
+
"content": "<mask_93>",
|
| 944 |
+
"lstrip": false,
|
| 945 |
+
"normalized": false,
|
| 946 |
+
"rstrip": false,
|
| 947 |
+
"single_word": false,
|
| 948 |
+
"special": true
|
| 949 |
+
},
|
| 950 |
+
"118": {
|
| 951 |
+
"content": "<mask_94>",
|
| 952 |
+
"lstrip": false,
|
| 953 |
+
"normalized": false,
|
| 954 |
+
"rstrip": false,
|
| 955 |
+
"single_word": false,
|
| 956 |
+
"special": true
|
| 957 |
+
},
|
| 958 |
+
"119": {
|
| 959 |
+
"content": "<mask_95>",
|
| 960 |
+
"lstrip": false,
|
| 961 |
+
"normalized": false,
|
| 962 |
+
"rstrip": false,
|
| 963 |
+
"single_word": false,
|
| 964 |
+
"special": true
|
| 965 |
+
},
|
| 966 |
+
"120": {
|
| 967 |
+
"content": "<mask_96>",
|
| 968 |
+
"lstrip": false,
|
| 969 |
+
"normalized": false,
|
| 970 |
+
"rstrip": false,
|
| 971 |
+
"single_word": false,
|
| 972 |
+
"special": true
|
| 973 |
+
},
|
| 974 |
+
"121": {
|
| 975 |
+
"content": "<mask_97>",
|
| 976 |
+
"lstrip": false,
|
| 977 |
+
"normalized": false,
|
| 978 |
+
"rstrip": false,
|
| 979 |
+
"single_word": false,
|
| 980 |
+
"special": true
|
| 981 |
+
},
|
| 982 |
+
"122": {
|
| 983 |
+
"content": "<mask_98>",
|
| 984 |
+
"lstrip": false,
|
| 985 |
+
"normalized": false,
|
| 986 |
+
"rstrip": false,
|
| 987 |
+
"single_word": false,
|
| 988 |
+
"special": true
|
| 989 |
+
},
|
| 990 |
+
"123": {
|
| 991 |
+
"content": "<mask_99>",
|
| 992 |
+
"lstrip": false,
|
| 993 |
+
"normalized": false,
|
| 994 |
+
"rstrip": false,
|
| 995 |
+
"single_word": false,
|
| 996 |
+
"special": true
|
| 997 |
+
},
|
| 998 |
+
"124": {
|
| 999 |
+
"content": "<mask_100>",
|
| 1000 |
+
"lstrip": false,
|
| 1001 |
+
"normalized": false,
|
| 1002 |
+
"rstrip": false,
|
| 1003 |
+
"single_word": false,
|
| 1004 |
+
"special": true
|
| 1005 |
+
},
|
| 1006 |
+
"125": {
|
| 1007 |
+
"content": "<mask_101>",
|
| 1008 |
+
"lstrip": false,
|
| 1009 |
+
"normalized": false,
|
| 1010 |
+
"rstrip": false,
|
| 1011 |
+
"single_word": false,
|
| 1012 |
+
"special": true
|
| 1013 |
+
},
|
| 1014 |
+
"126": {
|
| 1015 |
+
"content": "<mask_102>",
|
| 1016 |
+
"lstrip": false,
|
| 1017 |
+
"normalized": false,
|
| 1018 |
+
"rstrip": false,
|
| 1019 |
+
"single_word": false,
|
| 1020 |
+
"special": true
|
| 1021 |
+
},
|
| 1022 |
+
"127": {
|
| 1023 |
+
"content": "<mask_103>",
|
| 1024 |
+
"lstrip": false,
|
| 1025 |
+
"normalized": false,
|
| 1026 |
+
"rstrip": false,
|
| 1027 |
+
"single_word": false,
|
| 1028 |
+
"special": true
|
| 1029 |
+
},
|
| 1030 |
+
"128": {
|
| 1031 |
+
"content": "<mask_104>",
|
| 1032 |
+
"lstrip": false,
|
| 1033 |
+
"normalized": false,
|
| 1034 |
+
"rstrip": false,
|
| 1035 |
+
"single_word": false,
|
| 1036 |
+
"special": true
|
| 1037 |
+
},
|
| 1038 |
+
"129": {
|
| 1039 |
+
"content": "<mask_105>",
|
| 1040 |
+
"lstrip": false,
|
| 1041 |
+
"normalized": false,
|
| 1042 |
+
"rstrip": false,
|
| 1043 |
+
"single_word": false,
|
| 1044 |
+
"special": true
|
| 1045 |
+
},
|
| 1046 |
+
"130": {
|
| 1047 |
+
"content": "<mask_106>",
|
| 1048 |
+
"lstrip": false,
|
| 1049 |
+
"normalized": false,
|
| 1050 |
+
"rstrip": false,
|
| 1051 |
+
"single_word": false,
|
| 1052 |
+
"special": true
|
| 1053 |
+
},
|
| 1054 |
+
"131": {
|
| 1055 |
+
"content": "<mask_107>",
|
| 1056 |
+
"lstrip": false,
|
| 1057 |
+
"normalized": false,
|
| 1058 |
+
"rstrip": false,
|
| 1059 |
+
"single_word": false,
|
| 1060 |
+
"special": true
|
| 1061 |
+
},
|
| 1062 |
+
"132": {
|
| 1063 |
+
"content": "<mask_108>",
|
| 1064 |
+
"lstrip": false,
|
| 1065 |
+
"normalized": false,
|
| 1066 |
+
"rstrip": false,
|
| 1067 |
+
"single_word": false,
|
| 1068 |
+
"special": true
|
| 1069 |
+
},
|
| 1070 |
+
"133": {
|
| 1071 |
+
"content": "<mask_109>",
|
| 1072 |
+
"lstrip": false,
|
| 1073 |
+
"normalized": false,
|
| 1074 |
+
"rstrip": false,
|
| 1075 |
+
"single_word": false,
|
| 1076 |
+
"special": true
|
| 1077 |
+
},
|
| 1078 |
+
"134": {
|
| 1079 |
+
"content": "<mask_110>",
|
| 1080 |
+
"lstrip": false,
|
| 1081 |
+
"normalized": false,
|
| 1082 |
+
"rstrip": false,
|
| 1083 |
+
"single_word": false,
|
| 1084 |
+
"special": true
|
| 1085 |
+
},
|
| 1086 |
+
"135": {
|
| 1087 |
+
"content": "<mask_111>",
|
| 1088 |
+
"lstrip": false,
|
| 1089 |
+
"normalized": false,
|
| 1090 |
+
"rstrip": false,
|
| 1091 |
+
"single_word": false,
|
| 1092 |
+
"special": true
|
| 1093 |
+
},
|
| 1094 |
+
"136": {
|
| 1095 |
+
"content": "<mask_112>",
|
| 1096 |
+
"lstrip": false,
|
| 1097 |
+
"normalized": false,
|
| 1098 |
+
"rstrip": false,
|
| 1099 |
+
"single_word": false,
|
| 1100 |
+
"special": true
|
| 1101 |
+
},
|
| 1102 |
+
"137": {
|
| 1103 |
+
"content": "<mask_113>",
|
| 1104 |
+
"lstrip": false,
|
| 1105 |
+
"normalized": false,
|
| 1106 |
+
"rstrip": false,
|
| 1107 |
+
"single_word": false,
|
| 1108 |
+
"special": true
|
| 1109 |
+
},
|
| 1110 |
+
"138": {
|
| 1111 |
+
"content": "<mask_114>",
|
| 1112 |
+
"lstrip": false,
|
| 1113 |
+
"normalized": false,
|
| 1114 |
+
"rstrip": false,
|
| 1115 |
+
"single_word": false,
|
| 1116 |
+
"special": true
|
| 1117 |
+
},
|
| 1118 |
+
"139": {
|
| 1119 |
+
"content": "<mask_115>",
|
| 1120 |
+
"lstrip": false,
|
| 1121 |
+
"normalized": false,
|
| 1122 |
+
"rstrip": false,
|
| 1123 |
+
"single_word": false,
|
| 1124 |
+
"special": true
|
| 1125 |
+
},
|
| 1126 |
+
"140": {
|
| 1127 |
+
"content": "<mask_116>",
|
| 1128 |
+
"lstrip": false,
|
| 1129 |
+
"normalized": false,
|
| 1130 |
+
"rstrip": false,
|
| 1131 |
+
"single_word": false,
|
| 1132 |
+
"special": true
|
| 1133 |
+
},
|
| 1134 |
+
"141": {
|
| 1135 |
+
"content": "<mask_117>",
|
| 1136 |
+
"lstrip": false,
|
| 1137 |
+
"normalized": false,
|
| 1138 |
+
"rstrip": false,
|
| 1139 |
+
"single_word": false,
|
| 1140 |
+
"special": true
|
| 1141 |
+
},
|
| 1142 |
+
"142": {
|
| 1143 |
+
"content": "<mask_118>",
|
| 1144 |
+
"lstrip": false,
|
| 1145 |
+
"normalized": false,
|
| 1146 |
+
"rstrip": false,
|
| 1147 |
+
"single_word": false,
|
| 1148 |
+
"special": true
|
| 1149 |
+
},
|
| 1150 |
+
"143": {
|
| 1151 |
+
"content": "<mask_119>",
|
| 1152 |
+
"lstrip": false,
|
| 1153 |
+
"normalized": false,
|
| 1154 |
+
"rstrip": false,
|
| 1155 |
+
"single_word": false,
|
| 1156 |
+
"special": true
|
| 1157 |
+
},
|
| 1158 |
+
"144": {
|
| 1159 |
+
"content": "<mask_120>",
|
| 1160 |
+
"lstrip": false,
|
| 1161 |
+
"normalized": false,
|
| 1162 |
+
"rstrip": false,
|
| 1163 |
+
"single_word": false,
|
| 1164 |
+
"special": true
|
| 1165 |
+
},
|
| 1166 |
+
"145": {
|
| 1167 |
+
"content": "<mask_121>",
|
| 1168 |
+
"lstrip": false,
|
| 1169 |
+
"normalized": false,
|
| 1170 |
+
"rstrip": false,
|
| 1171 |
+
"single_word": false,
|
| 1172 |
+
"special": true
|
| 1173 |
+
},
|
| 1174 |
+
"146": {
|
| 1175 |
+
"content": "<mask_122>",
|
| 1176 |
+
"lstrip": false,
|
| 1177 |
+
"normalized": false,
|
| 1178 |
+
"rstrip": false,
|
| 1179 |
+
"single_word": false,
|
| 1180 |
+
"special": true
|
| 1181 |
+
},
|
| 1182 |
+
"147": {
|
| 1183 |
+
"content": "<mask_123>",
|
| 1184 |
+
"lstrip": false,
|
| 1185 |
+
"normalized": false,
|
| 1186 |
+
"rstrip": false,
|
| 1187 |
+
"single_word": false,
|
| 1188 |
+
"special": true
|
| 1189 |
+
},
|
| 1190 |
+
"148": {
|
| 1191 |
+
"content": "<mask_124>",
|
| 1192 |
+
"lstrip": false,
|
| 1193 |
+
"normalized": false,
|
| 1194 |
+
"rstrip": false,
|
| 1195 |
+
"single_word": false,
|
| 1196 |
+
"special": true
|
| 1197 |
+
},
|
| 1198 |
+
"149": {
|
| 1199 |
+
"content": "<mask_125>",
|
| 1200 |
+
"lstrip": false,
|
| 1201 |
+
"normalized": false,
|
| 1202 |
+
"rstrip": false,
|
| 1203 |
+
"single_word": false,
|
| 1204 |
+
"special": true
|
| 1205 |
+
},
|
| 1206 |
+
"150": {
|
| 1207 |
+
"content": "<mask_126>",
|
| 1208 |
+
"lstrip": false,
|
| 1209 |
+
"normalized": false,
|
| 1210 |
+
"rstrip": false,
|
| 1211 |
+
"single_word": false,
|
| 1212 |
+
"special": true
|
| 1213 |
+
},
|
| 1214 |
+
"151": {
|
| 1215 |
+
"content": "<mask_127>",
|
| 1216 |
+
"lstrip": false,
|
| 1217 |
+
"normalized": false,
|
| 1218 |
+
"rstrip": false,
|
| 1219 |
+
"single_word": false,
|
| 1220 |
+
"special": true
|
| 1221 |
+
},
|
| 1222 |
+
"152": {
|
| 1223 |
+
"content": "<mask_128>",
|
| 1224 |
+
"lstrip": false,
|
| 1225 |
+
"normalized": false,
|
| 1226 |
+
"rstrip": false,
|
| 1227 |
+
"single_word": false,
|
| 1228 |
+
"special": true
|
| 1229 |
+
},
|
| 1230 |
+
"153": {
|
| 1231 |
+
"content": "<mask_129>",
|
| 1232 |
+
"lstrip": false,
|
| 1233 |
+
"normalized": false,
|
| 1234 |
+
"rstrip": false,
|
| 1235 |
+
"single_word": false,
|
| 1236 |
+
"special": true
|
| 1237 |
+
},
|
| 1238 |
+
"154": {
|
| 1239 |
+
"content": "<mask_130>",
|
| 1240 |
+
"lstrip": false,
|
| 1241 |
+
"normalized": false,
|
| 1242 |
+
"rstrip": false,
|
| 1243 |
+
"single_word": false,
|
| 1244 |
+
"special": true
|
| 1245 |
+
},
|
| 1246 |
+
"155": {
|
| 1247 |
+
"content": "<mask_131>",
|
| 1248 |
+
"lstrip": false,
|
| 1249 |
+
"normalized": false,
|
| 1250 |
+
"rstrip": false,
|
| 1251 |
+
"single_word": false,
|
| 1252 |
+
"special": true
|
| 1253 |
+
},
|
| 1254 |
+
"156": {
|
| 1255 |
+
"content": "<mask_132>",
|
| 1256 |
+
"lstrip": false,
|
| 1257 |
+
"normalized": false,
|
| 1258 |
+
"rstrip": false,
|
| 1259 |
+
"single_word": false,
|
| 1260 |
+
"special": true
|
| 1261 |
+
},
|
| 1262 |
+
"157": {
|
| 1263 |
+
"content": "<mask_133>",
|
| 1264 |
+
"lstrip": false,
|
| 1265 |
+
"normalized": false,
|
| 1266 |
+
"rstrip": false,
|
| 1267 |
+
"single_word": false,
|
| 1268 |
+
"special": true
|
| 1269 |
+
},
|
| 1270 |
+
"158": {
|
| 1271 |
+
"content": "<mask_134>",
|
| 1272 |
+
"lstrip": false,
|
| 1273 |
+
"normalized": false,
|
| 1274 |
+
"rstrip": false,
|
| 1275 |
+
"single_word": false,
|
| 1276 |
+
"special": true
|
| 1277 |
+
},
|
| 1278 |
+
"159": {
|
| 1279 |
+
"content": "<mask_135>",
|
| 1280 |
+
"lstrip": false,
|
| 1281 |
+
"normalized": false,
|
| 1282 |
+
"rstrip": false,
|
| 1283 |
+
"single_word": false,
|
| 1284 |
+
"special": true
|
| 1285 |
+
},
|
| 1286 |
+
"160": {
|
| 1287 |
+
"content": "<mask_136>",
|
| 1288 |
+
"lstrip": false,
|
| 1289 |
+
"normalized": false,
|
| 1290 |
+
"rstrip": false,
|
| 1291 |
+
"single_word": false,
|
| 1292 |
+
"special": true
|
| 1293 |
+
},
|
| 1294 |
+
"161": {
|
| 1295 |
+
"content": "<mask_137>",
|
| 1296 |
+
"lstrip": false,
|
| 1297 |
+
"normalized": false,
|
| 1298 |
+
"rstrip": false,
|
| 1299 |
+
"single_word": false,
|
| 1300 |
+
"special": true
|
| 1301 |
+
},
|
| 1302 |
+
"162": {
|
| 1303 |
+
"content": "<mask_138>",
|
| 1304 |
+
"lstrip": false,
|
| 1305 |
+
"normalized": false,
|
| 1306 |
+
"rstrip": false,
|
| 1307 |
+
"single_word": false,
|
| 1308 |
+
"special": true
|
| 1309 |
+
},
|
| 1310 |
+
"163": {
|
| 1311 |
+
"content": "<mask_139>",
|
| 1312 |
+
"lstrip": false,
|
| 1313 |
+
"normalized": false,
|
| 1314 |
+
"rstrip": false,
|
| 1315 |
+
"single_word": false,
|
| 1316 |
+
"special": true
|
| 1317 |
+
},
|
| 1318 |
+
"164": {
|
| 1319 |
+
"content": "<mask_140>",
|
| 1320 |
+
"lstrip": false,
|
| 1321 |
+
"normalized": false,
|
| 1322 |
+
"rstrip": false,
|
| 1323 |
+
"single_word": false,
|
| 1324 |
+
"special": true
|
| 1325 |
+
},
|
| 1326 |
+
"165": {
|
| 1327 |
+
"content": "<mask_141>",
|
| 1328 |
+
"lstrip": false,
|
| 1329 |
+
"normalized": false,
|
| 1330 |
+
"rstrip": false,
|
| 1331 |
+
"single_word": false,
|
| 1332 |
+
"special": true
|
| 1333 |
+
},
|
| 1334 |
+
"166": {
|
| 1335 |
+
"content": "<mask_142>",
|
| 1336 |
+
"lstrip": false,
|
| 1337 |
+
"normalized": false,
|
| 1338 |
+
"rstrip": false,
|
| 1339 |
+
"single_word": false,
|
| 1340 |
+
"special": true
|
| 1341 |
+
},
|
| 1342 |
+
"167": {
|
| 1343 |
+
"content": "<mask_143>",
|
| 1344 |
+
"lstrip": false,
|
| 1345 |
+
"normalized": false,
|
| 1346 |
+
"rstrip": false,
|
| 1347 |
+
"single_word": false,
|
| 1348 |
+
"special": true
|
| 1349 |
+
},
|
| 1350 |
+
"168": {
|
| 1351 |
+
"content": "<mask_144>",
|
| 1352 |
+
"lstrip": false,
|
| 1353 |
+
"normalized": false,
|
| 1354 |
+
"rstrip": false,
|
| 1355 |
+
"single_word": false,
|
| 1356 |
+
"special": true
|
| 1357 |
+
},
|
| 1358 |
+
"169": {
|
| 1359 |
+
"content": "<mask_145>",
|
| 1360 |
+
"lstrip": false,
|
| 1361 |
+
"normalized": false,
|
| 1362 |
+
"rstrip": false,
|
| 1363 |
+
"single_word": false,
|
| 1364 |
+
"special": true
|
| 1365 |
+
},
|
| 1366 |
+
"170": {
|
| 1367 |
+
"content": "<mask_146>",
|
| 1368 |
+
"lstrip": false,
|
| 1369 |
+
"normalized": false,
|
| 1370 |
+
"rstrip": false,
|
| 1371 |
+
"single_word": false,
|
| 1372 |
+
"special": true
|
| 1373 |
+
},
|
| 1374 |
+
"171": {
|
| 1375 |
+
"content": "<mask_147>",
|
| 1376 |
+
"lstrip": false,
|
| 1377 |
+
"normalized": false,
|
| 1378 |
+
"rstrip": false,
|
| 1379 |
+
"single_word": false,
|
| 1380 |
+
"special": true
|
| 1381 |
+
},
|
| 1382 |
+
"172": {
|
| 1383 |
+
"content": "<mask_148>",
|
| 1384 |
+
"lstrip": false,
|
| 1385 |
+
"normalized": false,
|
| 1386 |
+
"rstrip": false,
|
| 1387 |
+
"single_word": false,
|
| 1388 |
+
"special": true
|
| 1389 |
+
},
|
| 1390 |
+
"173": {
|
| 1391 |
+
"content": "<mask_149>",
|
| 1392 |
+
"lstrip": false,
|
| 1393 |
+
"normalized": false,
|
| 1394 |
+
"rstrip": false,
|
| 1395 |
+
"single_word": false,
|
| 1396 |
+
"special": true
|
| 1397 |
+
},
|
| 1398 |
+
"174": {
|
| 1399 |
+
"content": "<mask_150>",
|
| 1400 |
+
"lstrip": false,
|
| 1401 |
+
"normalized": false,
|
| 1402 |
+
"rstrip": false,
|
| 1403 |
+
"single_word": false,
|
| 1404 |
+
"special": true
|
| 1405 |
+
},
|
| 1406 |
+
"175": {
|
| 1407 |
+
"content": "<mask_151>",
|
| 1408 |
+
"lstrip": false,
|
| 1409 |
+
"normalized": false,
|
| 1410 |
+
"rstrip": false,
|
| 1411 |
+
"single_word": false,
|
| 1412 |
+
"special": true
|
| 1413 |
+
},
|
| 1414 |
+
"176": {
|
| 1415 |
+
"content": "<mask_152>",
|
| 1416 |
+
"lstrip": false,
|
| 1417 |
+
"normalized": false,
|
| 1418 |
+
"rstrip": false,
|
| 1419 |
+
"single_word": false,
|
| 1420 |
+
"special": true
|
| 1421 |
+
},
|
| 1422 |
+
"177": {
|
| 1423 |
+
"content": "<mask_153>",
|
| 1424 |
+
"lstrip": false,
|
| 1425 |
+
"normalized": false,
|
| 1426 |
+
"rstrip": false,
|
| 1427 |
+
"single_word": false,
|
| 1428 |
+
"special": true
|
| 1429 |
+
},
|
| 1430 |
+
"178": {
|
| 1431 |
+
"content": "<mask_154>",
|
| 1432 |
+
"lstrip": false,
|
| 1433 |
+
"normalized": false,
|
| 1434 |
+
"rstrip": false,
|
| 1435 |
+
"single_word": false,
|
| 1436 |
+
"special": true
|
| 1437 |
+
},
|
| 1438 |
+
"179": {
|
| 1439 |
+
"content": "<mask_155>",
|
| 1440 |
+
"lstrip": false,
|
| 1441 |
+
"normalized": false,
|
| 1442 |
+
"rstrip": false,
|
| 1443 |
+
"single_word": false,
|
| 1444 |
+
"special": true
|
| 1445 |
+
},
|
| 1446 |
+
"180": {
|
| 1447 |
+
"content": "<mask_156>",
|
| 1448 |
+
"lstrip": false,
|
| 1449 |
+
"normalized": false,
|
| 1450 |
+
"rstrip": false,
|
| 1451 |
+
"single_word": false,
|
| 1452 |
+
"special": true
|
| 1453 |
+
},
|
| 1454 |
+
"181": {
|
| 1455 |
+
"content": "<mask_157>",
|
| 1456 |
+
"lstrip": false,
|
| 1457 |
+
"normalized": false,
|
| 1458 |
+
"rstrip": false,
|
| 1459 |
+
"single_word": false,
|
| 1460 |
+
"special": true
|
| 1461 |
+
},
|
| 1462 |
+
"182": {
|
| 1463 |
+
"content": "<mask_158>",
|
| 1464 |
+
"lstrip": false,
|
| 1465 |
+
"normalized": false,
|
| 1466 |
+
"rstrip": false,
|
| 1467 |
+
"single_word": false,
|
| 1468 |
+
"special": true
|
| 1469 |
+
},
|
| 1470 |
+
"183": {
|
| 1471 |
+
"content": "<mask_159>",
|
| 1472 |
+
"lstrip": false,
|
| 1473 |
+
"normalized": false,
|
| 1474 |
+
"rstrip": false,
|
| 1475 |
+
"single_word": false,
|
| 1476 |
+
"special": true
|
| 1477 |
+
},
|
| 1478 |
+
"184": {
|
| 1479 |
+
"content": "<mask_160>",
|
| 1480 |
+
"lstrip": false,
|
| 1481 |
+
"normalized": false,
|
| 1482 |
+
"rstrip": false,
|
| 1483 |
+
"single_word": false,
|
| 1484 |
+
"special": true
|
| 1485 |
+
},
|
| 1486 |
+
"185": {
|
| 1487 |
+
"content": "<mask_161>",
|
| 1488 |
+
"lstrip": false,
|
| 1489 |
+
"normalized": false,
|
| 1490 |
+
"rstrip": false,
|
| 1491 |
+
"single_word": false,
|
| 1492 |
+
"special": true
|
| 1493 |
+
},
|
| 1494 |
+
"186": {
|
| 1495 |
+
"content": "<mask_162>",
|
| 1496 |
+
"lstrip": false,
|
| 1497 |
+
"normalized": false,
|
| 1498 |
+
"rstrip": false,
|
| 1499 |
+
"single_word": false,
|
| 1500 |
+
"special": true
|
| 1501 |
+
},
|
| 1502 |
+
"187": {
|
| 1503 |
+
"content": "<mask_163>",
|
| 1504 |
+
"lstrip": false,
|
| 1505 |
+
"normalized": false,
|
| 1506 |
+
"rstrip": false,
|
| 1507 |
+
"single_word": false,
|
| 1508 |
+
"special": true
|
| 1509 |
+
},
|
| 1510 |
+
"188": {
|
| 1511 |
+
"content": "<mask_164>",
|
| 1512 |
+
"lstrip": false,
|
| 1513 |
+
"normalized": false,
|
| 1514 |
+
"rstrip": false,
|
| 1515 |
+
"single_word": false,
|
| 1516 |
+
"special": true
|
| 1517 |
+
},
|
| 1518 |
+
"189": {
|
| 1519 |
+
"content": "<mask_165>",
|
| 1520 |
+
"lstrip": false,
|
| 1521 |
+
"normalized": false,
|
| 1522 |
+
"rstrip": false,
|
| 1523 |
+
"single_word": false,
|
| 1524 |
+
"special": true
|
| 1525 |
+
},
|
| 1526 |
+
"190": {
|
| 1527 |
+
"content": "<mask_166>",
|
| 1528 |
+
"lstrip": false,
|
| 1529 |
+
"normalized": false,
|
| 1530 |
+
"rstrip": false,
|
| 1531 |
+
"single_word": false,
|
| 1532 |
+
"special": true
|
| 1533 |
+
},
|
| 1534 |
+
"191": {
|
| 1535 |
+
"content": "<mask_167>",
|
| 1536 |
+
"lstrip": false,
|
| 1537 |
+
"normalized": false,
|
| 1538 |
+
"rstrip": false,
|
| 1539 |
+
"single_word": false,
|
| 1540 |
+
"special": true
|
| 1541 |
+
},
|
| 1542 |
+
"192": {
|
| 1543 |
+
"content": "<mask_168>",
|
| 1544 |
+
"lstrip": false,
|
| 1545 |
+
"normalized": false,
|
| 1546 |
+
"rstrip": false,
|
| 1547 |
+
"single_word": false,
|
| 1548 |
+
"special": true
|
| 1549 |
+
},
|
| 1550 |
+
"193": {
|
| 1551 |
+
"content": "<mask_169>",
|
| 1552 |
+
"lstrip": false,
|
| 1553 |
+
"normalized": false,
|
| 1554 |
+
"rstrip": false,
|
| 1555 |
+
"single_word": false,
|
| 1556 |
+
"special": true
|
| 1557 |
+
},
|
| 1558 |
+
"194": {
|
| 1559 |
+
"content": "<mask_170>",
|
| 1560 |
+
"lstrip": false,
|
| 1561 |
+
"normalized": false,
|
| 1562 |
+
"rstrip": false,
|
| 1563 |
+
"single_word": false,
|
| 1564 |
+
"special": true
|
| 1565 |
+
},
|
| 1566 |
+
"195": {
|
| 1567 |
+
"content": "<mask_171>",
|
| 1568 |
+
"lstrip": false,
|
| 1569 |
+
"normalized": false,
|
| 1570 |
+
"rstrip": false,
|
| 1571 |
+
"single_word": false,
|
| 1572 |
+
"special": true
|
| 1573 |
+
},
|
| 1574 |
+
"196": {
|
| 1575 |
+
"content": "<mask_172>",
|
| 1576 |
+
"lstrip": false,
|
| 1577 |
+
"normalized": false,
|
| 1578 |
+
"rstrip": false,
|
| 1579 |
+
"single_word": false,
|
| 1580 |
+
"special": true
|
| 1581 |
+
},
|
| 1582 |
+
"197": {
|
| 1583 |
+
"content": "<mask_173>",
|
| 1584 |
+
"lstrip": false,
|
| 1585 |
+
"normalized": false,
|
| 1586 |
+
"rstrip": false,
|
| 1587 |
+
"single_word": false,
|
| 1588 |
+
"special": true
|
| 1589 |
+
},
|
| 1590 |
+
"198": {
|
| 1591 |
+
"content": "<mask_174>",
|
| 1592 |
+
"lstrip": false,
|
| 1593 |
+
"normalized": false,
|
| 1594 |
+
"rstrip": false,
|
| 1595 |
+
"single_word": false,
|
| 1596 |
+
"special": true
|
| 1597 |
+
},
|
| 1598 |
+
"199": {
|
| 1599 |
+
"content": "<mask_175>",
|
| 1600 |
+
"lstrip": false,
|
| 1601 |
+
"normalized": false,
|
| 1602 |
+
"rstrip": false,
|
| 1603 |
+
"single_word": false,
|
| 1604 |
+
"special": true
|
| 1605 |
+
},
|
| 1606 |
+
"200": {
|
| 1607 |
+
"content": "<mask_176>",
|
| 1608 |
+
"lstrip": false,
|
| 1609 |
+
"normalized": false,
|
| 1610 |
+
"rstrip": false,
|
| 1611 |
+
"single_word": false,
|
| 1612 |
+
"special": true
|
| 1613 |
+
},
|
| 1614 |
+
"201": {
|
| 1615 |
+
"content": "<mask_177>",
|
| 1616 |
+
"lstrip": false,
|
| 1617 |
+
"normalized": false,
|
| 1618 |
+
"rstrip": false,
|
| 1619 |
+
"single_word": false,
|
| 1620 |
+
"special": true
|
| 1621 |
+
},
|
| 1622 |
+
"202": {
|
| 1623 |
+
"content": "<mask_178>",
|
| 1624 |
+
"lstrip": false,
|
| 1625 |
+
"normalized": false,
|
| 1626 |
+
"rstrip": false,
|
| 1627 |
+
"single_word": false,
|
| 1628 |
+
"special": true
|
| 1629 |
+
},
|
| 1630 |
+
"203": {
|
| 1631 |
+
"content": "<mask_179>",
|
| 1632 |
+
"lstrip": false,
|
| 1633 |
+
"normalized": false,
|
| 1634 |
+
"rstrip": false,
|
| 1635 |
+
"single_word": false,
|
| 1636 |
+
"special": true
|
| 1637 |
+
},
|
| 1638 |
+
"204": {
|
| 1639 |
+
"content": "<mask_180>",
|
| 1640 |
+
"lstrip": false,
|
| 1641 |
+
"normalized": false,
|
| 1642 |
+
"rstrip": false,
|
| 1643 |
+
"single_word": false,
|
| 1644 |
+
"special": true
|
| 1645 |
+
},
|
| 1646 |
+
"205": {
|
| 1647 |
+
"content": "<mask_181>",
|
| 1648 |
+
"lstrip": false,
|
| 1649 |
+
"normalized": false,
|
| 1650 |
+
"rstrip": false,
|
| 1651 |
+
"single_word": false,
|
| 1652 |
+
"special": true
|
| 1653 |
+
},
|
| 1654 |
+
"206": {
|
| 1655 |
+
"content": "<mask_182>",
|
| 1656 |
+
"lstrip": false,
|
| 1657 |
+
"normalized": false,
|
| 1658 |
+
"rstrip": false,
|
| 1659 |
+
"single_word": false,
|
| 1660 |
+
"special": true
|
| 1661 |
+
},
|
| 1662 |
+
"207": {
|
| 1663 |
+
"content": "<mask_183>",
|
| 1664 |
+
"lstrip": false,
|
| 1665 |
+
"normalized": false,
|
| 1666 |
+
"rstrip": false,
|
| 1667 |
+
"single_word": false,
|
| 1668 |
+
"special": true
|
| 1669 |
+
},
|
| 1670 |
+
"208": {
|
| 1671 |
+
"content": "<mask_184>",
|
| 1672 |
+
"lstrip": false,
|
| 1673 |
+
"normalized": false,
|
| 1674 |
+
"rstrip": false,
|
| 1675 |
+
"single_word": false,
|
| 1676 |
+
"special": true
|
| 1677 |
+
},
|
| 1678 |
+
"209": {
|
| 1679 |
+
"content": "<mask_185>",
|
| 1680 |
+
"lstrip": false,
|
| 1681 |
+
"normalized": false,
|
| 1682 |
+
"rstrip": false,
|
| 1683 |
+
"single_word": false,
|
| 1684 |
+
"special": true
|
| 1685 |
+
},
|
| 1686 |
+
"210": {
|
| 1687 |
+
"content": "<mask_186>",
|
| 1688 |
+
"lstrip": false,
|
| 1689 |
+
"normalized": false,
|
| 1690 |
+
"rstrip": false,
|
| 1691 |
+
"single_word": false,
|
| 1692 |
+
"special": true
|
| 1693 |
+
},
|
| 1694 |
+
"211": {
|
| 1695 |
+
"content": "<mask_187>",
|
| 1696 |
+
"lstrip": false,
|
| 1697 |
+
"normalized": false,
|
| 1698 |
+
"rstrip": false,
|
| 1699 |
+
"single_word": false,
|
| 1700 |
+
"special": true
|
| 1701 |
+
},
|
| 1702 |
+
"212": {
|
| 1703 |
+
"content": "<mask_188>",
|
| 1704 |
+
"lstrip": false,
|
| 1705 |
+
"normalized": false,
|
| 1706 |
+
"rstrip": false,
|
| 1707 |
+
"single_word": false,
|
| 1708 |
+
"special": true
|
| 1709 |
+
},
|
| 1710 |
+
"213": {
|
| 1711 |
+
"content": "<mask_189>",
|
| 1712 |
+
"lstrip": false,
|
| 1713 |
+
"normalized": false,
|
| 1714 |
+
"rstrip": false,
|
| 1715 |
+
"single_word": false,
|
| 1716 |
+
"special": true
|
| 1717 |
+
},
|
| 1718 |
+
"214": {
|
| 1719 |
+
"content": "<mask_190>",
|
| 1720 |
+
"lstrip": false,
|
| 1721 |
+
"normalized": false,
|
| 1722 |
+
"rstrip": false,
|
| 1723 |
+
"single_word": false,
|
| 1724 |
+
"special": true
|
| 1725 |
+
},
|
| 1726 |
+
"215": {
|
| 1727 |
+
"content": "<mask_191>",
|
| 1728 |
+
"lstrip": false,
|
| 1729 |
+
"normalized": false,
|
| 1730 |
+
"rstrip": false,
|
| 1731 |
+
"single_word": false,
|
| 1732 |
+
"special": true
|
| 1733 |
+
},
|
| 1734 |
+
"216": {
|
| 1735 |
+
"content": "<mask_192>",
|
| 1736 |
+
"lstrip": false,
|
| 1737 |
+
"normalized": false,
|
| 1738 |
+
"rstrip": false,
|
| 1739 |
+
"single_word": false,
|
| 1740 |
+
"special": true
|
| 1741 |
+
},
|
| 1742 |
+
"217": {
|
| 1743 |
+
"content": "<mask_193>",
|
| 1744 |
+
"lstrip": false,
|
| 1745 |
+
"normalized": false,
|
| 1746 |
+
"rstrip": false,
|
| 1747 |
+
"single_word": false,
|
| 1748 |
+
"special": true
|
| 1749 |
+
},
|
| 1750 |
+
"218": {
|
| 1751 |
+
"content": "<mask_194>",
|
| 1752 |
+
"lstrip": false,
|
| 1753 |
+
"normalized": false,
|
| 1754 |
+
"rstrip": false,
|
| 1755 |
+
"single_word": false,
|
| 1756 |
+
"special": true
|
| 1757 |
+
},
|
| 1758 |
+
"219": {
|
| 1759 |
+
"content": "<mask_195>",
|
| 1760 |
+
"lstrip": false,
|
| 1761 |
+
"normalized": false,
|
| 1762 |
+
"rstrip": false,
|
| 1763 |
+
"single_word": false,
|
| 1764 |
+
"special": true
|
| 1765 |
+
},
|
| 1766 |
+
"220": {
|
| 1767 |
+
"content": "<mask_196>",
|
| 1768 |
+
"lstrip": false,
|
| 1769 |
+
"normalized": false,
|
| 1770 |
+
"rstrip": false,
|
| 1771 |
+
"single_word": false,
|
| 1772 |
+
"special": true
|
| 1773 |
+
},
|
| 1774 |
+
"221": {
|
| 1775 |
+
"content": "<mask_197>",
|
| 1776 |
+
"lstrip": false,
|
| 1777 |
+
"normalized": false,
|
| 1778 |
+
"rstrip": false,
|
| 1779 |
+
"single_word": false,
|
| 1780 |
+
"special": true
|
| 1781 |
+
},
|
| 1782 |
+
"222": {
|
| 1783 |
+
"content": "<mask_198>",
|
| 1784 |
+
"lstrip": false,
|
| 1785 |
+
"normalized": false,
|
| 1786 |
+
"rstrip": false,
|
| 1787 |
+
"single_word": false,
|
| 1788 |
+
"special": true
|
| 1789 |
+
},
|
| 1790 |
+
"223": {
|
| 1791 |
+
"content": "<mask_199>",
|
| 1792 |
+
"lstrip": false,
|
| 1793 |
+
"normalized": false,
|
| 1794 |
+
"rstrip": false,
|
| 1795 |
+
"single_word": false,
|
| 1796 |
+
"special": true
|
| 1797 |
+
},
|
| 1798 |
+
"131072": {
|
| 1799 |
+
"content": "<mask_131048>",
|
| 1800 |
+
"lstrip": false,
|
| 1801 |
+
"normalized": false,
|
| 1802 |
+
"rstrip": false,
|
| 1803 |
+
"single_word": false,
|
| 1804 |
+
"special": true
|
| 1805 |
+
},
|
| 1806 |
+
"131073": {
|
| 1807 |
+
"content": "<mask_131049>",
|
| 1808 |
+
"lstrip": false,
|
| 1809 |
+
"normalized": false,
|
| 1810 |
+
"rstrip": false,
|
| 1811 |
+
"single_word": false,
|
| 1812 |
+
"special": true
|
| 1813 |
+
},
|
| 1814 |
+
"131074": {
|
| 1815 |
+
"content": "<mask_131050>",
|
| 1816 |
+
"lstrip": false,
|
| 1817 |
+
"normalized": false,
|
| 1818 |
+
"rstrip": false,
|
| 1819 |
+
"single_word": false,
|
| 1820 |
+
"special": true
|
| 1821 |
+
},
|
| 1822 |
+
"131075": {
|
| 1823 |
+
"content": "<mask_131051>",
|
| 1824 |
+
"lstrip": false,
|
| 1825 |
+
"normalized": false,
|
| 1826 |
+
"rstrip": false,
|
| 1827 |
+
"single_word": false,
|
| 1828 |
+
"special": true
|
| 1829 |
+
},
|
| 1830 |
+
"131076": {
|
| 1831 |
+
"content": "<mask_131052>",
|
| 1832 |
+
"lstrip": false,
|
| 1833 |
+
"normalized": false,
|
| 1834 |
+
"rstrip": false,
|
| 1835 |
+
"single_word": false,
|
| 1836 |
+
"special": true
|
| 1837 |
+
},
|
| 1838 |
+
"131077": {
|
| 1839 |
+
"content": "<mask_131053>",
|
| 1840 |
+
"lstrip": false,
|
| 1841 |
+
"normalized": false,
|
| 1842 |
+
"rstrip": false,
|
| 1843 |
+
"single_word": false,
|
| 1844 |
+
"special": true
|
| 1845 |
+
},
|
| 1846 |
+
"131078": {
|
| 1847 |
+
"content": "<mask_131054>",
|
| 1848 |
+
"lstrip": false,
|
| 1849 |
+
"normalized": false,
|
| 1850 |
+
"rstrip": false,
|
| 1851 |
+
"single_word": false,
|
| 1852 |
+
"special": true
|
| 1853 |
+
},
|
| 1854 |
+
"131079": {
|
| 1855 |
+
"content": "<mask_131055>",
|
| 1856 |
+
"lstrip": false,
|
| 1857 |
+
"normalized": false,
|
| 1858 |
+
"rstrip": false,
|
| 1859 |
+
"single_word": false,
|
| 1860 |
+
"special": true
|
| 1861 |
+
},
|
| 1862 |
+
"131080": {
|
| 1863 |
+
"content": "<mask_131056>",
|
| 1864 |
+
"lstrip": false,
|
| 1865 |
+
"normalized": false,
|
| 1866 |
+
"rstrip": false,
|
| 1867 |
+
"single_word": false,
|
| 1868 |
+
"special": true
|
| 1869 |
+
},
|
| 1870 |
+
"131081": {
|
| 1871 |
+
"content": "<mask_131057>",
|
| 1872 |
+
"lstrip": false,
|
| 1873 |
+
"normalized": false,
|
| 1874 |
+
"rstrip": false,
|
| 1875 |
+
"single_word": false,
|
| 1876 |
+
"special": true
|
| 1877 |
+
},
|
| 1878 |
+
"131082": {
|
| 1879 |
+
"content": "<mask_131058>",
|
| 1880 |
+
"lstrip": false,
|
| 1881 |
+
"normalized": false,
|
| 1882 |
+
"rstrip": false,
|
| 1883 |
+
"single_word": false,
|
| 1884 |
+
"special": true
|
| 1885 |
+
},
|
| 1886 |
+
"131083": {
|
| 1887 |
+
"content": "<mask_131059>",
|
| 1888 |
+
"lstrip": false,
|
| 1889 |
+
"normalized": false,
|
| 1890 |
+
"rstrip": false,
|
| 1891 |
+
"single_word": false,
|
| 1892 |
+
"special": true
|
| 1893 |
+
},
|
| 1894 |
+
"131084": {
|
| 1895 |
+
"content": "<mask_131060>",
|
| 1896 |
+
"lstrip": false,
|
| 1897 |
+
"normalized": false,
|
| 1898 |
+
"rstrip": false,
|
| 1899 |
+
"single_word": false,
|
| 1900 |
+
"special": true
|
| 1901 |
+
},
|
| 1902 |
+
"131085": {
|
| 1903 |
+
"content": "<mask_131061>",
|
| 1904 |
+
"lstrip": false,
|
| 1905 |
+
"normalized": false,
|
| 1906 |
+
"rstrip": false,
|
| 1907 |
+
"single_word": false,
|
| 1908 |
+
"special": true
|
| 1909 |
+
},
|
| 1910 |
+
"131086": {
|
| 1911 |
+
"content": "<mask_131062>",
|
| 1912 |
+
"lstrip": false,
|
| 1913 |
+
"normalized": false,
|
| 1914 |
+
"rstrip": false,
|
| 1915 |
+
"single_word": false,
|
| 1916 |
+
"special": true
|
| 1917 |
+
},
|
| 1918 |
+
"131087": {
|
| 1919 |
+
"content": "<mask_131063>",
|
| 1920 |
+
"lstrip": false,
|
| 1921 |
+
"normalized": false,
|
| 1922 |
+
"rstrip": false,
|
| 1923 |
+
"single_word": false,
|
| 1924 |
+
"special": true
|
| 1925 |
+
},
|
| 1926 |
+
"131088": {
|
| 1927 |
+
"content": "<mask_131064>",
|
| 1928 |
+
"lstrip": false,
|
| 1929 |
+
"normalized": false,
|
| 1930 |
+
"rstrip": false,
|
| 1931 |
+
"single_word": false,
|
| 1932 |
+
"special": true
|
| 1933 |
+
},
|
| 1934 |
+
"131089": {
|
| 1935 |
+
"content": "<mask_131065>",
|
| 1936 |
+
"lstrip": false,
|
| 1937 |
+
"normalized": false,
|
| 1938 |
+
"rstrip": false,
|
| 1939 |
+
"single_word": false,
|
| 1940 |
+
"special": true
|
| 1941 |
+
},
|
| 1942 |
+
"131090": {
|
| 1943 |
+
"content": "<longcat_img_token_size>",
|
| 1944 |
+
"lstrip": false,
|
| 1945 |
+
"normalized": false,
|
| 1946 |
+
"rstrip": false,
|
| 1947 |
+
"single_word": false,
|
| 1948 |
+
"special": true
|
| 1949 |
+
},
|
| 1950 |
+
"131091": {
|
| 1951 |
+
"content": "</longcat_img_token_size>",
|
| 1952 |
+
"lstrip": false,
|
| 1953 |
+
"normalized": false,
|
| 1954 |
+
"rstrip": false,
|
| 1955 |
+
"single_word": false,
|
| 1956 |
+
"special": true
|
| 1957 |
+
},
|
| 1958 |
+
"131092": {
|
| 1959 |
+
"content": "<mask_131068>",
|
| 1960 |
+
"lstrip": false,
|
| 1961 |
+
"normalized": false,
|
| 1962 |
+
"rstrip": false,
|
| 1963 |
+
"single_word": false,
|
| 1964 |
+
"special": true
|
| 1965 |
+
},
|
| 1966 |
+
"131093": {
|
| 1967 |
+
"content": "<mask_131069>",
|
| 1968 |
+
"lstrip": false,
|
| 1969 |
+
"normalized": false,
|
| 1970 |
+
"rstrip": false,
|
| 1971 |
+
"single_word": false,
|
| 1972 |
+
"special": true
|
| 1973 |
+
},
|
| 1974 |
+
"131094": {
|
| 1975 |
+
"content": "<mask_131070>",
|
| 1976 |
+
"lstrip": false,
|
| 1977 |
+
"normalized": false,
|
| 1978 |
+
"rstrip": false,
|
| 1979 |
+
"single_word": false,
|
| 1980 |
+
"special": true
|
| 1981 |
+
},
|
| 1982 |
+
"131095": {
|
| 1983 |
+
"content": "<mask_131071>",
|
| 1984 |
+
"lstrip": false,
|
| 1985 |
+
"normalized": false,
|
| 1986 |
+
"rstrip": false,
|
| 1987 |
+
"single_word": false,
|
| 1988 |
+
"special": true
|
| 1989 |
+
},
|
| 1990 |
+
"131096": {
|
| 1991 |
+
"content": "<longcat_point_start>",
|
| 1992 |
+
"lstrip": false,
|
| 1993 |
+
"normalized": false,
|
| 1994 |
+
"rstrip": false,
|
| 1995 |
+
"single_word": false,
|
| 1996 |
+
"special": true
|
| 1997 |
+
},
|
| 1998 |
+
"131097": {
|
| 1999 |
+
"content": "<longcat_point_end>",
|
| 2000 |
+
"lstrip": false,
|
| 2001 |
+
"normalized": false,
|
| 2002 |
+
"rstrip": false,
|
| 2003 |
+
"single_word": false,
|
| 2004 |
+
"special": true
|
| 2005 |
+
},
|
| 2006 |
+
"131098": {
|
| 2007 |
+
"content": "<longcat_point_delim>",
|
| 2008 |
+
"lstrip": false,
|
| 2009 |
+
"normalized": false,
|
| 2010 |
+
"rstrip": false,
|
| 2011 |
+
"single_word": false,
|
| 2012 |
+
"special": true
|
| 2013 |
+
},
|
| 2014 |
+
"131099": {
|
| 2015 |
+
"content": "<longcat_polygon_start>",
|
| 2016 |
+
"lstrip": false,
|
| 2017 |
+
"normalized": false,
|
| 2018 |
+
"rstrip": false,
|
| 2019 |
+
"single_word": false,
|
| 2020 |
+
"special": true
|
| 2021 |
+
},
|
| 2022 |
+
"131100": {
|
| 2023 |
+
"content": "<longcat_polygon_end>",
|
| 2024 |
+
"lstrip": false,
|
| 2025 |
+
"normalized": false,
|
| 2026 |
+
"rstrip": false,
|
| 2027 |
+
"single_word": false,
|
| 2028 |
+
"special": true
|
| 2029 |
+
},
|
| 2030 |
+
"131101": {
|
| 2031 |
+
"content": "<mask_131077>",
|
| 2032 |
+
"lstrip": false,
|
| 2033 |
+
"normalized": false,
|
| 2034 |
+
"rstrip": false,
|
| 2035 |
+
"single_word": false,
|
| 2036 |
+
"special": true
|
| 2037 |
+
},
|
| 2038 |
+
"131102": {
|
| 2039 |
+
"content": "<mask_131078>",
|
| 2040 |
+
"lstrip": false,
|
| 2041 |
+
"normalized": false,
|
| 2042 |
+
"rstrip": false,
|
| 2043 |
+
"single_word": false,
|
| 2044 |
+
"special": true
|
| 2045 |
+
},
|
| 2046 |
+
"131103": {
|
| 2047 |
+
"content": "<longcat_audio_start>",
|
| 2048 |
+
"lstrip": false,
|
| 2049 |
+
"normalized": false,
|
| 2050 |
+
"rstrip": false,
|
| 2051 |
+
"single_word": false,
|
| 2052 |
+
"special": true
|
| 2053 |
+
},
|
| 2054 |
+
"131104": {
|
| 2055 |
+
"content": "<longcat_audio_end>",
|
| 2056 |
+
"lstrip": false,
|
| 2057 |
+
"normalized": false,
|
| 2058 |
+
"rstrip": false,
|
| 2059 |
+
"single_word": false,
|
| 2060 |
+
"special": true
|
| 2061 |
+
},
|
| 2062 |
+
"131105": {
|
| 2063 |
+
"content": "<longcat_audio_pad>",
|
| 2064 |
+
"lstrip": false,
|
| 2065 |
+
"normalized": false,
|
| 2066 |
+
"rstrip": false,
|
| 2067 |
+
"single_word": false,
|
| 2068 |
+
"special": true
|
| 2069 |
+
},
|
| 2070 |
+
"131106": {
|
| 2071 |
+
"content": "<longcat_img_start>",
|
| 2072 |
+
"lstrip": false,
|
| 2073 |
+
"normalized": false,
|
| 2074 |
+
"rstrip": false,
|
| 2075 |
+
"single_word": false,
|
| 2076 |
+
"special": true
|
| 2077 |
+
},
|
| 2078 |
+
"131107": {
|
| 2079 |
+
"content": "<longcat_img_end>",
|
| 2080 |
+
"lstrip": false,
|
| 2081 |
+
"normalized": false,
|
| 2082 |
+
"rstrip": false,
|
| 2083 |
+
"single_word": false,
|
| 2084 |
+
"special": true
|
| 2085 |
+
},
|
| 2086 |
+
"131108": {
|
| 2087 |
+
"content": "<longcat_img_pad>",
|
| 2088 |
+
"lstrip": false,
|
| 2089 |
+
"normalized": false,
|
| 2090 |
+
"rstrip": false,
|
| 2091 |
+
"single_word": false,
|
| 2092 |
+
"special": true
|
| 2093 |
+
},
|
| 2094 |
+
"131109": {
|
| 2095 |
+
"content": "<longcat_img_newline>",
|
| 2096 |
+
"lstrip": false,
|
| 2097 |
+
"normalized": false,
|
| 2098 |
+
"rstrip": false,
|
| 2099 |
+
"single_word": false,
|
| 2100 |
+
"special": true
|
| 2101 |
+
},
|
| 2102 |
+
"131110": {
|
| 2103 |
+
"content": "<longcat_box_start>",
|
| 2104 |
+
"lstrip": false,
|
| 2105 |
+
"normalized": false,
|
| 2106 |
+
"rstrip": false,
|
| 2107 |
+
"single_word": false,
|
| 2108 |
+
"special": true
|
| 2109 |
+
},
|
| 2110 |
+
"131111": {
|
| 2111 |
+
"content": "<longcat_box_end>",
|
| 2112 |
+
"lstrip": false,
|
| 2113 |
+
"normalized": false,
|
| 2114 |
+
"rstrip": false,
|
| 2115 |
+
"single_word": false,
|
| 2116 |
+
"special": true
|
| 2117 |
+
},
|
| 2118 |
+
"131112": {
|
| 2119 |
+
"content": "<longcat_box_delim>",
|
| 2120 |
+
"lstrip": false,
|
| 2121 |
+
"normalized": false,
|
| 2122 |
+
"rstrip": false,
|
| 2123 |
+
"single_word": false,
|
| 2124 |
+
"special": true
|
| 2125 |
+
},
|
| 2126 |
+
"131113": {
|
| 2127 |
+
"content": "<longcat_ref_start>",
|
| 2128 |
+
"lstrip": false,
|
| 2129 |
+
"normalized": false,
|
| 2130 |
+
"rstrip": false,
|
| 2131 |
+
"single_word": false,
|
| 2132 |
+
"special": true
|
| 2133 |
+
},
|
| 2134 |
+
"131114": {
|
| 2135 |
+
"content": "<longcat_ref_end>",
|
| 2136 |
+
"lstrip": false,
|
| 2137 |
+
"normalized": false,
|
| 2138 |
+
"rstrip": false,
|
| 2139 |
+
"single_word": false,
|
| 2140 |
+
"special": true
|
| 2141 |
+
},
|
| 2142 |
+
"131115": {
|
| 2143 |
+
"content": "<longcat_img_delim>",
|
| 2144 |
+
"lstrip": false,
|
| 2145 |
+
"normalized": false,
|
| 2146 |
+
"rstrip": false,
|
| 2147 |
+
"single_word": false,
|
| 2148 |
+
"special": true
|
| 2149 |
+
},
|
| 2150 |
+
"131116": {
|
| 2151 |
+
"content": "<longcat_audio_delim>",
|
| 2152 |
+
"lstrip": false,
|
| 2153 |
+
"normalized": false,
|
| 2154 |
+
"rstrip": false,
|
| 2155 |
+
"single_word": false,
|
| 2156 |
+
"special": true
|
| 2157 |
+
},
|
| 2158 |
+
"131117": {
|
| 2159 |
+
"content": "<longcat_video_palce>",
|
| 2160 |
+
"lstrip": false,
|
| 2161 |
+
"normalized": false,
|
| 2162 |
+
"rstrip": false,
|
| 2163 |
+
"single_word": false,
|
| 2164 |
+
"special": true
|
| 2165 |
+
},
|
| 2166 |
+
"131118": {
|
| 2167 |
+
"content": "<longcat_video_start>",
|
| 2168 |
+
"lstrip": false,
|
| 2169 |
+
"normalized": false,
|
| 2170 |
+
"rstrip": false,
|
| 2171 |
+
"single_word": false,
|
| 2172 |
+
"special": true
|
| 2173 |
+
},
|
| 2174 |
+
"131119": {
|
| 2175 |
+
"content": "<longcat_video_end>",
|
| 2176 |
+
"lstrip": false,
|
| 2177 |
+
"normalized": false,
|
| 2178 |
+
"rstrip": false,
|
| 2179 |
+
"single_word": false,
|
| 2180 |
+
"special": true
|
| 2181 |
+
},
|
| 2182 |
+
"131120": {
|
| 2183 |
+
"content": "<longcat_audiotext_start>",
|
| 2184 |
+
"lstrip": false,
|
| 2185 |
+
"normalized": false,
|
| 2186 |
+
"rstrip": false,
|
| 2187 |
+
"single_word": false,
|
| 2188 |
+
"special": true
|
| 2189 |
+
},
|
| 2190 |
+
"131121": {
|
| 2191 |
+
"content": "<longcat_audiotext_end>",
|
| 2192 |
+
"lstrip": false,
|
| 2193 |
+
"normalized": false,
|
| 2194 |
+
"rstrip": false,
|
| 2195 |
+
"single_word": false,
|
| 2196 |
+
"special": true
|
| 2197 |
+
},
|
| 2198 |
+
"131122": {
|
| 2199 |
+
"content": "<longcat_audiotext_pad>",
|
| 2200 |
+
"lstrip": false,
|
| 2201 |
+
"normalized": false,
|
| 2202 |
+
"rstrip": false,
|
| 2203 |
+
"single_word": false,
|
| 2204 |
+
"special": true
|
| 2205 |
+
},
|
| 2206 |
+
"131123": {
|
| 2207 |
+
"content": "<longcat_audiogen_start>",
|
| 2208 |
+
"lstrip": false,
|
| 2209 |
+
"normalized": false,
|
| 2210 |
+
"rstrip": false,
|
| 2211 |
+
"single_word": false,
|
| 2212 |
+
"special": true
|
| 2213 |
+
},
|
| 2214 |
+
"131124": {
|
| 2215 |
+
"content": "<longcat_audiogen_end>",
|
| 2216 |
+
"lstrip": false,
|
| 2217 |
+
"normalized": false,
|
| 2218 |
+
"rstrip": false,
|
| 2219 |
+
"single_word": false,
|
| 2220 |
+
"special": true
|
| 2221 |
+
}
|
| 2222 |
+
},
|
| 2223 |
+
"additional_special_tokens": [
|
| 2224 |
+
"<mask_131048>",
|
| 2225 |
+
"<mask_131049>",
|
| 2226 |
+
"<mask_131050>",
|
| 2227 |
+
"<mask_131051>",
|
| 2228 |
+
"<mask_131052>",
|
| 2229 |
+
"<mask_131053>",
|
| 2230 |
+
"<mask_131054>",
|
| 2231 |
+
"<mask_131055>",
|
| 2232 |
+
"<mask_131056>",
|
| 2233 |
+
"<mask_131057>",
|
| 2234 |
+
"<mask_131058>",
|
| 2235 |
+
"<mask_131059>",
|
| 2236 |
+
"<mask_131060>",
|
| 2237 |
+
"<mask_131061>",
|
| 2238 |
+
"<mask_131062>",
|
| 2239 |
+
"<mask_131063>",
|
| 2240 |
+
"<mask_131064>",
|
| 2241 |
+
"<mask_131065>",
|
| 2242 |
+
"<longcat_img_token_size>",
|
| 2243 |
+
"</longcat_img_token_size>",
|
| 2244 |
+
"<mask_131068>",
|
| 2245 |
+
"<mask_131069>",
|
| 2246 |
+
"<mask_131070>",
|
| 2247 |
+
"<mask_131071>",
|
| 2248 |
+
"<longcat_point_start>",
|
| 2249 |
+
"<longcat_point_end>",
|
| 2250 |
+
"<longcat_point_delim>",
|
| 2251 |
+
"<longcat_polygon_start>",
|
| 2252 |
+
"<longcat_polygon_end>",
|
| 2253 |
+
"<mask_131077>",
|
| 2254 |
+
"<mask_131078>",
|
| 2255 |
+
"<longcat_audio_start>",
|
| 2256 |
+
"<longcat_audio_end>",
|
| 2257 |
+
"<longcat_audio_pad>",
|
| 2258 |
+
"<longcat_img_start>",
|
| 2259 |
+
"<longcat_img_end>",
|
| 2260 |
+
"<longcat_img_pad>",
|
| 2261 |
+
"<longcat_img_newline>",
|
| 2262 |
+
"<longcat_box_start>",
|
| 2263 |
+
"<longcat_box_end>",
|
| 2264 |
+
"<longcat_box_delim>",
|
| 2265 |
+
"<longcat_ref_start>",
|
| 2266 |
+
"<longcat_ref_end>",
|
| 2267 |
+
"<longcat_img_delim>",
|
| 2268 |
+
"<longcat_audio_delim>",
|
| 2269 |
+
"<longcat_video_palce>",
|
| 2270 |
+
"<longcat_video_start>",
|
| 2271 |
+
"<longcat_video_end>",
|
| 2272 |
+
"<longcat_audiotext_start>",
|
| 2273 |
+
"<longcat_audiotext_end>",
|
| 2274 |
+
"<longcat_audiotext_pad>",
|
| 2275 |
+
"<longcat_audiogen_start>",
|
| 2276 |
+
"<longcat_audiogen_end>"
|
| 2277 |
+
],
|
| 2278 |
+
"audio_end_token": "<longcat_audio_end>",
|
| 2279 |
+
"audio_pad_token": "<longcat_audio_pad>",
|
| 2280 |
+
"audio_start_token": "<longcat_audio_start>",
|
| 2281 |
+
"auto_map": {
|
| 2282 |
+
"AutoProcessor": "processing_longcat_next.LongcatNextProcessor"
|
| 2283 |
+
},
|
| 2284 |
+
"bos_token": "<longcat_s>",
|
| 2285 |
+
"clean_up_tokenization_spaces": false,
|
| 2286 |
+
"eos_token": "</longcat_s>",
|
| 2287 |
+
"extra_special_tokens": {},
|
| 2288 |
+
"image_end_token": "<longcat_img_end>",
|
| 2289 |
+
"image_newline_token": "<longcat_img_newline>",
|
| 2290 |
+
"image_pad_token": "<longcat_img_pad>",
|
| 2291 |
+
"image_start_token": "<longcat_img_start>",
|
| 2292 |
+
"merges_file": null,
|
| 2293 |
+
"model_max_length": 131072,
|
| 2294 |
+
"pad_token": "<longcat_pad>",
|
| 2295 |
+
"processor_class": "LongcatNextProcessor",
|
| 2296 |
+
"sp_model_kwargs": {},
|
| 2297 |
+
"tokenizer_class": "BloomTokenizer",
|
| 2298 |
+
"unk_token": "<longcat_unk>",
|
| 2299 |
+
"vocab_file": null
|
| 2300 |
+
}
|