File size: 1,764 Bytes
5eb7d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch.nn as nn
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaMLP
from .configuration_llama_lm_feats import LlamaWithFeatsEncoderConfig
class LlamaFeatsMLP(LlamaMLP):
def __init__(self, config):
super().__init__(config)
self.gate_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False)
class LlamaWithFeatsForCausalLM(LlamaForCausalLM):
config_class = LlamaWithFeatsEncoderConfig
def __init__(self, config: LlamaWithFeatsEncoderConfig):
super().__init__(config)
self.feature_mlp = LlamaFeatsMLP(config)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
meta_features=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if meta_features is not None:
feats_embeds = self.feature_mlp(meta_features)
inputs_embeds = inputs_embeds + feats_embeds
return super().forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
|