internlm-xcomposer-7b-4bit / modeling_InternLM_XComposer.py
myownskyW7's picture
Add fine-tuning code
9a3a17c
import copy
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import contextlib
import torch.utils.checkpoint
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from .modeling_perceive_sampler import BertConfig, BertLMHeadModel
from .modeling_vit import *
from .modeling_InternLM import *
from .modeling_utils import *
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InternLMXComposerForCausalLM(PreTrainedModel):
config_class = InternLMXComposerConfig
_auto_class = "AutoModelForCausalLM"
meta_instruction = """meta instruction
You are an AI assistant whose name is 浦语.
- 浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- 浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文.
conversation
"""
gen_config = dict(
num_beams=5,
do_sample=False,
min_length=1,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=1.0,
max_new_tokens=500,
)
def __init__(self, config):
super().__init__(config)
self.max_length = config.max_length
rank0_print('Init VIT ... ', end='')
self.visual_encoder = create_eva_vit_g()
self.ln_vision = LayerNorm(self.visual_encoder.num_features)
rank0_print('Done')
rank0_print('Init Perceive Sampler ... ', end='')
with all_logging_disabled():
self.Qformer, self.query_tokens = self.init_qformer(
config.num_query_token, self.visual_encoder.num_features)
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.Qformer.cls = None
rank0_print('Done')
rank0_print('Init InternLM ... ', end='')
self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
self.flag_image_start.requires_grad = False
self.flag_image_end.requires_grad = False
internlm_lora = config.internlm_lora
self.internlm_lora = internlm_lora
setattr(InternLMForCausalLM, 'lora_cfg', internlm_lora)
if int(torch.__version__[0]) == 1:
self.internlm_model = InternLMForCausalLM._from_config(config).to(
torch.float16)
else:
assert int(torch.__version__[0]) == 2
# speed up init llm
with torch.device('meta'):
self.internlm_model = InternLMForCausalLM._from_config(config)
self.internlm_model.to_empty(device=config.device).to(torch.float16)
self.internlm_model.to(config.device)
for n, m in self.internlm_model.named_modules():
if 'lora' in n:
m.float()
self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
self.internlm_model.config.hidden_size)
rank0_print('Done')
self.vis_processor = transforms.Compose([
transforms.Resize((224, 224),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
self.tokenizer = None
self.eoh = '<TOKENS_UNUSED_0>' # end of human
self.eoa = '<TOKENS_UNUSED_1>' # end of assistant
stop_words_ids = [
torch.tensor([103027]).to(config.device),
torch.tensor([103028]).to(config.device),
]
stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids)])
self.gen_config['stopping_criteria'] = stopping_criteria
self.supports_gradient_checkpointing = True
def get_input_embeddings(self):
return self.internlm_model.get_input_embeddings()
def _set_gradient_checkpointing(self, module, value=False):
if value:
self.internlm_model.apply(
partial(self.internlm_model._set_gradient_checkpointing,
value=True))
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_qformer(cls,
num_query_token,
vision_width,
cross_attention_freq=2,
pretrain=True):
encoder_config = BertConfig()
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size))
query_tokens.data.normal_(mean=0.0,
std=encoder_config.initializer_range)
return Qformer, query_tokens
def encode_img(self, image):
if image is None:
return None
if isinstance(image, str):
image = Image.open(image).convert("RGB")
image = self.vis_processor(image).unsqueeze(0).to(self.device)
else:
assert isinstance(image, torch.Tensor)
device = image.device
with self.maybe_autocast():
image_embeds = self.ln_vision(
self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1,
-1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_internlm = self.internlm_proj(
query_output.last_hidden_state)
inputs_internlm = torch.cat([
self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1),
inputs_internlm,
self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1)
],
dim=1)
return inputs_internlm
def encode_text(self, text, add_special_tokens=False):
text_token_ids = self.tokenizer(
text,
return_tensors='pt',
add_special_tokens=add_special_tokens,
).input_ids.to(self.device)
text_embeds = self.internlm_model.model.embed_tokens(text_token_ids)
return text_embeds
def decode_text(self, out_embeds):
out_text = self.tokenizer.batch_decode(out_embeds,
skip_special_tokens=True)[0]
out_text = out_text.split(self.eoa)[0]
return out_text
def wrap_text(self, user_text, bot_text='', add_special=True):
if add_special:
eoh = self.eoh
else:
eoh = ''
text = f' <|User|>:{user_text} \n{eoh} <|Bot|>:{bot_text}'
return text
def get_gen_args(self, **kwargs):
new_kargs = copy.deepcopy(self.gen_config)
new_kargs.update(kwargs)
return new_kargs
def generate(self, text, image=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_img(image)
prompt_embeds = self.wrap_prompt(text_embeds, img_embeds)
out_embeds = self.internlm_model.generate(
inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
return out_text
def chat(self, text, image=None, history=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_img(image)
prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
history=history)
out_embeds = self.internlm_model.generate(
inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
# trunc at eoh and eoa
clean_out_text_token_ids = self.tokenizer(
out_text, return_tensors='pt').input_ids.to(self.device)
clean_out_text_embeds = self.internlm_model.model.embed_tokens(
clean_out_text_token_ids)
clean_prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
add_special=False)
cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
dim=1)
if history is None:
history = []
history.append(cur_history)
return out_text, history
def wrap_prompt(self,
text_embeds,
img_embeds=None,
history=None,
add_special=True):
if add_special:
if history is None:
prompt_segs = [
self.meta_instruction + ' <|User|>:',
f'\n{self.eoh} <|Bot|>:'
]
else:
prompt_segs = [' <|User|>:', f'\n{self.eoh} <|Bot|>:']
else:
prompt_segs = [' <|User|>:', ' <|Bot|>:'] # used in wrap history
prompt_seg_embeds = []
for i, seg in enumerate(prompt_segs):
if history is not None:
add_special_tokens = False
else:
add_special_tokens = i == 0
seg_embeds = self.encode_text(
seg, add_special_tokens=add_special_tokens)
prompt_seg_embeds.append(seg_embeds)
if img_embeds is None:
img_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
text_embeds.size(-1))
prompt_seg_embeds = [
prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1]
]
prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
if history is not None:
prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
return prompt_embeds
######################
# code for training
######################
def prompt_wrap(self, img_embeds, prompt):
batch_size = img_embeds.shape[0]
p_before, p_after = prompt.split('<ImageHere>')
p_before_tokens = self.tokenizer(p_before,
return_tensors="pt",
add_special_tokens=True).to(
img_embeds.device)
p_before_embeds = self.internlm_model.model.embed_tokens(
p_before_tokens.input_ids).expand(batch_size, -1, -1)
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds], dim=1)
wrapped_atts_img = torch.ones(wrapped_img_embeds.size()[:-1],
dtype=torch.long).to(img_embeds.device)
wrapped_target = torch.ones(
batch_size, wrapped_img_embeds.shape[1], dtype=torch.long).to(
img_embeds.device) * -100
return wrapped_img_embeds, wrapped_atts_img, wrapped_target
def align_text(self, samples, has_img=False): ### add eos and eoa
text_new = []
if has_img: ### remove the first user to wrap image features
text = [
t.replace("<image>", "").split("<|User|>:", 1)[-1].lstrip()
for t in samples["text_input"]
]
else:
text = [t for t in samples["text_input"]]
text = [t + self.eoa + ' </s>' for t in text]
for i in range(len(text)):
temp = text[i]
temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
temp = temp.replace(' <|User|>', self.eoa + ' <|User|>')
if temp.find(self.eoh) > temp.find(self.eoa):
temp = temp.replace(self.eoa, '', 1)
text_new.append(temp)
return text_new
def text2emb(self, text):
to_regress_tokens = self.tokenizer(text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_length,
add_special_tokens=False).to(
self.device)
targets = self.mask_human_targets(to_regress_tokens.input_ids)
targets = targets.to(self.device)
return to_regress_tokens, targets
def mask_human_targets(self, input_ids, pure=False):
target_batch = []
for bs in range(input_ids.shape[0]):
cur_idx = 0
ids = input_ids[bs]
targets = copy.deepcopy(ids)
last_eoa = 0
last_eoh = 0
for i, temp_id in enumerate(ids):
if temp_id == 103027: #### end of human
targets[cur_idx:i + 6] = -100
cur_idx = i + 6
last_eoh = i
elif temp_id == 103028: ### end of assistant
cur_idx = i + 1
last_eoa = i
elif temp_id == 2: ### eos and following pad
targets[i + 1:] = -100 #### loss on eos, but not on pad
break
if temp_id != 2 and last_eoa > last_eoh: ### trunction, end at last question
targets[last_eoa +
1:] = -100 #### mask all after the last answer
target_batch.append(targets.unsqueeze(0))
target_batch = torch.cat(target_batch, dim=0)
return target_batch
def forward(self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs):
samples = kwargs.get('samples')
has_img = 'images' in samples.keys()
### encode text
text = self.align_text(samples, has_img=has_img)
to_regress_tokens, targets = self.text2emb(text)
to_regress_embeds = self.internlm_model.model.embed_tokens(
to_regress_tokens.input_ids)
attention_mask = to_regress_tokens.attention_mask
if has_img:
header = samples["text_input"][0].split(' <|User|>:')[0]
prompt = header + ' <|User|>:<ImageHere>'
### encode image
image = samples["image"]
img_embeds = self.encode_img(image)
img_embeds, atts_img, wrapped_target = self.prompt_wrap(
img_embeds, prompt)
### combine text and image
to_regress_embeds = torch.cat([img_embeds, to_regress_embeds],
dim=1)
attention_mask = torch.cat([atts_img, attention_mask], dim=1)
targets = torch.cat([wrapped_target, targets], dim=1)
outputs = self.internlm_model(
inputs_embeds=to_regress_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
return outputs