Fix incorrect image embedding when running with a single GPU and 24GB VRAM
Browse files# Issue
When running the image encoding function on a single GPU with no more than 24GB ram, `model.encode_image(pixel_values, mode='InternVL-G')` returns incorrect value
# To Reproduce
- Hardware: Single GPU with **no more than 24GB memory** (e.g. RTX3090/4090).
- transformers==4.37.2
- accelerate==0.24.1
minimal code to reproduce:
```python
import torch
import requests
from io import BytesIO
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
from transformers import AutoTokenizer
model_path = 'OpenGVLab/InternVL-14B-224px'
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map='auto',
trust_remote_code=True
).eval()
image_processor = CLIPImageProcessor.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, add_eos_token=True)
tokenizer.pad_token_id = 0 # set pad_token_id to 0
print(model.hf_device_map)
request = requests.get('https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/2yzk5wUY-obL6H4rKiHlU.webp')
images = [
Image.open(BytesIO(request.content)).convert('RGB'),
]
pixel_values = image_processor(images=images, return_tensors='pt').pixel_values
pixel_values = pixel_values.to(torch.bfloat16).cuda()
with torch.inference_mode():
features = model.encode_image(pixel_values, mode='InternVL-G')
print(features)
```
expected output (when GPU memory >> 24GB or num of GPUs > 1 the output is as expected):
```
tensor([[-8.1055e-02, 1.1133e-01, 3.5889e-02, -1.4893e-02, 8.9722e-03,
1.5527e-01, 2.8320e-02, -5.5664e-02, 1.0352e-01, -1.1963e-02,
-5.4688e-02, -6.4941e-02, -6.8665e-03, -1.0498e-01, -1.2329e-02,
-5.7129e-02, 1.3062e-02, 4.4678e-02, -5.5176e-02, -7.8125e-02,
-9.5703e-02, 1.9409e-02, 4.5898e-02, -2.4414e-03, -4.2969e-02,
...
```
actual output (when GPU memory = 24GB and num of GPU = 1):
```
'clip_projector': 'cpu', 'clip_projector2': 'cpu', 'itm_head': 'cpu'}
tensor([[ 4.4434e-02, 1.0620e-02, 8.3008e-03, 4.7363e-02, -2.2583e-03,
-2.0996e-02, 3.5400e-02, -4.2969e-02, -5.0049e-02, -1.2451e-02,
-7.5195e-02, -8.3008e-03, -2.5391e-02, 6.5918e-03, -1.3306e-02,
-1.7700e-02, 2.8076e-02, -2.7222e-02, -1.4771e-02, -3.2227e-02,
8.1543e-02, 2.3926e-02, -1.6357e-02, -7.5195e-02, 1.8921e-02,
...
```
# Analysis
I've managed to trace down the problem and find that the error is caused by the CrossAttention module
In the original code inside `CrossAttention.forward`:
```
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
```
In this case, `self.q.weight` is automatically offloaded by *accelerate* to save memory. It is expected to be automatically loaded into GPU when the `self.q` is called.
However, the code references `self.q.weight` without actually calling the forward method of `self.q`. As a result, *accelerate* will not be able to load the correct weight onto GPU to perform the linear operation.
Performing F.linear on an uninitialized weight will produce unpredictable output (without any error or warning), and therefore lead to incorrect image embeddings.
# Fix
Simulate pre_forward and post_forward hook to tell *accelerate* to load and offload the weight
```python
# simulate module forward hooks to let accelerate load the actual weight
# see https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/hooks.py#L163
simulate_hooks = hasattr(self.q, '_hf_hook')
if simulate_hooks:
self.q._hf_hook.pre_forward(self.q, x)
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
if simulate_hooks:
self.q._hf_hook.post_forward(self.q, x)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
if simulate_hooks:
self.k._hf_hook.pre_forward(self.k, k)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
if simulate_hooks:
self.k._hf_hook.post_forward(self.k, k)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
if simulate_hooks:
self.v._hf_hook.pre_forward(self.v, v)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
if simulate_hooks:
self.v._hf_hook.post_forward(self.v, v)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
```
- modeling_internvl.py +16 -0
@@ -114,13 +114,29 @@ class CrossAttention(nn.Module):
|
|
114 |
k_bias = self.k_bias
|
115 |
v_bias = self.v_bias
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
|
|
|
|
118 |
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
|
119 |
|
|
|
|
|
120 |
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
|
|
|
|
121 |
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
122 |
|
|
|
|
|
123 |
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
|
|
|
|
124 |
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
125 |
|
126 |
q = q * self.scale
|
|
|
114 |
k_bias = self.k_bias
|
115 |
v_bias = self.v_bias
|
116 |
|
117 |
+
# simulate module forward hooks to let accelerate load the actual weight
|
118 |
+
# see https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/hooks.py#L163
|
119 |
+
simulate_hooks = hasattr(self.q, '_hf_hook')
|
120 |
+
|
121 |
+
if simulate_hooks:
|
122 |
+
self.q._hf_hook.pre_forward(self.q, x)
|
123 |
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
124 |
+
if simulate_hooks:
|
125 |
+
self.q._hf_hook.post_forward(self.q, x)
|
126 |
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
|
127 |
|
128 |
+
if simulate_hooks:
|
129 |
+
self.k._hf_hook.pre_forward(self.k, k)
|
130 |
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
131 |
+
if simulate_hooks:
|
132 |
+
self.k._hf_hook.post_forward(self.k, k)
|
133 |
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
134 |
|
135 |
+
if simulate_hooks:
|
136 |
+
self.v._hf_hook.pre_forward(self.v, v)
|
137 |
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
138 |
+
if simulate_hooks:
|
139 |
+
self.v._hf_hook.post_forward(self.v, v)
|
140 |
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
|
141 |
|
142 |
q = q * self.scale
|