|
import torch.nn as nn |
|
from transformers import LlamaForCausalLM |
|
from transformers.models.llama.modeling_llama import LlamaMLP |
|
|
|
from .configuration_llama_lm_feats import LlamaWithFeatsEncoderConfig |
|
|
|
|
|
class LlamaFeatsMLP(LlamaMLP): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.gate_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(config.feats_hidden_size, self.intermediate_size, bias=False) |
|
|
|
|
|
class LlamaWithFeatsForCausalLM(LlamaForCausalLM): |
|
config_class = LlamaWithFeatsEncoderConfig |
|
|
|
def __init__(self, config: LlamaWithFeatsEncoderConfig): |
|
super().__init__(config) |
|
self.feature_mlp = LlamaFeatsMLP(config) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
meta_features=None, |
|
position_ids=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
if meta_features is not None: |
|
feats_embeds = self.feature_mlp(meta_features) |
|
inputs_embeds = inputs_embeds + feats_embeds |
|
|
|
return super().forward( |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|