Inference VRAM usage is abnormally large

#37
by manu - opened

I get GPU OOM on 40GB A100s with a batch size of 1 and context lenghts of just a tad more than 512 tokens in greedy search, all the while being able to train models with a micro batch size of 4 in Low rank (bf16). Don't have the problem for copmparable sizes of llama, pythia, bloom, opt, etc...

I am wondering if this is caused by the multi-query attention jey value cache which could be badly configured in the rw_modeling file ?

Sign up or log in to comment