How did you apply bi-directional attention for embedding only?
Hello. Thanks for your great work and helpful code in your github repo! It helps me a lot.
I am wondering how you implemented bi-directional attention only for embedding while maintaining causal attention for generation simultaneously.
What I understood so far from your code (https://github.com/ContextualAI/gritlm/blob/main/gritlm/training/model.py) is as follow.
I get that you put 'is_causal' key into the "kwargs" variable(Dict) and pass it to model call if self.attn[:2] == 'bb' (default to 'bbcc').
But I can't understand how this can remove causal mask even if there is no argument named 'is_causal' in forward function in most model (ex. MistralForCausalLM).
Would you mind giving me slight advice or information how it works?
Thank you so much.
kwargs = {'input_ids': features.get('input_ids'), 'attention_mask': attention_mask}
if self.attn[:2] == 'cb':
kwargs['instruction_lens'] = instruction_lens
elif self.attn[:2] == 'bb':
kwargs['is_causal'] = False
out = (getattr(self.model, self.embedding_attr) if self.embedding_attr else self.model)(**kwargs)[0]
The repo has a custom modeling file that has the is_causal
kwarg: https://huggingface.co/GritLM/GritLM-7B/blob/main/modeling_gritlm7b.py
Thank you so much! Now I understood!