japanese-stablelm-instruct-alpha-7b-s / modeling_japanese_stablelm_alpha.py
camenduru's picture
thanks to Stability AI Japan ❤
a07de4f
# coding=utf-8
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch JapaneseStableLMAlpha model. """
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
logger = logging.get_logger(__name__)
class JapaneseStableLMAlphaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = JapaneseStableLMAlphaConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
if module.bias is not None:
module.bias.data.zero_()
if module.weight is not None:
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, JapaneseStableLMAlphaModel):
module.gradient_checkpointing = value
class JapaneseStableLMAlphaModel(JapaneseStableLMAlphaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_in
def set_input_embeddings(self, value):
self.embed_in = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
elementwise_affine=False,
)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps
)
self.attention = Attention(config)
self.mlp = MLP(config)
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = hidden_states + mlp_output + attn_output
if use_cache:
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
class MLP(nn.Module):
def __init__(self, config: JapaneseStableLMAlphaConfig):
super().__init__()
hidden_size = config.hidden_size
multiple_of = 256
ff_dim = int(8 * hidden_size / 3)
intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.packed_input_proj = torch.nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
ff, ff_gate = self.packed_input_proj(x).chunk(2, dim=-1)
return self.out_proj(ff * self.act(ff_gate))
class RotaryEmbedding(torch.nn.Module):
"""Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py"""
def __init__(
self,
dim: int,
max_position_embeddings: int,
base: int = 10_000,
scale_base: int = 512,
device: str = None
):
super().__init__()
self.dim = dim
self.seq_len_cached = max_position_embeddings
# Set up `inv_freq` term
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq)
# Set up `scale` term
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None else None
)
self.register_buffer("scale", scale)
# Seet up `cos..` and `sin...` cache terms
t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
# freqs = torch.cat((freqs, freqs), dim=-1)
seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
power = (seq_range - self.seq_len_cached // 2) / self.scale_base
scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
# scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
power = (seq_range - self.seq_len_cached // 2) / self.scale_base
scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
return (
self.cos_cached[:seq_len, ...],
self.sin_cached[:seq_len, ...],
self.cos_k_cached[:seq_len, ...],
self.sin_k_cached[:seq_len, ...],
)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None):
"""
q, k: [bs, num_heads, seq_len, rot_dim]
cos, sin: [seq_len, rot_dim / 2]
position_ids: [bs, seq_len]
"""
# print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}")
import einops
cos = einops.repeat(cos, 's r -> s (2 r)')
sin = einops.repeat(sin, 's r -> s (2 r)')
cos_k = einops.repeat(cos_k, 's r -> s (2 r)')
sin_k = einops.repeat(sin_k, 's r -> s (2 r)')
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
return q_embed, k_embed
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
)
self.head_size = self.hidden_size // self.num_attention_heads
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.rotary_ndims = int(self.head_size * config.rotary_pct)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
max_position_embeddings=config.max_position_embeddings,
base=config.rotary_emb_base,
scale_base=config.rotary_scale_base,
)
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding)
kv_seq_len = key.shape[-2]
if has_layer_past:
kv_seq_len += layer_past[0].shape[-2]
# Add rotary embeddings to query and key
# TODO: Check if using xpos
cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len)
query, key = apply_rotary_pos_emb(
query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Merge attn_head_size dim and num_attn_heads dim into hidden dim
# [bs, seq_len, num_attention_heads, attn_head_size]
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
batch_size * num_attention_heads,
query_length,
key_length,
dtype=query.dtype,
device=key.device,
)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_scores = attn_scores + attention_mask
# NOTE: Upcast to float32
attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(value)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
class JapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaPreTrainedModel):
_tied_weights_keys = ["embed_out.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = JapaneseStableLMAlphaModel(config)
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.embed_out
def set_output_embeddings(self, new_embeddings):
self.embed_out = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Example:
```python
>>> import torch
>>> from transformers import LlamaTokenizer, JapaneseStableLMAlphaForCausalLM, JapaneseStableLMAlphaConfig
>>> tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
>>> config = JapaneseStableLMAlphaConfig.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b")
>>> config.is_decoder = True
>>> model = JapaneseStableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b", config=config, trust_remote_code=True)
>>> inputs = tokenizer("日本語の美しいところは、", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past