CalamitousFelicitousness commited on
Commit
df9529d
·
verified ·
1 Parent(s): 6943141

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
llm_adapter/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AnimaLLMAdapter",
3
+ "_diffusers_version": "0.37.0",
4
+ "source_dim": 1024,
5
+ "target_dim": 1024,
6
+ "model_dim": 1024,
7
+ "num_layers": 6,
8
+ "num_heads": 16,
9
+ "mlp_ratio": 4.0,
10
+ "vocab_size": 32128,
11
+ "use_self_attn": true
12
+ }
llm_adapter/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:149d3c0ae9a1b76c5a02a722288a7eadeec306769e2a60f5b34513155c8a2105
3
+ size 269339368
llm_adapter/modeling_llm_adapter.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+
8
+ def rotate_half(x):
9
+ x1 = x[..., : x.shape[-1] // 2]
10
+ x2 = x[..., x.shape[-1] // 2 :]
11
+ return torch.cat((-x2, x1), dim=-1)
12
+
13
+
14
+ def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1):
15
+ cos = cos.unsqueeze(unsqueeze_dim)
16
+ sin = sin.unsqueeze(unsqueeze_dim)
17
+ return (x * cos) + (rotate_half(x) * sin)
18
+
19
+
20
+ class RotaryEmbedding(nn.Module):
21
+ def __init__(self, head_dim):
22
+ super().__init__()
23
+ self.rope_theta = 10000
24
+ inv_freq = 1.0 / (
25
+ self.rope_theta
26
+ ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)
27
+ )
28
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
29
+
30
+ @torch.no_grad()
31
+ def forward(self, x, position_ids):
32
+ inv_freq_expanded = (
33
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
34
+ )
35
+ position_ids_expanded = position_ids[:, None, :].float()
36
+
37
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
38
+ with torch.autocast(device_type=device_type, enabled=False):
39
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
40
+ emb = torch.cat((freqs, freqs), dim=-1)
41
+ cos = emb.cos()
42
+ sin = emb.sin()
43
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(self, query_dim, context_dim, n_heads, head_dim):
48
+ super().__init__()
49
+ inner_dim = head_dim * n_heads
50
+ self.n_heads = n_heads
51
+ self.head_dim = head_dim
52
+
53
+ self.q_proj = nn.Linear(query_dim, inner_dim, bias=False)
54
+ self.q_norm = nn.RMSNorm(head_dim, eps=1e-6)
55
+ self.k_proj = nn.Linear(context_dim, inner_dim, bias=False)
56
+ self.k_norm = nn.RMSNorm(head_dim, eps=1e-6)
57
+ self.v_proj = nn.Linear(context_dim, inner_dim, bias=False)
58
+ self.o_proj = nn.Linear(inner_dim, query_dim, bias=False)
59
+
60
+ def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None):
61
+ context = x if context is None else context
62
+ input_shape = x.shape[:-1]
63
+ q_shape = (*input_shape, self.n_heads, self.head_dim)
64
+ context_shape = context.shape[:-1]
65
+ kv_shape = (*context_shape, self.n_heads, self.head_dim)
66
+
67
+ query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2)
68
+ key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2)
69
+ value_states = self.v_proj(context).view(kv_shape).transpose(1, 2)
70
+
71
+ if position_embeddings is not None:
72
+ assert position_embeddings_context is not None
73
+ cos, sin = position_embeddings
74
+ query_states = apply_rotary_pos_emb(query_states, cos, sin)
75
+ cos, sin = position_embeddings_context
76
+ key_states = apply_rotary_pos_emb(key_states, cos, sin)
77
+
78
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask)
79
+ attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
80
+ return self.o_proj(attn_output)
81
+
82
+
83
+ class TransformerBlock(nn.Module):
84
+ def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=True):
85
+ super().__init__()
86
+ self.use_self_attn = use_self_attn
87
+
88
+ if self.use_self_attn:
89
+ self.norm_self_attn = nn.RMSNorm(model_dim, eps=1e-6)
90
+ self.self_attn = Attention(
91
+ query_dim=model_dim,
92
+ context_dim=model_dim,
93
+ n_heads=num_heads,
94
+ head_dim=model_dim // num_heads,
95
+ )
96
+
97
+ self.norm_cross_attn = nn.RMSNorm(model_dim, eps=1e-6)
98
+ self.cross_attn = Attention(
99
+ query_dim=model_dim,
100
+ context_dim=source_dim,
101
+ n_heads=num_heads,
102
+ head_dim=model_dim // num_heads,
103
+ )
104
+
105
+ self.norm_mlp = nn.RMSNorm(model_dim, eps=1e-6)
106
+ self.mlp = nn.Sequential(
107
+ nn.Linear(model_dim, int(model_dim * mlp_ratio)),
108
+ nn.GELU(),
109
+ nn.Linear(int(model_dim * mlp_ratio), model_dim),
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ x,
115
+ context,
116
+ target_attention_mask=None,
117
+ source_attention_mask=None,
118
+ position_embeddings=None,
119
+ position_embeddings_context=None,
120
+ ):
121
+ if self.use_self_attn:
122
+ normed = self.norm_self_attn(x)
123
+ attn_out = self.self_attn(
124
+ normed,
125
+ mask=target_attention_mask,
126
+ position_embeddings=position_embeddings,
127
+ position_embeddings_context=position_embeddings,
128
+ )
129
+ x = x + attn_out
130
+
131
+ normed = self.norm_cross_attn(x)
132
+ attn_out = self.cross_attn(
133
+ normed,
134
+ mask=source_attention_mask,
135
+ context=context,
136
+ position_embeddings=position_embeddings,
137
+ position_embeddings_context=position_embeddings_context,
138
+ )
139
+ x = x + attn_out
140
+ x = x + self.mlp(self.norm_mlp(x))
141
+ return x
142
+
143
+
144
+ class AnimaLLMAdapter(ModelMixin, ConfigMixin):
145
+ @register_to_config
146
+ def __init__(
147
+ self,
148
+ source_dim: int = 1024,
149
+ target_dim: int = 1024,
150
+ model_dim: int = 1024,
151
+ num_layers: int = 6,
152
+ num_heads: int = 16,
153
+ mlp_ratio: float = 4.0,
154
+ vocab_size: int = 32128,
155
+ use_self_attn: bool = True,
156
+ ):
157
+ super().__init__()
158
+
159
+ self.embed = nn.Embedding(vocab_size, target_dim)
160
+ if model_dim != target_dim:
161
+ self.in_proj = nn.Linear(target_dim, model_dim)
162
+ else:
163
+ self.in_proj = nn.Identity()
164
+ self.rotary_emb = RotaryEmbedding(model_dim // num_heads)
165
+ self.blocks = nn.ModuleList(
166
+ [
167
+ TransformerBlock(
168
+ source_dim,
169
+ model_dim,
170
+ num_heads=num_heads,
171
+ mlp_ratio=mlp_ratio,
172
+ use_self_attn=use_self_attn,
173
+ )
174
+ for _ in range(num_layers)
175
+ ]
176
+ )
177
+ self.out_proj = nn.Linear(model_dim, target_dim)
178
+ self.norm = nn.RMSNorm(target_dim, eps=1e-6)
179
+
180
+ def forward(
181
+ self,
182
+ source_hidden_states: torch.Tensor,
183
+ target_input_ids: torch.Tensor,
184
+ target_attention_mask: torch.Tensor = None,
185
+ source_attention_mask: torch.Tensor = None,
186
+ ) -> torch.Tensor:
187
+ if target_attention_mask is not None:
188
+ target_attention_mask = target_attention_mask.to(torch.bool)
189
+ if target_attention_mask.ndim == 2:
190
+ target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1)
191
+
192
+ if source_attention_mask is not None:
193
+ source_attention_mask = source_attention_mask.to(torch.bool)
194
+ if source_attention_mask.ndim == 2:
195
+ source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
196
+
197
+ x = self.in_proj(self.embed(target_input_ids))
198
+ context = source_hidden_states
199
+
200
+ position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
201
+ position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
202
+ position_embeddings = self.rotary_emb(x, position_ids)
203
+ position_embeddings_context = self.rotary_emb(x, position_ids_context)
204
+
205
+ for block in self.blocks:
206
+ x = block(
207
+ x,
208
+ context,
209
+ target_attention_mask=target_attention_mask,
210
+ source_attention_mask=source_attention_mask,
211
+ position_embeddings=position_embeddings,
212
+ position_embeddings_context=position_embeddings_context,
213
+ )
214
+
215
+ return self.norm(self.out_proj(x))
model_index.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "AnimaTextToImagePipeline"
5
+ ],
6
+ "_diffusers_version": "0.37.0",
7
+ "text_encoder": [
8
+ "transformers",
9
+ "Qwen3Model"
10
+ ],
11
+ "tokenizer": [
12
+ "transformers",
13
+ "PreTrainedTokenizerFast"
14
+ ],
15
+ "t5_tokenizer": [
16
+ "transformers",
17
+ "T5TokenizerFast"
18
+ ],
19
+ "llm_adapter": [
20
+ "modeling_llm_adapter",
21
+ "AnimaLLMAdapter"
22
+ ],
23
+ "transformer": [
24
+ "diffusers",
25
+ "CosmosTransformer3DModel"
26
+ ],
27
+ "vae": [
28
+ "diffusers",
29
+ "AutoencoderKLWan"
30
+ ],
31
+ "scheduler": [
32
+ "diffusers",
33
+ "FlowMatchEulerDiscreteScheduler"
34
+ ]
35
+ }
pipeline.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast
7
+
8
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
9
+ from diffusers.models import AutoencoderKLWan, CosmosTransformer3DModel
10
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.utils import logging
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.pipelines.cosmos.pipeline_output import CosmosImagePipelineOutput
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ def retrieve_timesteps(scheduler, num_inference_steps=None, device=None, timesteps=None, sigmas=None, **kwargs):
21
+ if timesteps is not None and sigmas is not None:
22
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
23
+ if timesteps is not None:
24
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
25
+ timesteps = scheduler.timesteps
26
+ num_inference_steps = len(timesteps)
27
+ elif sigmas is not None:
28
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
29
+ timesteps = scheduler.timesteps
30
+ num_inference_steps = len(timesteps)
31
+ else:
32
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
33
+ timesteps = scheduler.timesteps
34
+ return timesteps, num_inference_steps
35
+
36
+
37
+ class AnimaTextToImagePipeline(DiffusionPipeline):
38
+ """Pipeline for text-to-image generation using the Anima model.
39
+
40
+ Anima uses a Cosmos Predict2 backbone with a Qwen3 text encoder and an LLM adapter
41
+ that cross-attends T5 token embeddings to Qwen3 hidden states.
42
+ """
43
+
44
+ model_cpu_offload_seq = "text_encoder->llm_adapter->transformer->vae"
45
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
46
+
47
+ def __init__(
48
+ self,
49
+ text_encoder: PreTrainedModel,
50
+ tokenizer: PreTrainedTokenizerFast,
51
+ t5_tokenizer: PreTrainedTokenizerFast,
52
+ llm_adapter,
53
+ transformer: CosmosTransformer3DModel,
54
+ vae: AutoencoderKLWan,
55
+ scheduler: FlowMatchEulerDiscreteScheduler,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.register_modules(
60
+ text_encoder=text_encoder,
61
+ tokenizer=tokenizer,
62
+ t5_tokenizer=t5_tokenizer,
63
+ llm_adapter=llm_adapter,
64
+ transformer=transformer,
65
+ vae=vae,
66
+ scheduler=scheduler,
67
+ )
68
+
69
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
70
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
71
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
72
+
73
+ def _encode_prompt(
74
+ self,
75
+ prompt: Union[str, List[str]],
76
+ device: torch.device,
77
+ dtype: torch.dtype,
78
+ max_sequence_length: int = 512,
79
+ ):
80
+ """Encode prompt through Qwen3 and run LLM adapter with T5 token IDs."""
81
+ prompt = [prompt] if isinstance(prompt, str) else prompt
82
+ batch_size = len(prompt)
83
+
84
+ # Check for empty prompts - return zero embeddings directly
85
+ all_empty = all(p.strip() == "" for p in prompt)
86
+ if all_empty:
87
+ return torch.zeros(batch_size, 512, self.llm_adapter.config.target_dim, device=device, dtype=dtype)
88
+
89
+ # Tokenize with Qwen3 tokenizer
90
+ qwen_inputs = self.tokenizer(
91
+ prompt,
92
+ padding=True,
93
+ truncation=True,
94
+ max_length=max_sequence_length,
95
+ return_tensors="pt",
96
+ )
97
+ qwen_input_ids = qwen_inputs.input_ids.to(device)
98
+ qwen_attention_mask = qwen_inputs.attention_mask.to(device)
99
+
100
+ # Get Qwen3 hidden states
101
+ qwen_outputs = self.text_encoder(
102
+ input_ids=qwen_input_ids,
103
+ attention_mask=qwen_attention_mask,
104
+ )
105
+ qwen_hidden_states = qwen_outputs.last_hidden_state.to(dtype=dtype)
106
+
107
+ # Tokenize with T5 tokenizer (we only need the IDs for the adapter embedding)
108
+ t5_inputs = self.t5_tokenizer(
109
+ prompt,
110
+ padding=True,
111
+ truncation=True,
112
+ max_length=max_sequence_length,
113
+ return_tensors="pt",
114
+ )
115
+ t5_input_ids = t5_inputs.input_ids.to(device)
116
+
117
+ # Run LLM adapter: T5 token embeddings attend to Qwen3 hidden states
118
+ adapted_embeds = self.llm_adapter(
119
+ source_hidden_states=qwen_hidden_states,
120
+ target_input_ids=t5_input_ids,
121
+ )
122
+
123
+ # Pad to 512 sequence length if shorter
124
+ if adapted_embeds.shape[1] < 512:
125
+ adapted_embeds = torch.nn.functional.pad(
126
+ adapted_embeds, (0, 0, 0, 512 - adapted_embeds.shape[1])
127
+ )
128
+
129
+ return adapted_embeds
130
+
131
+ def encode_prompt(
132
+ self,
133
+ prompt: Union[str, List[str]],
134
+ negative_prompt: Optional[Union[str, List[str]]] = None,
135
+ do_classifier_free_guidance: bool = True,
136
+ num_images_per_prompt: int = 1,
137
+ prompt_embeds: Optional[torch.Tensor] = None,
138
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
139
+ max_sequence_length: int = 512,
140
+ device: Optional[torch.device] = None,
141
+ dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ device = device or self._execution_device
144
+ dtype = dtype or self.text_encoder.dtype
145
+ prompt = [prompt] if isinstance(prompt, str) else prompt
146
+
147
+ if prompt is not None:
148
+ batch_size = len(prompt)
149
+ else:
150
+ batch_size = prompt_embeds.shape[0]
151
+
152
+ if prompt_embeds is None:
153
+ prompt_embeds = self._encode_prompt(prompt, device, dtype, max_sequence_length)
154
+ _, seq_len, _ = prompt_embeds.shape
155
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
156
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
157
+
158
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
159
+ negative_prompt = negative_prompt or ""
160
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
161
+ negative_prompt_embeds = self._encode_prompt(negative_prompt, device, dtype, max_sequence_length)
162
+ _, seq_len, _ = negative_prompt_embeds.shape
163
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
164
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
165
+
166
+ return prompt_embeds, negative_prompt_embeds
167
+
168
+ def prepare_latents(
169
+ self,
170
+ batch_size: int,
171
+ num_channels_latents: int,
172
+ height: int,
173
+ width: int,
174
+ num_frames: int = 1,
175
+ dtype: torch.dtype = None,
176
+ device: torch.device = None,
177
+ generator=None,
178
+ latents: torch.Tensor = None,
179
+ ):
180
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
181
+ latent_height = height // self.vae_scale_factor_spatial
182
+ latent_width = width // self.vae_scale_factor_spatial
183
+
184
+ if latents is not None:
185
+ return latents.to(device=device, dtype=dtype)
186
+
187
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
188
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
189
+ return latents
190
+
191
+ def check_inputs(self, prompt, height, width, prompt_embeds=None):
192
+ if height % 16 != 0 or width % 16 != 0:
193
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
194
+ if prompt is not None and prompt_embeds is not None:
195
+ raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
196
+ elif prompt is None and prompt_embeds is None:
197
+ raise ValueError("Provide either `prompt` or `prompt_embeds`.")
198
+
199
+ @property
200
+ def guidance_scale(self):
201
+ return self._guidance_scale
202
+
203
+ @property
204
+ def do_classifier_free_guidance(self):
205
+ return self._guidance_scale > 1.0
206
+
207
+ @property
208
+ def num_timesteps(self):
209
+ return self._num_timesteps
210
+
211
+ @property
212
+ def interrupt(self):
213
+ return self._interrupt
214
+
215
+ @torch.no_grad()
216
+ def __call__(
217
+ self,
218
+ prompt: Union[str, List[str]] = None,
219
+ negative_prompt: Optional[Union[str, List[str]]] = None,
220
+ height: int = 768,
221
+ width: int = 1360,
222
+ num_inference_steps: int = 35,
223
+ guidance_scale: float = 7.0,
224
+ num_images_per_prompt: Optional[int] = 1,
225
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
226
+ latents: Optional[torch.Tensor] = None,
227
+ prompt_embeds: Optional[torch.Tensor] = None,
228
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
229
+ output_type: Optional[str] = "pil",
230
+ return_dict: bool = True,
231
+ callback_on_step_end: Optional[
232
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
233
+ ] = None,
234
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
235
+ max_sequence_length: int = 512,
236
+ ):
237
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
238
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
239
+
240
+ num_frames = 1
241
+
242
+ self.check_inputs(prompt, height, width, prompt_embeds)
243
+ self._guidance_scale = guidance_scale
244
+ self._current_timestep = None
245
+ self._interrupt = False
246
+
247
+ device = self._execution_device
248
+
249
+ if prompt is not None and isinstance(prompt, str):
250
+ batch_size = 1
251
+ elif prompt is not None and isinstance(prompt, list):
252
+ batch_size = len(prompt)
253
+ else:
254
+ batch_size = prompt_embeds.shape[0]
255
+
256
+ # Encode prompt
257
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
258
+ prompt=prompt,
259
+ negative_prompt=negative_prompt,
260
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
261
+ num_images_per_prompt=num_images_per_prompt,
262
+ prompt_embeds=prompt_embeds,
263
+ negative_prompt_embeds=negative_prompt_embeds,
264
+ device=device,
265
+ max_sequence_length=max_sequence_length,
266
+ )
267
+
268
+ # Prepare timesteps - use default descending schedule (1→0)
269
+ timesteps, num_inference_steps = retrieve_timesteps(
270
+ self.scheduler, num_inference_steps=num_inference_steps, device=device
271
+ )
272
+
273
+ # Prepare latents
274
+ transformer_dtype = self.transformer.dtype
275
+ num_channels_latents = self.transformer.config.in_channels
276
+ latents = self.prepare_latents(
277
+ batch_size * num_images_per_prompt,
278
+ num_channels_latents,
279
+ height,
280
+ width,
281
+ num_frames,
282
+ torch.float32,
283
+ device,
284
+ generator,
285
+ latents,
286
+ )
287
+
288
+ padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
289
+
290
+ # Denoising loop using CONST preconditioning (flow matching velocity model):
291
+ # - c_in = 1.0 (no input scaling)
292
+ # - timestep = sigma (passed directly)
293
+ # - model output is the velocity: denoised = x - velocity * sigma
294
+ # - CFG applied to velocity (equivalent to applying to denoised for linear preconditioning)
295
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
296
+ self._num_timesteps = len(timesteps)
297
+
298
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
299
+ for i, t in enumerate(timesteps):
300
+ if self.interrupt:
301
+ continue
302
+
303
+ self._current_timestep = t
304
+ sigma = self.scheduler.sigmas[i]
305
+
306
+ # Pass sigma directly as timestep (CONST preconditioning)
307
+ timestep = sigma.expand(latents.shape[0]).to(transformer_dtype)
308
+ latent_model_input = latents.to(transformer_dtype)
309
+
310
+ # Model predicts velocity (raw output IS the velocity for CONST)
311
+ velocity = self.transformer(
312
+ hidden_states=latent_model_input,
313
+ timestep=timestep,
314
+ encoder_hidden_states=prompt_embeds,
315
+ padding_mask=padding_mask,
316
+ return_dict=False,
317
+ )[0].float()
318
+
319
+ if self.do_classifier_free_guidance:
320
+ velocity_uncond = self.transformer(
321
+ hidden_states=latent_model_input,
322
+ timestep=timestep,
323
+ encoder_hidden_states=negative_prompt_embeds,
324
+ padding_mask=padding_mask,
325
+ return_dict=False,
326
+ )[0].float()
327
+ velocity = velocity_uncond + self.guidance_scale * (velocity - velocity_uncond)
328
+
329
+ # Euler step: scheduler computes x_next = x + (sigma_next - sigma) * velocity
330
+ latents = self.scheduler.step(velocity, t, latents, return_dict=False)[0]
331
+
332
+ if callback_on_step_end is not None:
333
+ callback_kwargs = {}
334
+ for k in callback_on_step_end_tensor_inputs:
335
+ callback_kwargs[k] = locals()[k]
336
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
337
+ latents = callback_outputs.pop("latents", latents)
338
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
339
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
340
+
341
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
342
+ progress_bar.update()
343
+
344
+ self._current_timestep = None
345
+
346
+ if not output_type == "latent":
347
+ latents_mean = (
348
+ torch.tensor(self.vae.config.latents_mean)
349
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
350
+ .to(latents.device, latents.dtype)
351
+ )
352
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
353
+ latents.device, latents.dtype
354
+ )
355
+ latents = latents / latents_std + latents_mean
356
+ video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
357
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
358
+ image = [batch[0] for batch in video]
359
+ if isinstance(video, torch.Tensor):
360
+ image = torch.stack(image)
361
+ elif isinstance(video, np.ndarray):
362
+ image = np.stack(image)
363
+ else:
364
+ image = latents[:, :, 0]
365
+
366
+ self.maybe_free_model_hooks()
367
+
368
+ if not return_dict:
369
+ return (image,)
370
+
371
+ return CosmosImagePipelineOutput(images=image)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.37.0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 3.0
6
+ }
t5_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
t5_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "clean_up_tokenization_spaces": true,
4
+ "eos_token": "</s>",
5
+ "extra_ids": 100,
6
+ "extra_special_tokens": [
7
+ "<extra_id_0>",
8
+ "<extra_id_1>",
9
+ "<extra_id_2>",
10
+ "<extra_id_3>",
11
+ "<extra_id_4>",
12
+ "<extra_id_5>",
13
+ "<extra_id_6>",
14
+ "<extra_id_7>",
15
+ "<extra_id_8>",
16
+ "<extra_id_9>",
17
+ "<extra_id_10>",
18
+ "<extra_id_11>",
19
+ "<extra_id_12>",
20
+ "<extra_id_13>",
21
+ "<extra_id_14>",
22
+ "<extra_id_15>",
23
+ "<extra_id_16>",
24
+ "<extra_id_17>",
25
+ "<extra_id_18>",
26
+ "<extra_id_19>",
27
+ "<extra_id_20>",
28
+ "<extra_id_21>",
29
+ "<extra_id_22>",
30
+ "<extra_id_23>",
31
+ "<extra_id_24>",
32
+ "<extra_id_25>",
33
+ "<extra_id_26>",
34
+ "<extra_id_27>",
35
+ "<extra_id_28>",
36
+ "<extra_id_29>",
37
+ "<extra_id_30>",
38
+ "<extra_id_31>",
39
+ "<extra_id_32>",
40
+ "<extra_id_33>",
41
+ "<extra_id_34>",
42
+ "<extra_id_35>",
43
+ "<extra_id_36>",
44
+ "<extra_id_37>",
45
+ "<extra_id_38>",
46
+ "<extra_id_39>",
47
+ "<extra_id_40>",
48
+ "<extra_id_41>",
49
+ "<extra_id_42>",
50
+ "<extra_id_43>",
51
+ "<extra_id_44>",
52
+ "<extra_id_45>",
53
+ "<extra_id_46>",
54
+ "<extra_id_47>",
55
+ "<extra_id_48>",
56
+ "<extra_id_49>",
57
+ "<extra_id_50>",
58
+ "<extra_id_51>",
59
+ "<extra_id_52>",
60
+ "<extra_id_53>",
61
+ "<extra_id_54>",
62
+ "<extra_id_55>",
63
+ "<extra_id_56>",
64
+ "<extra_id_57>",
65
+ "<extra_id_58>",
66
+ "<extra_id_59>",
67
+ "<extra_id_60>",
68
+ "<extra_id_61>",
69
+ "<extra_id_62>",
70
+ "<extra_id_63>",
71
+ "<extra_id_64>",
72
+ "<extra_id_65>",
73
+ "<extra_id_66>",
74
+ "<extra_id_67>",
75
+ "<extra_id_68>",
76
+ "<extra_id_69>",
77
+ "<extra_id_70>",
78
+ "<extra_id_71>",
79
+ "<extra_id_72>",
80
+ "<extra_id_73>",
81
+ "<extra_id_74>",
82
+ "<extra_id_75>",
83
+ "<extra_id_76>",
84
+ "<extra_id_77>",
85
+ "<extra_id_78>",
86
+ "<extra_id_79>",
87
+ "<extra_id_80>",
88
+ "<extra_id_81>",
89
+ "<extra_id_82>",
90
+ "<extra_id_83>",
91
+ "<extra_id_84>",
92
+ "<extra_id_85>",
93
+ "<extra_id_86>",
94
+ "<extra_id_87>",
95
+ "<extra_id_88>",
96
+ "<extra_id_89>",
97
+ "<extra_id_90>",
98
+ "<extra_id_91>",
99
+ "<extra_id_92>",
100
+ "<extra_id_93>",
101
+ "<extra_id_94>",
102
+ "<extra_id_95>",
103
+ "<extra_id_96>",
104
+ "<extra_id_97>",
105
+ "<extra_id_98>",
106
+ "<extra_id_99>"
107
+ ],
108
+ "is_local": false,
109
+ "model_max_length": 512,
110
+ "pad_token": "<pad>",
111
+ "tokenizer_class": "T5Tokenizer",
112
+ "unk_token": "<unk>"
113
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3Model"
4
+ ],
5
+ "model_type": "qwen3",
6
+ "vocab_size": 151936,
7
+ "hidden_size": 1024,
8
+ "intermediate_size": 3072,
9
+ "num_hidden_layers": 28,
10
+ "num_attention_heads": 16,
11
+ "num_key_value_heads": 8,
12
+ "head_dim": 128,
13
+ "hidden_act": "silu",
14
+ "max_position_embeddings": 32768,
15
+ "rms_norm_eps": 1e-06,
16
+ "rope_theta": 1000000.0,
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "use_cache": false,
20
+ "tie_word_embeddings": false,
21
+ "torch_dtype": "bfloat16"
22
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d10aa56a4da8a95d954d99228d9e20e27f96ac5fc8aa41b89a41532b16bb4817
3
+ size 1192135064
tokenizer/chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "is_local": false,
24
+ "model_max_length": 131072,
25
+ "pad_token": "<|endoftext|>",
26
+ "split_special_tokens": false,
27
+ "tokenizer_class": "Qwen2Tokenizer",
28
+ "unk_token": null
29
+ }
transformer/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CosmosTransformer3DModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "in_channels": 16,
5
+ "out_channels": 16,
6
+ "num_attention_heads": 16,
7
+ "attention_head_dim": 128,
8
+ "num_layers": 28,
9
+ "mlp_ratio": 4.0,
10
+ "text_embed_dim": 1024,
11
+ "adaln_lora_dim": 256,
12
+ "max_size": [
13
+ 128,
14
+ 240,
15
+ 240
16
+ ],
17
+ "patch_size": [
18
+ 1,
19
+ 2,
20
+ 2
21
+ ],
22
+ "rope_scale": [
23
+ 1.0,
24
+ 4.0,
25
+ 4.0
26
+ ],
27
+ "concat_padding_mask": true,
28
+ "extra_pos_embed_type": null
29
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9c0b348c119e44dcc26589102ad5ca64d26ac84d5db3b743d29f0fa2fc2f8b2
3
+ size 3912877072
vae/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLWan",
3
+ "_diffusers_version": "0.33.0.dev0",
4
+ "attn_scales": [],
5
+ "base_dim": 96,
6
+ "dim_mult": [
7
+ 1,
8
+ 2,
9
+ 4,
10
+ 4
11
+ ],
12
+ "dropout": 0.0,
13
+ "latents_mean": [
14
+ -0.7571,
15
+ -0.7089,
16
+ -0.9113,
17
+ 0.1075,
18
+ -0.1745,
19
+ 0.9653,
20
+ -0.1517,
21
+ 1.5508,
22
+ 0.4134,
23
+ -0.0715,
24
+ 0.5517,
25
+ -0.3632,
26
+ -0.1922,
27
+ -0.9497,
28
+ 0.2503,
29
+ -0.2921
30
+ ],
31
+ "latents_std": [
32
+ 2.8184,
33
+ 1.4541,
34
+ 2.3275,
35
+ 2.6558,
36
+ 1.2196,
37
+ 1.7708,
38
+ 2.6052,
39
+ 2.0743,
40
+ 3.2687,
41
+ 2.1526,
42
+ 2.8652,
43
+ 1.5579,
44
+ 1.6382,
45
+ 1.1253,
46
+ 2.8251,
47
+ 1.916
48
+ ],
49
+ "num_res_blocks": 2,
50
+ "temperal_downsample": [
51
+ false,
52
+ true,
53
+ true
54
+ ],
55
+ "z_dim": 16
56
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b5bf326a6c4f66fb2b2250687fdccd1f126ee7c977d2f0170cb56fdacc70a9a
3
+ size 253806934