Configuring Command-R for long context tasks
Apologies for the duplicate post, but the previous related discussion was unclear to me.
saurabhdash mentions:
"This implementation is based on the Llama implementation which materializes this huge buffer which would not be feasible for 128k context. The model does support 128k context with a better implementation."
and then gives the following line of python:
causal_mask = torch.full( (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool )
What exact steps do we need to follow to implement this?
I've tried editing the max_position_embeddings directly in the config.json, and can only run a 25k prompt with max_position_embeddings=32768 and 8 bit quant using a machine with 2x A100 (approx 160GB VRAM).
Can someone indicate how this default implementation needs to change to use the better implementation mentioned above:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="auto", quantization_config=bnb_config)
Hi! Apart from the materialized attention mask, there is another problem -- the logits are up-casted to fp32. If you have a seq length of 128k, the logits themselves would take up 128k * 256k * 4(bytes)= 131GB. If the goal is to use it for generation, one could get rid of this and just do log-softmax over the last token's logits.
Thanks for your answer @saurabhdash ! In terms of implementation:
- Would the implementation of causal_mask at line 614 of modeling_cohere.py in forward() need to change to your above implementation?
- Where would you change the implementation of the logits? Any tips about how to do so?
- What's a reasonable VRAM usage to expect for a 128k task with these optimisations? Am I over-optimistic to think that we can fit a context of that size on 2x A100s?
Apologies if these are silly questions, still a little new to all this
I'd recommend waiting for/ using the vLLM implementation. That should be able to help you scale the context to the maximum.