File size: 10,037 Bytes
9469af5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
from typing import List, Optional, Tuple, Union
import torch
from transformers import (
MistralModel,
MistralPreTrainedModel,
MistralForCausalLM,
MistralConfig,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer,
MistralRMSNorm,
MistralAttention,
MistralFlashAttention2,
MistralSdpaAttention,
MistralMLP,
)
from torch import nn
from transformers.utils import logging
from attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
logger = logging.get_logger(__name__)
class ModifiedMistralAttention(MistralAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
class ModifiedMistralFlashAttention2(MistralFlashAttention2):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
class ModifiedMistralSdpaAttention(MistralSdpaAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
MISTRAL_ATTENTION_CLASSES = {
"eager": ModifiedMistralAttention,
"flash_attention_2": ModifiedMistralFlashAttention2,
"sdpa": ModifiedMistralSdpaAttention,
}
class ModifiedMistralDecoderLayer(MistralDecoderLayer):
def __init__(self, config: MistralConfig, layer_idx: int):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
config, layer_idx
)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = MistralRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
class MistralBiModel(MistralModel):
def __init__(self, config: MistralConfig):
MistralPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
ModifiedMistralDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self._attn_implementation = config._attn_implementation
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
# Copied from forward() in transformers.models.mistral.modeling_mistral.MistralModel
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if (
attention_mask is not None
and self._attn_implementation == "flash_attention_2"
and use_cache
):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = (
attention_mask
if (attention_mask is not None and 0 in attention_mask)
else None
)
elif self._attn_implementation == "sdpa" and not output_attentions:
# The original implementation is by-passed, see attn_mask_utils.py
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if use_legacy_cache
else next_decoder_cache
)
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MistralBiForMNTP(MistralForCausalLM):
def __init__(self, config):
MistralPreTrainedModel.__init__(self, config)
self.model = MistralBiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
|