LingoWhale-8B / modeling_lingowhale.py
DeepLangLvcc's picture
add chat function
9a542c4
# Copyright 2023 DeepLang AI. All Rights Reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from queue import Queue
from threading import Thread
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.generation.utils import GenerationConfig
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.utils import logging
from .configuration_lingowhale import LingoWhaleConfig
logger = logging.get_logger(__name__)
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
try:
from flash_attn.flash_attn_interface import \
flash_attn_varlen_func as flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full(
(tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device,
)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(tgt_len,
past_key_values_length,
dtype=dtype,
device=device),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class TextIterStreamer:
def __init__(self,
tokenizer,
skip_prompt=False,
skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(
self.tokens, skip_special_tokens=self.skip_special_tokens))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value
class LingoWhaleRMSNorm(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class LingoWhaleRotaryEmbedding(torch.nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.inv_freq = 1.0 / (base**(
torch.arange(0, dim, 2).float().to(device) / dim))
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=torch.float32,
)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(
x.device)
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(
x.device)
elif self.cos_cached.device != x.device:
self.cos_cached = self.cos_cached.to(x.device)
self.sin_cached = self.sin_cached.to(x.device)
return (
self.cos_cached[:, :, :seq_len, ...],
self.sin_cached[:, :, :seq_len, ...],
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)
class LingoWhaleMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_and_up_proj = nn.Linear(self.hidden_size,
self.intermediate_size * 2,
bias=False)
self.down_proj = nn.Linear(self.intermediate_size,
self.hidden_size,
bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_and_up = self.gate_and_up_proj(x)
[gate, up] = torch.chunk(gate_and_up, 2, dim=-1)
acted = self.act_fn(gate)
tmp = acted * up
result = self.down_proj(tmp)
return result
class LingoWhaleAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LingoWhaleConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.dropout_p = config.attn_dropout_prob
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = nn.Linear(self.hidden_size,
3 * self.hidden_size,
bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=False)
self.attention_dropout = torch.nn.Dropout(self.dropout_p)
self._init_rope()
def attention_mask_func(self, attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def forward_torch_softmax(self, input, mask):
input = input.float()
mask_output = (self.attention_mask_func(input, mask)
if mask is not None else input)
probs = torch.nn.Softmax(dim=-1)(mask_output)
probs = probs.bfloat16()
return probs
def _self_attention(self, query_layer, key_layer, value_layer,
attention_mask):
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3],
output_size[0] * output_size[1], -1)
matmul_input_buffer = torch.randn(
(output_size[0] * output_size[1], output_size[2], output_size[3]),
dtype=query_layer.dtype,
device=query_layer.device,
)
norm_factor = math.sqrt(key_layer.shape[-1])
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.forward_torch_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.reshape(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size, )
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
def _self_attention_flash(self, q, k, v):
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device,
)
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = True
cu_seqlens_k = cu_seqlens_q
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device,
)
dropout_p = 0
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
causal=is_causal,
)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
return output
def _init_rope(self):
self.rotary_emb = LingoWhaleRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous())
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
proj = self.qkv_proj(hidden_states)
proj = (proj.unflatten(-1,
(3, self.hidden_size)).unsqueeze(0).transpose(
0, -2).squeeze(-2))
query_states = (proj[0].view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2))
key_states = (proj[1].view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2))
value_states = (proj[2].view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2).transpose(0, 1)
value_states = value_states.transpose(1, 2).transpose(0, 1)
key_states = key_states.transpose(1, 2).transpose(0, 1)
attention_mask = attention_mask < -0.5
if self.config.use_flash_attention and flash_attn_unpadded_func is not None:
assert (
rearrange is not None
), "Please install einops first, e.g., with pip install einops"
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous()
for x in (query_states, key_states, value_states)
]
attn_output = self._self_attention_flash(q, k, v)
attn_output = rearrange(attn_output,
"b s h d -> s b (h d)").contiguous()
else:
attn_output = self._self_attention(query_states, key_states,
value_states, attention_mask)
attn_output = attn_output.transpose(0, 1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LingoWhaleDecoderLayer(nn.Module):
def __init__(self, config: LingoWhaleConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LingoWhaleAttention(config=config)
self.mlp = LingoWhaleMLP(config)
self.input_layernorm = LingoWhaleRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = LingoWhaleRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
class LingoWhalePreTrainedModel(PreTrainedModel):
config_class = LingoWhaleConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LingoWhaleDecoderLayer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LingoWhaleModel):
module.gradient_checkpointing = value
class LingoWhaleModel(LingoWhalePreTrainedModel):
def __init__(self, config: LingoWhaleConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
LingoWhaleDecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.norm = LingoWhaleRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.drop = nn.Dropout(config.emb_dropout_prob)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(
inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask
if combined_attention_mask is None else
expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
hidden_states = self.drop(hidden_states)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = (past_key_values[idx]
if past_key_values is not None else None)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LingoWhaleForCausalLM(LingoWhalePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LingoWhaleModel(config)
self.vocab_size = config.vocab_size
self.lm_head = torch.nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = (config if config is not None else
pretrained_model_name_or_path)
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=False,
proxies=None,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder="",
_from_auto=False,
_from_pipeline=None,
**kwargs,
)
else:
model_kwargs = kwargs
if "torch_dtype" not in kwargs:
kwargs["torch_dtype"] = config.torch_dtype
return super(LingoWhaleForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
softmax_normalizer = shift_logits.max(-1).values**2
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update({
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
})
return model_inputs
def build_chat_input(self,
tokenizer,
messages: List[dict],
max_new_tokens: int = 0,
user_token_ids=[3],
assistant_tokens=[4]):
max_input_tokens = self.config.model_max_length - max_new_tokens
def _parse_messages(messages):
chat_rounds, chat_round = [], []
for message in messages:
if message['role'] == 'user' and len(chat_round) > 0:
chat_rounds.append(chat_round)
chat_round = []
chat_round.append(message)
if len(chat_round) > 0:
chat_rounds.append(chat_round)
return chat_rounds
chat_rounds = _parse_messages(messages)[::-1]
def get_chat_tokens(tokenizer, chat_round, user_token_ids,
assistant_tokens):
tokens = []
tokens += user_token_ids
assert len(chat_round) < 3
if len(chat_round) == 1:
tokens += tokenizer.encode(chat_round[0]['content'])
tokens += assistant_tokens
else:
tokens += tokenizer.encode(chat_round[0]['content'])
tokens += assistant_tokens
tokens += tokenizer.encode(chat_round[1]['content'])
return tokens
input_tokens = []
for chat_round in chat_rounds:
chat_tokens = get_chat_tokens(tokenizer, chat_round,
user_token_ids, assistant_tokens)
if len(chat_tokens + input_tokens) > max_input_tokens:
return input_tokens
input_tokens = chat_tokens + input_tokens
return torch.LongTensor([input_tokens]).to(self.device)
def chat(self,
tokenizer,
messages: List[dict],
stream=False,
generation_config: Optional[GenerationConfig] = None,
max_new_tokens = 100):
if generation_config is not None:
max_new_tokens = generation_config.max_new_tokens
input_ids = self.build_chat_input(tokenizer, messages, max_new_tokens)
if stream:
streamer = TextIterStreamer(tokenizer,
skip_prompt=True,
skip_special_tokens=True)
Thread(target=self.generate,
kwargs=dict(inputs=input_ids,
streamer=streamer,
generation_config=generation_config)).start()
return streamer
else:
outputs = self.generate(input_ids,
generation_config=generation_config)
response = tokenizer.decode(outputs[0][len(input_ids[0]):],
skip_special_tokens=True)
return response
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past