Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from dataclasses import dataclass | |
from typing import Optional, Any | |
import math | |
from comfy.ldm.modules.attention import optimized_attention_for_device | |
import comfy.model_management | |
import comfy.ldm.common_dit | |
import comfy.model_management | |
from . import qwen_vl | |
class Llama2Config: | |
vocab_size: int = 128320 | |
hidden_size: int = 4096 | |
intermediate_size: int = 14336 | |
num_hidden_layers: int = 32 | |
num_attention_heads: int = 32 | |
num_key_value_heads: int = 8 | |
max_position_embeddings: int = 8192 | |
rms_norm_eps: float = 1e-5 | |
rope_theta: float = 500000.0 | |
transformer_type: str = "llama" | |
head_dim = 128 | |
rms_norm_add = False | |
mlp_activation = "silu" | |
qkv_bias = False | |
rope_dims = None | |
class Qwen25_3BConfig: | |
vocab_size: int = 151936 | |
hidden_size: int = 2048 | |
intermediate_size: int = 11008 | |
num_hidden_layers: int = 36 | |
num_attention_heads: int = 16 | |
num_key_value_heads: int = 2 | |
max_position_embeddings: int = 128000 | |
rms_norm_eps: float = 1e-6 | |
rope_theta: float = 1000000.0 | |
transformer_type: str = "llama" | |
head_dim = 128 | |
rms_norm_add = False | |
mlp_activation = "silu" | |
qkv_bias = True | |
rope_dims = None | |
class Qwen25_7BVLI_Config: | |
vocab_size: int = 152064 | |
hidden_size: int = 3584 | |
intermediate_size: int = 18944 | |
num_hidden_layers: int = 28 | |
num_attention_heads: int = 28 | |
num_key_value_heads: int = 4 | |
max_position_embeddings: int = 128000 | |
rms_norm_eps: float = 1e-6 | |
rope_theta: float = 1000000.0 | |
transformer_type: str = "llama" | |
head_dim = 128 | |
rms_norm_add = False | |
mlp_activation = "silu" | |
qkv_bias = True | |
rope_dims = [16, 24, 24] | |
class Gemma2_2B_Config: | |
vocab_size: int = 256000 | |
hidden_size: int = 2304 | |
intermediate_size: int = 9216 | |
num_hidden_layers: int = 26 | |
num_attention_heads: int = 8 | |
num_key_value_heads: int = 4 | |
max_position_embeddings: int = 8192 | |
rms_norm_eps: float = 1e-6 | |
rope_theta: float = 10000.0 | |
transformer_type: str = "gemma2" | |
head_dim = 256 | |
rms_norm_add = True | |
mlp_activation = "gelu_pytorch_tanh" | |
qkv_bias = False | |
rope_dims = None | |
class RMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) | |
self.add = add | |
def forward(self, x: torch.Tensor): | |
w = self.weight | |
if self.add: | |
w = w + 1.0 | |
return comfy.ldm.common_dit.rms_norm(x, w, self.eps) | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): | |
theta_numerator = torch.arange(0, head_dim, 2, device=device).float() | |
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) | |
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
position_ids_expanded = position_ids[:, None, :].float() | |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
cos = emb.cos() | |
sin = emb.sin() | |
if rope_dims is not None and position_ids.shape[0] > 1: | |
mrope_section = rope_dims * 2 | |
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) | |
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) | |
else: | |
cos = cos.unsqueeze(1) | |
sin = sin.unsqueeze(1) | |
return (cos, sin) | |
def apply_rope(xq, xk, freqs_cis): | |
cos = freqs_cis[0] | |
sin = freqs_cis[1] | |
q_embed = (xq * cos) + (rotate_half(xq) * sin) | |
k_embed = (xk * cos) + (rotate_half(xk) * sin) | |
return q_embed, k_embed | |
class Attention(nn.Module): | |
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): | |
super().__init__() | |
self.num_heads = config.num_attention_heads | |
self.num_kv_heads = config.num_key_value_heads | |
self.hidden_size = config.hidden_size | |
self.head_dim = config.head_dim | |
self.inner_size = self.num_heads * self.head_dim | |
ops = ops or nn | |
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) | |
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) | |
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) | |
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
freqs_cis: Optional[torch.Tensor] = None, | |
optimized_attention=None, | |
): | |
batch_size, seq_length, _ = hidden_states.shape | |
xq = self.q_proj(hidden_states) | |
xk = self.k_proj(hidden_states) | |
xv = self.v_proj(hidden_states) | |
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) | |
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) | |
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) | |
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) | |
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) | |
return self.o_proj(output) | |
class MLP(nn.Module): | |
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): | |
super().__init__() | |
ops = ops or nn | |
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) | |
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) | |
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) | |
if config.mlp_activation == "silu": | |
self.activation = torch.nn.functional.silu | |
elif config.mlp_activation == "gelu_pytorch_tanh": | |
self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh") | |
def forward(self, x): | |
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): | |
super().__init__() | |
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) | |
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) | |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) | |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) | |
def forward( | |
self, | |
x: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
freqs_cis: Optional[torch.Tensor] = None, | |
optimized_attention=None, | |
): | |
# Self Attention | |
residual = x | |
x = self.input_layernorm(x) | |
x = self.self_attn( | |
hidden_states=x, | |
attention_mask=attention_mask, | |
freqs_cis=freqs_cis, | |
optimized_attention=optimized_attention, | |
) | |
x = residual + x | |
# MLP | |
residual = x | |
x = self.post_attention_layernorm(x) | |
x = self.mlp(x) | |
x = residual + x | |
return x | |
class TransformerBlockGemma2(nn.Module): | |
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): | |
super().__init__() | |
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops) | |
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) | |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) | |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) | |
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) | |
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) | |
def forward( | |
self, | |
x: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
freqs_cis: Optional[torch.Tensor] = None, | |
optimized_attention=None, | |
): | |
# Self Attention | |
residual = x | |
x = self.input_layernorm(x) | |
x = self.self_attn( | |
hidden_states=x, | |
attention_mask=attention_mask, | |
freqs_cis=freqs_cis, | |
optimized_attention=optimized_attention, | |
) | |
x = self.post_attention_layernorm(x) | |
x = residual + x | |
# MLP | |
residual = x | |
x = self.pre_feedforward_layernorm(x) | |
x = self.mlp(x) | |
x = self.post_feedforward_layernorm(x) | |
x = residual + x | |
return x | |
class Llama2_(nn.Module): | |
def __init__(self, config, device=None, dtype=None, ops=None): | |
super().__init__() | |
self.config = config | |
self.vocab_size = config.vocab_size | |
self.embed_tokens = ops.Embedding( | |
config.vocab_size, | |
config.hidden_size, | |
device=device, | |
dtype=dtype | |
) | |
if self.config.transformer_type == "gemma2": | |
transformer = TransformerBlockGemma2 | |
self.normalize_in = True | |
else: | |
transformer = TransformerBlock | |
self.normalize_in = False | |
self.layers = nn.ModuleList([ | |
transformer(config, device=device, dtype=dtype, ops=ops) | |
for _ in range(config.num_hidden_layers) | |
]) | |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) | |
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) | |
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): | |
if embeds is not None: | |
x = embeds | |
else: | |
x = self.embed_tokens(x, out_dtype=dtype) | |
if self.normalize_in: | |
x *= self.config.hidden_size ** 0.5 | |
if position_ids is None: | |
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) | |
freqs_cis = precompute_freqs_cis(self.config.head_dim, | |
position_ids, | |
self.config.rope_theta, | |
self.config.rope_dims, | |
device=x.device) | |
mask = None | |
if attention_mask is not None: | |
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) | |
mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) | |
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) | |
if mask is not None: | |
mask += causal_mask | |
else: | |
mask = causal_mask | |
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) | |
intermediate = None | |
all_intermediate = None | |
if intermediate_output is not None: | |
if intermediate_output == "all": | |
all_intermediate = [] | |
intermediate_output = None | |
elif intermediate_output < 0: | |
intermediate_output = len(self.layers) + intermediate_output | |
for i, layer in enumerate(self.layers): | |
if all_intermediate is not None: | |
all_intermediate.append(x.unsqueeze(1).clone()) | |
x = layer( | |
x=x, | |
attention_mask=mask, | |
freqs_cis=freqs_cis, | |
optimized_attention=optimized_attention, | |
) | |
if i == intermediate_output: | |
intermediate = x.clone() | |
x = self.norm(x) | |
if all_intermediate is not None: | |
all_intermediate.append(x.unsqueeze(1).clone()) | |
if all_intermediate is not None: | |
intermediate = torch.cat(all_intermediate, dim=1) | |
if intermediate is not None and final_layer_norm_intermediate: | |
intermediate = self.norm(intermediate) | |
return x, intermediate | |
class BaseLlama: | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, embeddings): | |
self.model.embed_tokens = embeddings | |
def forward(self, input_ids, *args, **kwargs): | |
return self.model(input_ids, *args, **kwargs) | |
class Llama2(BaseLlama, torch.nn.Module): | |
def __init__(self, config_dict, dtype, device, operations): | |
super().__init__() | |
config = Llama2Config(**config_dict) | |
self.num_layers = config.num_hidden_layers | |
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) | |
self.dtype = dtype | |
class Qwen25_3B(BaseLlama, torch.nn.Module): | |
def __init__(self, config_dict, dtype, device, operations): | |
super().__init__() | |
config = Qwen25_3BConfig(**config_dict) | |
self.num_layers = config.num_hidden_layers | |
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) | |
self.dtype = dtype | |
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): | |
def __init__(self, config_dict, dtype, device, operations): | |
super().__init__() | |
config = Qwen25_7BVLI_Config(**config_dict) | |
self.num_layers = config.num_hidden_layers | |
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) | |
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) | |
self.dtype = dtype | |
def preprocess_embed(self, embed, device): | |
if embed["type"] == "image": | |
image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) | |
return self.visual(image.to(device, dtype=torch.float32), grid), grid | |
return None, None | |
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): | |
grid = None | |
for e in embeds_info: | |
if e.get("type") == "image": | |
grid = e.get("extra", None) | |
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) | |
start = e.get("index") | |
position_ids[:, :start] = torch.arange(0, start, device=embeds.device) | |
end = e.get("size") + start | |
len_max = int(grid.max()) // 2 | |
start_next = len_max + start | |
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) | |
position_ids[0, start:end] = start | |
max_d = int(grid[0][1]) // 2 | |
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] | |
max_d = int(grid[0][2]) // 2 | |
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] | |
if grid is None: | |
position_ids = None | |
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) | |
class Gemma2_2B(BaseLlama, torch.nn.Module): | |
def __init__(self, config_dict, dtype, device, operations): | |
super().__init__() | |
config = Gemma2_2B_Config(**config_dict) | |
self.num_layers = config.num_hidden_layers | |
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) | |
self.dtype = dtype | |