[ISSUE] forward() requires input_ids even if inputs_embeds is provided alternatively

#23
by x5fu - opened

The forward() implementation (https://huggingface.co/THUDM/glm-4-9b-chat/blob/cbc9aaf3ec306a41351dab9b262b120b610f9ad9/modeling_chatglm.py#L757)
requires input_ids even if inputs_embeds is provided.

This is not aligned with the design of most widely used hf transformer models (e.g. gemma, llama, mistral), source

inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional):
Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert input_ids indices into associated vectors than the
model's internal embedding lookup matrix.

This is a crucial feature and it would be very beneficial if you can update to align with the common practice. For your reference, check the implementation of forward in gemma and llama. Hope this could help you address this problem. Thank you!

Sign up or log in to comment