File size: 24,406 Bytes
7ef3558 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 |
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nested._internal.nested_tensor import nested_from_padded
from transformers import (
LlamaConfig,
LlamaModel,
LlamaPreTrainedModel,
PreTrainedTokenizer,
)
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaMLP,
LlamaRMSNorm,
LlamaRotaryEmbedding,
rotate_half,
)
from transformers.processing_utils import Unpack
class ModifiedLlamaAttention(LlamaAttention):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
warnings.warn(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_output, attn_weights = sdpa_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
is_causal=False,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs: Any,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
if key.is_nested:
key = repeat_jagged_kv(key, module.num_key_value_groups)
value = repeat_jagged_kv(value, module.num_key_value_groups)
else:
key = repeat_dense_kv(key, module.num_key_value_groups)
value = repeat_dense_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
is_causal = query.shape[2] > 1 and causal_mask is None
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def repeat_jagged_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
expand_shape = (batch, num_key_value_heads, -1, n_rep, head_dim)
if n_rep == 1:
return hidden_states
hidden_states = (
hidden_states.unsqueeze(3)
.expand(expand_shape)
.transpose(1, 2)
.flatten(2, 3)
.transpose(1, 2)
)
return hidden_states
def repeat_dense_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
if q.is_nested and k.is_nested:
if q.layout != torch.jagged:
raise NotImplementedError(f"Unsupported layout: {q.layout}")
if k.layout != torch.jagged:
raise NotImplementedError(f"Unsupported layout: {k.layout}")
return _jagged_tensor_forward(q, k, cos, sin)
else:
return _padded_tensor_forward(q, k, cos, sin)
def _jagged_tensor_forward(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_dense = q.to_padded_tensor(0.0)
k_dense = k.to_padded_tensor(0.0)
q_dense_embed = (q_dense * cos) + (rotate_half(q_dense) * sin)
k_dense_embed = (k_dense * cos) + (rotate_half(k_dense) * sin)
q_jagged_embed = convert_dense_to_jagged(q, q_dense_embed)
k_jagged_embed = convert_dense_to_jagged(k, k_dense_embed)
return q_jagged_embed, k_jagged_embed
def _padded_tensor_forward(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def convert_dense_to_jagged(nested_q: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
padded_max_S = nested_q._get_max_seqlen()
total_L = nested_q._values.shape[nested_q._ragged_idx - 1]
if padded_max_S is None:
# use upper bound on max seqlen if it's not present
padded_max_S = total_L
# convert dense tensor -> jagged
q = q.expand(
[
x if i != nested_q._ragged_idx else padded_max_S
for i, x in enumerate(q.shape)
]
)
nested_result = nested_from_padded(
q,
offsets=nested_q._offsets,
ragged_idx=nested_q._ragged_idx,
sum_S=total_L,
min_seqlen=nested_q._get_min_seqlen(),
max_seqlen=padded_max_S,
)
return nested_result
class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
nn.Module.__init__(self)
self.hidden_size: int = config.hidden_size
self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
class LlamaBiModel(LlamaModel):
def __init__(self, config: LlamaConfig) -> None:
LlamaPreTrainedModel.__init__(self, config)
self.padding_idx: int = config.pad_token_id
self.vocab_size: int = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
ModifiedLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens=None,
output_attentions=False,
):
"""
Updates the causal mask for attention computations.
"""
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if attention_mask is None or attention_mask.dim() == 4:
return attention_mask
return AttentionMaskConverter._expand_mask(
mask=attention_mask,
dtype=input_tensor.dtype,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = False
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
warnings.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.",
DeprecationWarning,
stacklevel=2,
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
warnings.warn(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)",
DeprecationWarning,
stacklevel=2,
)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
if inputs_embeds.is_nested:
seq_len = inputs_embeds._get_max_seqlen()
else:
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + seq_len,
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if not inputs_embeds.is_nested:
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
)
else:
causal_mask = None
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class DramaModel(LlamaBiModel):
"""
DramaModel is a modified version of the LlamaModel that supports bi-directional attention
and provides query and document encoding functionalities.
"""
def __init__(self, config: LlamaConfig):
"""
Initializes the DramaModel by disabling causal masking in self-attention layers.
"""
super().__init__(config)
for layer in self.layers:
layer.self_attn.is_causal = False
# query prefix
self.query_prefix = "Query: "
self.max_seq_len = 8192
self.hidden_size = config.hidden_size
def _average_pool(
self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Computes the average pooled representation of the last hidden states.
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def _tokenize(
self,
tokenizer: PreTrainedTokenizer,
texts: list[str],
max_seq_len: int = None,
use_nested: bool = False,
):
"""
Tokenizes input text sequences with optional sequence length restriction.
"""
if max_seq_len is None:
max_seq_len = self.max_seq_len
if use_nested:
tokenized = tokenizer(
texts,
truncation=True,
max_length=max_seq_len,
return_length=True,
)
tokenized.input_ids = torch.nested.nested_tensor(
tokenized.input_ids, layout=torch.jagged
).to(self.device)
tokenized.attention_mask = None
else:
tokenized = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
).to(self.device)
tokenizer_ouput = {}
tokenizer_ouput["input_ids"] = tokenized.input_ids
tokenizer_ouput["attention_mask"] = tokenized.attention_mask
return tokenizer_ouput
def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
"""
Pass through the model and compute normalized embeddings.
Args:
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask tensor.
dim (int): Dimensionality for output embeddings.
Returns:
torch.Tensor: Normalized output embeddings.
"""
outputs = self.forward(
input_ids, attention_mask, *args, **kwargs
).last_hidden_state
if not outputs.is_nested:
if dim is not None:
outputs = outputs[:, :, :dim]
embeddings = self._average_pool(outputs, attention_mask)
else:
if dim is not None:
outputs, _ = outputs.split_with_sizes(
split_sizes=[dim, outputs.shape[-1] - dim], dim=-1
)
embeddings = outputs.sum(dim=-2)
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def encode_queries(
self,
tokenizer: PreTrainedTokenizer,
queries: list[str],
max_seq_len: int = None,
dim: int = None,
use_nested: bool = False,
):
"""
Encodes a list of queries into embeddings.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
queries (list[str]): List of query texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
Returns:
torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
"""
if not queries:
raise ValueError("queries must not be empty.")
if not isinstance(queries, list) or not all(
isinstance(q, str) for q in queries
):
raise ValueError("queries must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
queries = [self.query_prefix + query for query in queries]
tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
embeddings = self.encode(**tokenized_queries, dim=dim)
return embeddings
def encode_documents(
self,
tokenizer: PreTrainedTokenizer,
documents: list[str],
max_seq_len: int = None,
dim: int = None,
use_nested: bool = False,
):
"""
Encodes a list of documents into embeddings.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
documents (list[str]): List of document texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
Returns:
torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
"""
if not documents:
raise ValueError("documents must not be empty.")
if not isinstance(documents, list) or not all(
isinstance(d, str) for d in documents
):
raise ValueError("documents must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
tokenized_documents = self._tokenize(
tokenizer, documents, max_seq_len, use_nested
)
embeddings = self.encode(**tokenized_documents, dim=dim)
return embeddings
|