RuntimeError: shape is invalid for input of size N

#133
by gnomes72 - opened

Complete beginner here. Downloaded the model with huggingface-cli. After initializing the pipeline using the example code and running pipeline("Hello"), I get this error:

  File "/home/me/miniconda3/envs/py/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 195, in forward
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[1, 1, 32, 128]' is invalid for input of size 1024

Using Python 3.10.14 and completely stumped. I read through the model card several times, so apologies if I missed a step.

Also, when I load this model in oobabooga/text-generation-webui, the chat results are very poor quality. But as soon as I load another model, such as mistralai_Mistral-7B-Instruct-v0.2 without changing any other settings, all works as expected.

Thank you and sorry for the noob questions!

I'm having the same issue. Raised again in #151.

/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 
195, in forward
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[18, 217, 32, 128]' is invalid for input of size 3999744  

It's very weird that my same piece of code runs on a 12.2 CUDA but not a CUDA11.3 nor CUDA 12.0 ones...
I'm using a bitsandbytes 8btit quantized model though.

Full environment packages:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
accelerate                0.30.1                   pypi_0    pypi
bitsandbytes              0.42.0                   pypi_0    pypi
blas                      1.0                         mkl  
brotli-python             1.0.9            py39h6a678d5_8  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.3.11            h06a4308_0  
certifi                   2024.2.2         py39h06a4308_0  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
cuda-cudart               11.8.89                       0    nvidia
cuda-cupti                11.8.87                       0    nvidia
cuda-libraries            11.8.0                        0    nvidia
cuda-nvrtc                11.8.89                       0    nvidia
cuda-nvtx                 11.8.86                       0    nvidia
cuda-runtime              11.8.0                        0    nvidia
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.13.1           py39h06a4308_0  
freetype                  2.12.1               h4a9f257_0  
fsspec                    2024.5.0                 pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gmpy2                     2.1.2            py39heeb90bb_0  
gnutls                    3.6.15               he1e5248_0  
huggingface-hub           0.23.1                   pypi_0    pypi
idna                      3.7              py39h06a4308_0  
intel-openmp              2023.1.0         hdb19cb5_46306  
jinja2                    3.1.3            py39h06a4308_0  
jpeg                      9e                   h5eee18b_1  
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libabseil                 20240116.2      cxx17_h59595ed_0    conda-forge
libcublas                 11.11.3.6                     0    nvidia
libcufft                  10.9.0.58                     0    nvidia
libcufile                 1.9.1.3                       0    nvidia
libcurand                 10.3.5.147                    0    nvidia
libcusolver               11.4.1.48                     0    nvidia
libcusparse               11.7.5.86                     0    nvidia
libdeflate                1.17                 h5eee18b_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libiconv                  1.16                 h5eee18b_3  
libidn2                   2.3.4                h5eee18b_0  
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
libnpp                    11.8.0.86                     0    nvidia
libnvjpeg                 11.9.0.86                     0    nvidia
libpng                    1.6.39               h5eee18b_0  
libprotobuf               4.25.3               h08a7969_0    conda-forge
libsentencepiece          0.2.0                hb0b37bd_1    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtasn1                  4.19.0               h5eee18b_0  
libtiff                   4.5.1                h6a678d5_0  
libunistring              0.9.10               h27cfd23_0  
libwebp-base              1.3.2                h5eee18b_0  
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               14.0.6               h9e868ea_0  
lz4-c                     1.9.4                h6a678d5_1  
markupsafe                2.1.3            py39h5eee18b_0  
mkl                       2023.1.0         h213fc3f_46344  
mkl-service               2.4.0            py39h5eee18b_1  
mkl_fft                   1.3.8            py39h5eee18b_0  
mkl_random                1.2.4            py39hdb19cb5_0  
mpc                       1.1.0                h10f8cd9_1  
mpfr                      4.0.2                hb69a4c5_1  
mpmath                    1.3.0            py39h06a4308_0  
ncurses                   6.4                  h6a678d5_0  
nettle                    3.7.3                hbbd107a_1  
networkx                  3.1              py39h06a4308_0  
numpy                     1.26.4           py39h5f9d8c6_0  
numpy-base                1.26.4           py39hb5e798b_0  
openh264                  2.1.1                h4ff587b_0  
openjpeg                  2.4.0                h3ad879b_0  
openssl                   3.0.13               h7f8727e_2  
packaging                 24.0                     pypi_0    pypi
pillow                    10.3.0           py39h5eee18b_0  
pip                       24.0             py39h06a4308_0  
psutil                    5.9.8                    pypi_0    pypi
pysocks                   1.7.1            py39h06a4308_0  
python                    3.9.19               h955ad1f_1  
python_abi                3.9                      2_cp39    conda-forge
pytorch                   2.2.2           py3.9_cuda11.8_cudnn8.7.0_0    pytorch
pytorch-cuda              11.8                 h7e8668a_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0.1            py39h5eee18b_0  
readline                  8.2                  h5eee18b_0  
regex                     2024.5.15                pypi_0    pypi
requests                  2.31.0           py39h06a4308_1  
safetensors               0.4.3                    pypi_0    pypi
scipy                     1.13.0                   pypi_0    pypi
sentencepiece             0.2.0                hf3d152e_1    conda-forge
sentencepiece-python      0.2.0            py39ha537242_1    conda-forge
sentencepiece-spm         0.2.0                hb0b37bd_1    conda-forge
setuptools                69.5.1           py39h06a4308_0  
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.12             py39h06a4308_0  
tbb                       2021.8.0             hdb19cb5_0  
tk                        8.6.14               h39e8969_0  
tokenizers                0.19.1                   pypi_0    pypi
torchaudio                2.2.2                py39_cu118    pytorch
torchtriton               2.2.0                      py39    pytorch
torchvision               0.17.2               py39_cu118    pytorch
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.41.0                   pypi_0    pypi
typing_extensions         4.11.0           py39h06a4308_0  
tzdata                    2024a                h04d1e81_0  
urllib3                   2.2.1            py39h06a4308_0  
wheel                     0.43.0           py39h06a4308_0  
xz                        5.4.6                h5eee18b_1  
yaml                      0.2.5                h7b6447c_0  
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.5                hc292b87_2 

Sign up or log in to comment