ChatGLM2 failed with bfloat16 using transformers 4.38.2

#109
by jasonjchen - opened

When running chatglm2 with bfloat16, it failed.

  • Code:
model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm2-6b",torch_dtype=amp_dtype, low_cpu_mem_usage=True,trust_remote_code=True)
tokenizer=AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
with torch.no_grad(), torch.autocast(dtype=torch.bfloat16):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        output = model.generate(input_ids, **generate_kwargs)
  • Error:
    RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::BFloat16 instead.

BTW, it can be run with transformers 4.31.0.

Sign up or log in to comment