pszemraj's picture
add details on usage
c53c970
|
raw
history blame
1.6 kB
metadata
license: apache-2.0
language:
  - en
pipeline_tag: text-generation
inference: false
datasets:
  - the_pile_books3

mpt-7b-storywriter: sharded

Open In Colab

This is a version of the mpt-7b-storywriter model, sharded to 2 GB chunks for low-RAM loading (i.e. Colab). The weights are stored in bfloat16 so in theory you can run this on CPU, though it may take forever.

Please refer to the previously linked repo for details on usage/implementation/etc. This model was downloaded from the original repo under Apache-2.0 and is redistributed under the same license.

Basic Usage

Install/upgrade packages:

pip install -U torch transformers accelerate

Load the model:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = 'ethzanalytics/mpt-7b-storywriter-sharded'
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    revision='b51ddaf1a256420debfb44fd7367ed7b291b7c19', # optional, but a good idea
    device_map='auto',
    load_in_8bit=False, # install bitsandbytes then set to true for 8-bit
)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Then you can use model.generate() as you would normally - see the notebook for details.