ChatGLM2 failed with bfloat16 using transformers 4.38.2
#109
by
jasonjchen
- opened
When running chatglm2 with bfloat16, it failed.
- Code:
model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm2-6b",torch_dtype=amp_dtype, low_cpu_mem_usage=True,trust_remote_code=True)
tokenizer=AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
with torch.no_grad(), torch.autocast(dtype=torch.bfloat16):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids, **generate_kwargs)
- Error:
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::BFloat16 instead.
BTW, it can be run with transformers 4.31.0.