High GPU RAM usage

#30
by Alealejandrooo - opened

Has anyone experienced a high GPU RAM usage with this model?
When downloaded in 8bit, the model only takes about 7GB on the GPU, but at inference time it shoots up to 28GB.
Anyone else / suggestions on how to fix this?

Facing the same issue +1

What configuration do you use when loading the model? default or overridden?

What configuration do you use when loading the model? default or overridden?

I load it in int8 with a device_map = auto.

What about the FP (dtype)?

What about the FP (dtype)?

torch_dtype=torch.float32

Is that the issue?

all in all:

device_map="auto",
torch_dtype=torch.float32,
load_in_8bit= True

The default model’s configuration uses bfloat16 floating point format "torch_dtype": "bfloat16". Switching to half point precision will definitely help reducing memory footprint and increase performance.

I experience the same problem with generation using cache: bfloat16 28GB, float32 39GB. For llama it is 20GB for bfloat16.
This model has bigger context.

Sign up or log in to comment