Would it work well with sequence length > 2048?

#1
by SamuelAzran - opened

Like other MPT models.

Nomic AI org

Possibly, although we haven't tested this extensively. Let us know if you find that it works well!

I've tried with the following code and it (any sequence length rather than 2048) doesn't work for me:

The same code works for mosaicml/mpt-7b-instruct though.

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch

device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'

print(f'Selected device is: {device}')

model_name = "nomic-ai/gpt4all-mpt"

config = AutoConfig.from_pretrained(
  model_name,
  trust_remote_code=True
)
# use the optimized triton implementation of FlashAttention, you can load the model with attn_impl='triton' and move the model to bfloat16
#config.attn_config['attn_impl'] = 'triton'
config.init_device = device
# config.max_seq_len = 2048
# update the maximum sequence length during inference to 4096
config.max_seq_len = 3072

print(config)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.bfloat16,
    trust_remote_code = True
)

model.eval()

I got the following error:

RuntimeError: Error(s) in loading state_dict for MPTForCausalLM:
    size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([2048, 4096]) from checkpoint, the shape in current model is torch.Size([3072, 4096]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Set ignore_mismatched_sizes=True still won't fix it. Instead, you got a different error:

File /opt/anaconda3/lib/python3.9/site-packages/transformers/modeling_utils.py:3031, in PreTrainedModel._load_pretrained_model.<locals>._find_mismatched_keys(state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes)
   3025 elif add_prefix_to_model:
   3026     # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
   3027     model_key = ".".join(checkpoint_key.split(".")[1:])
   3029 if (
   3030     model_key in model_state_dict
-> 3031     and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
   3032 ):
   3033     mismatched_keys.append(
   3034         (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
   3035     )
   3036     del state_dict[checkpoint_key]

KeyError: 'transformer.blocks.11.ffn.down_proj.weight'

By the way, this model also doesn't support the optimized triton implementation of FlashAttention like mosaicml/mpt-7b-instruct.
If you turn it on via config.attn_config['attn_impl'] = 'triton', you will get the same KeyError: 'transformer.blocks.11.ffn.down_proj.weight' error.

@zpn any chance you could shed some light on the possible cause of this error? Thanks a lot~

Sign up or log in to comment