Rewrite image embedding to remove the in-place op
#53
by
YenChunChen
- opened
When attempting to LoRA-finetune Phi-3-V, the following error occurred if trying to LoRA the CLIP encoder together with the Phi-3 LM.
File "/home/yenchun/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-vision-128k-instruct/f998a184b56bf0399b3af85c50b20ec0d5688f5f/image_embedding_phi3_v.py", line 280, in forward
hidden_states[positions[idx, 0], positions[idx, 1] : positions[idx, 1] + cnt] = (
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
This PR use the out-of-place equivalent index_put
instead. Moreover, the complete rewrite of forward
removes code paths for early model variants for better readability. The final version of hd_transform is refactored for better readability as well. See the below comment for parity tests.
@haipingwu please review.
Parity and batching tests:
import copy
import requests
import torch
import torch.testing
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from configuration_phi3_v import Phi3VConfig
from image_embedding_phi3_v import Phi3ImageEmbedding
def load_models():
model_path = 'microsoft/Phi-3-vision-128k-instruct'
kwargs = {}
kwargs['torch_dtype'] = torch.bfloat16
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype='auto',
_attn_implementation='eager',
revision='f998a184b56bf0399b3af85c50b20ec0d5688f5f',
).cuda()
config = Phi3VConfig.from_pretrained(model_path, _attn_implementation='eager')
embedding_config = {'embedding_cls': config.embd_layer['embedding_cls'], **config.embd_layer}
image_embed_state_dict = model.model.vision_embed_tokens.state_dict()
old_image_embedder = copy.deepcopy(model.model.vision_embed_tokens)
old_image_embedder.load_state_dict(image_embed_state_dict)
new_image_embedder = (
Phi3ImageEmbedding(config, wte=model.model.embed_tokens, **embedding_config)
.bfloat16()
.cuda()
)
new_image_embedder.load_state_dict(image_embed_state_dict)
del model
return processor, old_image_embedder, new_image_embedder
def test_input_1(processor):
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = '<|end|>\n'
prompt = (
f'{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}'
)
url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
print(f'>>> Prompt\n{prompt}')
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(prompt, image, return_tensors='pt').to('cuda:0')
return inputs
def test_input_2(processor):
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = '<|end|>\n'
prompt = f'{user_prompt}<|image_1|>\nCan you convert the table to markdown format?{prompt_suffix}{assistant_prompt}'
url = 'https://support.content.office.net/en-us/media/3dd2b79b-9160-403d-9967-af893d17b580.png'
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(prompt, image, return_tensors='pt').to('cuda:0')
return inputs
def compare_old_and_new_forward(old_image_embedder, new_image_embedder, inputs):
input_ids = inputs['input_ids'].clone().cuda()
pixel_values = inputs['pixel_values'].bfloat16().cuda()
image_sizes = inputs['image_sizes']
with torch.no_grad():
old_outputs = old_image_embedder(
input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes
)
input_ids = inputs['input_ids'].clone().cuda()
new_outputs = new_image_embedder(
input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes
)
torch.testing.assert_close(old_outputs, new_outputs)
return old_outputs, new_outputs
def main():
processor, old_image_embedder, new_image_embedder = load_models()
inputs_1 = test_input_1(processor)
inputs_2 = test_input_2(processor)
# test parity of single example
old_outputs_1, new_outputs_1 = compare_old_and_new_forward(
old_image_embedder, new_image_embedder, inputs_1
)
old_outputs_2, new_outputs_2 = compare_old_and_new_forward(
old_image_embedder, new_image_embedder, inputs_2
)
# test parity of batched examples
inputs_1and2 = {
'input_ids': torch.nn.utils.rnn.pad_sequence(
[
inputs_1['input_ids'].squeeze(0).unsqueeze(1),
inputs_2['input_ids'].squeeze(0).unsqueeze(1),
],
batch_first=True,
padding_value=processor.tokenizer.pad_token_id,
).squeeze(2),
'pixel_values': torch.cat([inputs_1['pixel_values'], inputs_2['pixel_values']]),
'image_sizes': torch.cat([inputs_1['image_sizes'], inputs_2['image_sizes']]),
}
old_outputs_1and2, new_outputs_1and2 = compare_old_and_new_forward(
old_image_embedder, new_image_embedder, inputs_1and2
)
# test batching correctness
len_1 = inputs_1['input_ids'].shape[1]
len_2 = inputs_2['input_ids'].shape[1]
torch.testing.assert_close(new_outputs_1[0], new_outputs_1and2[0, :len_1])
torch.testing.assert_close(new_outputs_2[0], new_outputs_1and2[1, :len_2])
# test parity for single example with multiple images
inputs_1plus2 = {
'input_ids': torch.cat([inputs_1['input_ids'], inputs_2['input_ids']], dim=1),
'pixel_values': torch.cat([inputs_1['pixel_values'], inputs_2['pixel_values']]),
'image_sizes': torch.cat([inputs_1['image_sizes'], inputs_2['image_sizes']]),
}
old_outputs_1plus2, new_outputs_1plus2 = compare_old_and_new_forward(
old_image_embedder, new_image_embedder, inputs_1plus2
)
torch.testing.assert_close(new_outputs_1, new_outputs_1plus2[:, :len_1])
torch.testing.assert_close(new_outputs_2, new_outputs_1plus2[:, -len_2:])
# test batched examples with potentially different number of images
inputs_complex = {
'input_ids': torch.nn.utils.rnn.pad_sequence(
[
inputs_1['input_ids'].squeeze(0).unsqueeze(1),
inputs_1plus2['input_ids'].squeeze(0).unsqueeze(1),
inputs_2['input_ids'].squeeze(0).unsqueeze(1),
],
batch_first=True,
padding_value=processor.tokenizer.pad_token_id,
).squeeze(2),
'pixel_values': torch.cat(
[inputs_1['pixel_values'], inputs_1plus2['pixel_values'], inputs_2['pixel_values']]
),
'image_sizes': torch.cat(
[inputs_1['image_sizes'], inputs_1plus2['image_sizes'], inputs_2['image_sizes']]
),
}
old_outputs_complex, new_outputs_complex = compare_old_and_new_forward(
old_image_embedder, new_image_embedder, inputs_complex
)
torch.testing.assert_close(new_outputs_1[0], new_outputs_complex[0, :len_1])
torch.testing.assert_close(new_outputs_1plus2[0], new_outputs_complex[1, : len_1 + len_2])
torch.testing.assert_close(new_outputs_2[0], new_outputs_complex[2, :len_2])
if __name__ == '__main__':
main()
YenChunChen
changed pull request status to
open
hi @leoxiaobin , please merge this PR
@YenChunChen , the branch has merge conflicts. Please fix it.
@leoxiaobin done
leoxiaobin
changed pull request status to
merged