Falcon kv cache shape issue

#3
by luxiaoz - opened

https://huggingface.co/tiiuae/falcon-rw-7b/blob/main/modeling_falcon.py#L314-L315 Here, Falcon has both concat dim = 1, but claims that key is (bs_times_num_head, head_dim, seq_len), while value is (bs_times_num_head, seq_len, head_dim). How does it work? Should both key and value shape be (bs_times_num_head, seq_len, head_dim) instead? To compare: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L305-L306

Sign up or log in to comment