import torch.nn as nn from transformers import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaMLP from .configuration_llama_lm_feats import LlamaWithFeatsEncoderConfig class LlamaFeatsMLP(LlamaMLP): def __init__(self, config): super().__init__(config) self.gate_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False) class LlamaWithFeatsForCausalLM(LlamaForCausalLM): config_class = LlamaWithFeatsEncoderConfig def __init__(self, config: LlamaWithFeatsEncoderConfig): super().__init__(config) self.feature_mlp = LlamaFeatsMLP(config) self.post_init() def forward( self, input_ids=None, attention_mask=None, meta_features=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if meta_features is not None: feats_embeds = self.feature_mlp(meta_features) inputs_embeds = inputs_embeds + feats_embeds return super().forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, )