smallm_70_rope / model.py
Azrail's picture
Upload SmalLmForCausalLM
a914078 verified
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from .config import SmalLmConfig
from typing import Optional
import logging
from einops import rearrange
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from einops._torch_specific import allow_ops_in_compiled_graph
allow_ops_in_compiled_graph()
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import unpad_input, pad_input
logger = logging.getLogger(__name__)
class SwiGLU(nn.Module):
def __init__(
self, input_size: int, hidden_size: int, bias: bool = False, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.up_proj = nn.Linear(input_size, hidden_size * 2, bias=bias)
self.down_proj = nn.Linear(hidden_size, input_size, bias=bias)
def forward(self, x):
up_gate = self.up_proj(x)
up, gate = rearrange(up_gate, "... (d span) -> span ... d", d=self.hidden_size)
down = F.silu(gate) * up
return self.down_proj(down)
class Router(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.experts_to_select = self.config.token_experts - self.config.shared_experts
self.gate = nn.Linear(config.hidden_size, config.routed_experts, bias=False)
self.gate_noise = (
nn.Linear(config.hidden_size, config.routed_experts, bias=False)
if config.noisy_experts is True
else None
)
self.bias_coef = config.balancing_coef
self.register_buffer(
"bias", torch.zeros(config.routed_experts), persistent=True
)
self.register_buffer(
"expert_counts", torch.zeros(config.routed_experts), persistent=False
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
# calculating with fp32 for stability
# num_tokens n_shared_experts
gate_logits = self.gate(x)
if self.gate_noise is not None:
gate_logits_noise = F.softplus(self.gate_noise(x))
gate_logits_noise = torch.randn_like(gate_logits_noise) * gate_logits_noise
gate_logits = gate_logits + gate_logits_noise
gate_weights = gate_logits.sigmoid()
original_weights = gate_weights
gate_weights = gate_weights + self.bias
_, top_experts_idx = torch.topk(gate_weights, self.experts_to_select, dim=-1)
counts = torch.bincount(
top_experts_idx.flatten(), minlength=self.config.routed_experts
).detach()
if self.training:
self.expert_counts += counts
top_experts_weights = original_weights.gather(1, top_experts_idx)
top_experts_weights = top_experts_weights / top_experts_weights.sum(
dim=-1, keepdim=True
)
return top_experts_idx, top_experts_weights.type_as(x), counts.tolist()
def update_bias(self):
mean = self.expert_counts.float().mean()
delta = self.bias_coef * torch.sign(mean - self.expert_counts)
self.bias += delta
self.expert_counts.zero_()
class MoE(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.shared_experts = SwiGLU(
config.hidden_size,
config.shared_experts * config.expert_size,
config.moe_bias,
)
self.routed_experts = nn.ModuleList(
[
SwiGLU(config.hidden_size, config.expert_size, config.moe_bias)
for _ in range(config.routed_experts)
]
)
self.router = Router(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = x.view(-1, self.config.hidden_size)
experts_idx, experts_weights, counts = self.router(x)
out = torch.zeros_like(x)
for i, expert in enumerate(self.routed_experts):
if counts[i] == 0:
continue
idx, pos = torch.where(experts_idx == i)
out[idx] += expert(x[idx]) * experts_weights[idx, pos, None]
shared_out = self.shared_experts(x)
return (out + shared_out).view(shape)
def build_alibi_bias(config: SmalLmConfig) -> torch.Tensor:
"""Build ALiBi for specified number of heads:
Returns:
Tensor with ALiBi biases, shape: [num heads]
"""
bias = (
2**-8
/ config.num_attention_heads
* torch.arange(1, config.num_attention_heads + 1).float()
)
return bias
def calc_rotation(num_rotaitions, dim, base, seq_len):
return (
dim
* torch.log(torch.tensor(seq_len).float() / (num_rotaitions * 2 * torch.pi))
/ torch.log(torch.tensor(base))
)
def get_ramp_interpolation(min_idx, max_idx, thetas_dim, eps=1e-6):
if min_idx == max_idx:
max_idx += eps
mult = (torch.arange(thetas_dim) - min_idx) / (max_idx - min_idx)
mult = torch.clamp(mult, 0, 1)
return 1 - mult
def build_rope_bias(config: SmalLmConfig) -> torch.Tensor:
dim = config.head_size
theta = 1.0 / (config.rope_base ** (torch.arange(0, dim, 2).float() / dim))
# neural tangent kernel by part korrection
if config.max_seq_len > config.original_seq_len:
scale = config.max_seq_len / config.original_seq_len
# from idea that lambda = 2pi / theta_i and lmbad = seq_len / num_rotations, lambda - wavelen
low_interpolation_idx = max(
0,
torch.ceil(
calc_rotation(
config.high_rotations,
dim,
config.rope_base,
config.original_seq_len,
)
).item(),
)
high_interpolation_idx = min(
dim - 1,
torch.floor(
calc_rotation(
config.low_rotations, dim, config.rope_base, config.original_seq_len
)
).item(),
)
interpolation_mult = get_ramp_interpolation(
low_interpolation_idx, high_interpolation_idx, dim // 2
)
theta = (1 - interpolation_mult) * theta / scale + interpolation_mult * theta
seq_idx = torch.arange(config.max_seq_len)
seq_theta = torch.outer(seq_idx, theta)
bias = torch.polar(torch.ones_like(seq_theta), seq_theta)
return bias
def apply_rope_bias(x: torch.Tensor, precompute_bias: torch.Tensor) -> torch.Tensor:
ini_dtype = x.dtype
# for stbility to fp32, also need for torch
x = rearrange(x.float(), "b n s (d i) -> b n s d i", i=2).contiguous()
x = torch.view_as_complex(x)
x = x * precompute_bias
x = torch.view_as_real(x)
x = rearrange(x, "b n s d i -> b n s (d i)")
return x.to(ini_dtype)
def flash_attention_forward(
module: nn.Module,
x: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
alibi_slope: Optional[torch.Tensor]
) -> torch.Tensor:
query = rearrange(query, "b n s d -> b s n d")
key = rearrange(key, "b n s d -> b s n d")
value = rearrange(value, "b n s d -> b s n d")
query, idx_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(query, attention_mask)
key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
value, _, _, _, _ = unpad_input(value, attention_mask)
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
attention_probs = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=module.config.attention_dropout if module.training else 0.0,
causal=True,
alibi_slopes=alibi_slope if module.config.attention_bias == "alibi" else None,
)
attention_probs = pad_input(attention_probs, idx_q, x.size(0), x.size(1))
out = rearrange(attention_probs, "b s n d -> b s (n d)")
return out, None
def sdpa_attention_forward(
module: nn.Module,
x: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
alibi_slope: Optional[torch.Tensor]
) -> torch.Tensor:
is_causal = attention_mask is None and query.size(-2) > 1
attention_probs = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
enable_gqa=True,
is_causal=is_causal,
dropout_p=module.config.attention_dropout if module.training else 0.0,
)
out = rearrange(attention_probs, "b n s d -> b s (n d)")
return out, None
def eager_attention_forward(
module: nn.Module,
x: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
alibi_slope: Optional[torch.Tensor]
) -> torch.Tensor:
query = rearrange(query, 'b (kv group) s d -> b kv group s d', kv=module.config.num_kv_heads, group=module.head_per_group)
key = rearrange(key, 'b kv s d -> b kv 1 s d')
value = rearrange(
value, 'b kv s d -> b kv 1 s d'
)
attention_weights = query @ key.transpose(-1, -2)
attention_probs = F.dropout(attention_weights / torch.sqrt(
torch.tensor(value.size(-1), device=x.device)
),
p=module.config.attention_dropout if module.training else 0.0
)
if alibi_slope is not None:
alibi_slope = rearrange(
alibi_slope, 'b n s s -> b kv group s s', kv=module.config.num_kv_heads, group=module.head_per_group
)
attention_probs = attention_probs + alibi_slope
elif alibi_slope is None and attention_mask is not None:
attention_mask = attention_mask.expand(-1, module.config.num_attention_heads, -1, -1)
attention_mask = rearrange(
attention_mask, 'b (kv group) s1 s2 -> b kv group s1 s2', kv=module.config.num_kv_heads, group=module.head_per_group
)
attention_probs = attention_probs + attention_mask
attention_probs = F.softmax(attention_probs, dim=-1)
attention_probs = attention_probs @ value
out = rearrange(attention_probs, "b kv group s d -> b s (kv group d)")
return out, attention_weights
ALL_ATTENTION_FUNCTIONS = {
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
"flash_attention_2": flash_attention_forward,
}
class CausalSelfAttention(nn.Module):
def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
super().__init__(*args, **kwargs)
if config.num_attention_heads % config.num_kv_heads != 0:
raise ValueError("Num attention heads should divided by num kv heads")
self.config = config
self.layer_idx = layer_idx
self.head_per_group = config.num_attention_heads // config.num_kv_heads
self.q_proj = nn.Linear(
config.hidden_size,
config.head_size * config.num_attention_heads,
bias=config.attention_bias,
)
self.kv_proj = nn.Linear(
config.hidden_size,
config.head_size * config.num_kv_heads * 2,
bias=config.attention_bias,
)
self.out_proj = nn.Linear(
config.head_size * config.num_attention_heads,
config.hidden_size,
bias=config.attention_bias,
)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: Optional[Cache | torch.FloatTensor],
cache_position: Optional[torch.LongTensor],
bias: torch.Tensor,
):
q = self.q_proj(x)
kv = self.kv_proj(x)
q = rearrange(q, "b s (n d) -> b n s d", n=self.config.num_attention_heads)
k, v = rearrange(kv, "b s (n d q) -> q b n s d", q=2, d=self.config.head_size)
if self.config.positional_bias_type == "rope":
k = apply_rope_bias(k, bias)
q = apply_rope_bias(q, bias)
if past_key_values is not None:
# for static cache
cach_kwargs = {"cache_position": cache_position}
k, v = past_key_values.update(
key_states=k,
value_states=v,
layer_idx=self.layer_idx,
cache_kwargs=cach_kwargs,
)
attention_interface = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
out, attention_weights = attention_interface(
self,
x,
q,
k,
v,
attention_mask,
bias if self.config.positional_bias_type == "alibi" else None
)
out = self.out_proj(out)
return out, attention_weights
class WeightedResidual(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight = nn.Parameter(
torch.ones(config.hidden_size), requires_grad=config.static_residual
)
def forward(self, short, long):
return self.weight * short + long
class Block(nn.Module):
def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_norm = nn.RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=config.rms_affine,
)
self.ffn_norm = nn.RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=config.rms_affine,
)
self.dropout1 = nn.Dropout(config.layer_dropout)
self.dropout2 = nn.Dropout(config.layer_dropout)
self.attention = CausalSelfAttention(config, layer_idx)
self.mlp = (
MoE(config)
if (
config.use_moe
and layer_idx % config.moe_period == 0
and layer_idx > config.no_moe_layers
)
else SwiGLU(config.hidden_size, config.intermediate_size, config.mlp_bias)
)
self.attention_residual = WeightedResidual(config)
self.ffn_residual = WeightedResidual(config)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: Optional[Cache | torch.FloatTensor],
output_attentions: bool,
cache_position: Optional[torch.LongTensor],
bias: torch.Tensor,
) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
identity = inputs_embeds
# attention block
out = self.attn_norm(inputs_embeds)
out, attention_probs = self.attention(
out, attention_mask, past_key_values, cache_position, bias
)
out = self.dropout1(out)
identity = self.attention_residual(identity, out)
# swiglu / MoE block
out = self.dropout2(self.mlp(self.ffn_norm(identity)))
out = self.ffn_residual(identity, out)
if output_attentions:
return out, attention_probs
return (out,)
class SmalLmPreTrainedModel(PreTrainedModel):
config_class = SmalLmConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Block"]
_skip_keys_device_placement = "past_key_values"
_supports_sdpa = True
_supports_flash_attn_2 = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
module.weight.data[self.pad_idx].zero_()
class SmalLmModel(SmalLmPreTrainedModel):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.pad_idx = config.pad_token_id
self.pad_token_id = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
precompute_bias = (
build_alibi_bias(config)
if config.positional_bias_type == "alibi"
else build_rope_bias(config)
)
self.register_buffer("precompute_bias", precompute_bias, persistent=False)
# не забыть про sharing weights на output голове self.embedding.weight = self.output.weight
self.embedding = nn.Embedding(
self.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.embedding_dropout = nn.Dropout(config.embedding_dropout)
self.layers = nn.ModuleList(
[Block(config, idx) for idx in range(1, config.num_hidden_layers + 1)]
)
self.out_norm = nn.RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=config.rms_affine,
)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embedding
def set_input_embeddings(self, value):
self.embedding = value
def forward(
self,
# input options
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# output options
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# cache options
use_cache: Optional[bool] = None,
past_key_values: Optional[Cache | torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple | BaseModelOutputWithPast:
# check additional parameters
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache
if use_cache is not None
else (False if self.training else self.config.use_cache)
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You must specify only input_ids or inputs_embeds, not both"
)
if self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
# calculating position for StaticCache
if cache_position is None:
last_position = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
last_position,
last_position + inputs_embeds.size(1),
device=inputs_embeds.device,
)
causal_mask = self._get_causal_masks(
attention_mask, inputs_embeds, past_key_values, cache_position
)
if self.config.positional_bias_type == "rope":
end_pos = (
inputs_embeds.size(1)
if past_key_values is None
else cache_position[-1] + 1
)
start_pos = 0 if past_key_values is None else cache_position[0]
bias = self.precompute_bias[start_pos:end_pos]
elif self.config.positional_bias_type == "alibi":
if self.config._attn_implementation == "flash_attention_2":
bias = self.precompute_bias
else:
i = torch.arange(
(
inputs_embeds.size(1)
if past_key_values is None
else cache_position[-1] + 1
),
device=inputs_embeds.device,
)
bias = i[:, None] - i[None, :]
bias = torch.tril(bias).expand(
inputs_embeds.size(0), self.config.num_attention_heads, -1, -1
) * rearrange(self.precompute_bias, "n -> 1 n 1 1")
if causal_mask is not None:
causal_mask = causal_mask + bias
else:
causal_mask = bias
hidden_state = inputs_embeds
hidden_states = [hidden_state] if output_hidden_states else None
attentions = [] if output_attentions else None
for idx, layer in enumerate(self.layers, 1):
if self.gradient_checkpointing:
# for details see:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3107
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3149
layer_out = self._gradient_checkpointing_func(
layer.__call__,
hidden_state,
causal_mask,
past_key_values,
output_attentions,
cache_position,
bias,
)
else:
layer_out = layer(
hidden_state,
causal_mask,
past_key_values,
output_attentions,
cache_position,
bias,
)
hidden_state = layer_out[0]
if output_hidden_states:
hidden_states.append(hidden_state)
if output_attentions:
attentions.append(layer_out[1])
hidden_state = self.out_norm(hidden_state)
out = BaseModelOutputWithPast(
last_hidden_state=hidden_state,
past_key_values=past_key_values if use_cache else None,
hidden_states=tuple(hidden_states) if hidden_states is not None else None,
attentions=tuple(attentions) if attentions is not None else None,
)
return out if return_dict else out.to_tuple()
def _get_causal_masks(
self,
attention_mask: Optional[torch.Tensor],
inputs_embeds: torch.Tensor,
past_key_values: Optional[torch.Tensor],
cache_position: Optional[torch.Tensor],
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is None:
attention_mask = torch.ones(
(inputs_embeds.size(0), inputs_embeds.size(1)), device=inputs_embeds.device
).long()
return attention_mask
dtype, device = inputs_embeds.dtype, inputs_embeds.device
past_token = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
if attention_mask is not None and torch.all(attention_mask == 0.0):
return None
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_token,
is_training=self.training,
):
return None
sequence_length = inputs_embeds.size(1)
target_length = (
attention_mask.size(-1)
if isinstance(attention_mask, torch.Tensor)
else past_token + sequence_length + 1
)
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask=attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=inputs_embeds.size(0),
)
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: Optional[torch.Tensor],
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: Optional[torch.Tensor],
batch_size: int,
):
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone()
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class SmalLmForCausalLM(SmalLmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.model = SmalLmModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
# input options
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# output options
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# cache options
use_cache: Optional[bool] = None,
past_key_values: Optional[Cache | torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
# generation options
labels: Optional[torch.Tensor] = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> tuple | CausalLMOutputWithPast:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
model_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = model_outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
)
if not return_dict:
output = (logits, model_outputs[1:])
return (loss, output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_outputs.past_key_values,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
__all__ = ["SmalLmForCausalLM", "SmalLmModel", "SmalLmPreTrainedModel"]