File size: 26,676 Bytes
f15f02b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 |
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
""" PyTorch StableLM-Alpha model. """
from typing import Optional, Tuple, Union
import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_stablelm_alpha import StableLMAlphaConfig
logger = logging.get_logger(__name__)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, tgt_seq_len, src_seq_len]`."""
batch_size, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class LayerNorm(nn.LayerNorm):
def __init__(self, normalized_shape: torch.Size, bias: bool = True, **kwargs):
r"""
bias (`bool`, default = True): whether to use the bias term.
"""
super().__init__(normalized_shape, **kwargs)
if not bias:
self.bias = None
class DecoderLayer(nn.Module):
def __init__(self, config: StableLMAlphaConfig):
super().__init__()
self.norm = LayerNorm(config.hidden_size, eps=config.norm_eps)
self.attention = Attention(config)
self.mlp = MLP(config)
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
# Pre-Norm
hidden_states = self.norm(hidden_states)
# Self-Attention
attn_output, attn_weights, present_key_value = self.attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# Feed-forward
mlp_output = self.mlp(hidden_states)
hidden_states = residual + attn_output + mlp_output
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs # hidden_states, (optional: attn_weights), (optional: present_key_value)
class MLP(nn.Module):
def __init__(self, config: StableLMAlphaConfig):
super().__init__()
hidden_size = config.hidden_size
multiple_of = 256
ff_dim = int(8 * hidden_size / 3)
intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.gate_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
ff, ff_gate = self.gate_proj(x).chunk(2, dim=-1)
return self.out_proj(ff * self.act(ff_gate))
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
max_position_embeddings: int,
base: int = 10_000,
device: Optional[torch.device] = None,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
# x: [batch_size, num_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.get_default_dtype())
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Attention(nn.Module):
def __init__(self, config: StableLMAlphaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if self.hidden_size % self.num_heads != 0:
raise ValueError(
"`hidden_size` is not divisble by the number of attention heads! Make sure to update them"
)
self.qkv_proj = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
self.rotary_ndims = int(self.head_dim * self.config.rotary_pct)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
has_past_key_value = past_key_value is not None
# Compute QKV
# [batch_size, seq_len, (num_heads * 3 * head_dim)]
qkv = self.qkv_proj(hidden_states)
# [batch_size, seq_len, num_heads, 3 * head_dim]
new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
qkv = qkv.view(*new_qkv_shape)
# 3 * [batch_size, num_heads, seq_len, head_dim]
query = qkv[..., : self.head_dim].permute(0, 2, 1, 3)
key = qkv[..., self.head_dim:(2 * self.head_dim)].permute(0, 2, 1, 3)
value = qkv[..., (2 * self.head_dim):].permute(0, 2, 1, 3)
# Compute rotary embeddings on rotary_ndims
# [batch_size, num_heads, seq_len, rotary_ndims]
query_rot = query[..., :self.rotary_ndims]
query_pass = query[..., self.rotary_ndims:]
key_rot = key[..., :self.rotary_ndims]
key_pass = key[..., self.rotary_ndims:]
# Compute token offset for rotary embeddings (when decoding)
kv_seq_len = key.shape[-2]
if has_past_key_value:
kv_seq_len += past_key_value[0].shape[-2]
# Add rotary embeddings to query and key
cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
# Concatenate rotary embeddings with pass-through query and key
# [batch_size, num_heads, seq_len, head_dim]
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Reuse past key-value states
if has_past_key_value:
key = torch.cat((past_key_value[0], key), dim=2)
value = torch.cat((past_key_value[1], value), dim=2)
present_key_value = (key, value) if use_cache else None
# [batch_size, num_heads, seq_len, head_dim]
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
# Compute attention
softmax_scale = 1 / math.sqrt(self.head_dim)
attn_scores = torch.einsum('bthd,bshd->bhts', query, key * softmax_scale)
# Apply the attention mask
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.einsum('bhts,bshd->bthd', attn_weights, value)
# Merge heads
attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1)
# Final linear projection
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, present_key_value
def attention_mask_func(attention_scores: torch.Tensor, ltor_mask: torch.Tensor):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
class StableLMAlphaPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and a simple interface
for downloading and loading pretrained models.
"""
config_class = StableLMAlphaConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module: nn.Module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, value=False):
if isinstance(module, StableLMAlphaModel):
module.gradient_checkpointing = value
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0
):
"""Make causal mask used for bi-directional self-attention."""
batch_size, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(torch.float16).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(batch_size, 1, tgt_len, tgt_len + past_key_values_length)
class StableLMAlphaModel(StableLMAlphaPreTrainedModel):
def __init__(self, config: StableLMAlphaConfig):
super().__init__(config)
self.config = config
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.final_norm = LayerNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value: nn.Module):
self.embed = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
past_key_values_length: int,
):
# Create causal mask
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers`
with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks.
Can be used to speed up decoding. If `past_key_values` are used, the user
can optionally input only the last `decoder_input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and
can be used to speed up decoding (see `past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_key_values_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
seq_length_with_past = seq_length
else:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed(input_ids)
# Attention mask.
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
present_key_values = () if use_cache else None
for _, (decoder_layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# `None` for `use_cache`
return module(*inputs, output_attentions, None)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
# `None` for `past_key_value`
None,
)
else:
outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = outputs[0]
if output_attentions:
all_attentions = all_attentions + (outputs[1],)
if use_cache:
present_key_values += (outputs[2 if output_attentions else 1],)
hidden_states = self.final_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
present_key_values = present_key_values if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=present_key_values,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class StableLMAlphaForCausalLM(StableLMAlphaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: StableLMAlphaConfig):
super().__init__(config)
self.transformer = StableLMAlphaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, StableLMAlphaForCausalLM, StableLMAlphaConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-base-alpha-3b-v2", trust_remote_code=True)
>>> config = StableLMAlphaConfig.from_pretrained("stabilityai/stablelm-base-alpha-3b-v2")
>>> config.is_decoder = True
>>> model = StableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-base-alpha-3b-v2", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
# Cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# Create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# If `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values: torch.Tensor, beam_idx: int):
reordered_past = ()
for past_key_value in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in past_key_value[:2]) + past_key_value[2:],
)
return reordered_past
StableLMAlphaConfig.register_for_auto_class()
StableLMAlphaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|