Update modeling_mpt.py

#45
by ybelkada HF staff - opened
No description provided.

This PR adds the accelerate support for MPT models, so that any user could load these models in 8bit and 4bit

To load this model in 8bit before merging the PR:

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = 'mosaicml/mpt-7b'

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    load_in_8bit=True,
    device_map="auto",
    trust_remote_code=True,
    revision="pr/45"
)

prompt = "What is the boiling point of Nitrogen?"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(0)
out = model.generate(input_ids)
print(tokenizer.decode(out[0], skip_special_tokens=True))
abhi-mosaic changed pull request status to merged

Sign up or log in to comment