It seems that this project can only support a batch_size of 1 during inference?

#1
by howard-hou - opened

I try to make the input batch_size =2 by
inputs = tokenizer([prompt, prompt], return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=256)
and it raise a runtime error:

175     # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693

--> 176 key = self.key(key).to(torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1)
177 value = self.value(value).to(torch.float32).view(T, H, S).transpose(0, 1)
178 receptance = self.receptance(receptance).to(torch.float32).view(T, H, S).transpose(0, 1)

RuntimeError: shape '[47, 32, 64]' is invalid for input of size 192512

Same problem encountered.

Sign up or log in to comment