TypeError: forward() takes 2 positional arguments but 3 were given

#5
by prabhatk579 - opened

Hi, When I'm trying to load the model in 4-bit configuration, I'm getting the following error:

Traceback (most recent call last):
  File ".../model_load_4k.py", line 38, in <module>
    response, _ = model.chat(
...
  File ".../.venv/lib/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

Here is the code for loading the model:

nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )

# init model and tokenizer
model = AutoModel.from_pretrained(
    "internlm/internlm-xcomposer2-4khd-7b",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    # low_cpu_mem_usage=True,
    # load_in_4bit=True,
    cache_dir="./model",
    quantization_config=nf4_config,
).eval()

tokenizer = AutoTokenizer.from_pretrained(
    "internlm/internlm-xcomposer2-4khd-7b",
    trust_remote_code=True,
    cache_dir="./model",
    load_in_4bit=True,
)

###############
# First Round
###############

query = "<ImageHere>Illustrate the fine details present in the image"
image = "examples/example4.jpeg"
with torch.cuda.amp.autocast():
    response, _ = model.chat(
        tokenizer,
        query=query,
        image=image,
        hd_num=55,
        history=[],
        do_sample=False,
        num_beams=3,
    )
print(response)

How can I load the model in 4-bit as I have limited resources?

Hi! I figured out it was Plora deriving from nn.Linear instead of nn.Module (and b&b targeting for vanilla linear layers) and I made a script to convert the checkpoint to the launchable format. And it requires swapping build_mlp's Plora to one with inner Linear layer.

import torch, os
bin_file = 'pytorch_model-00002-of-00002.bin'
sd = torch.load(bin_file)
suffixes = ['.w1', '.w3', '.w2', '.wqkv', '.wo']
upd = {}
to_pop = []

for k,v in sd.items():
    for s in suffixes:
        if s in str(k) and 'Plora' not in k:
            new_name = str(k).split(s)[0] + s + ".linear" + str(k).split(s)[1]
            to_pop.append(k)
            upd[new_name] = v

print(len(to_pop))
sd.update(upd)
for k in to_pop:
    j = sd.pop(k)

os.remove(bin_file)
torch.save(sd, bin_file)

https://huggingface.co/internlm/internlm-xcomposer2-7b-4bit/blob/main/build_mlp.py

Sign up or log in to comment