InfMLLM2_7B_chat / modeling_infmllm_unified_hd_chat.py
QianYEee's picture
Upload 18 files
8a096e8 verified
raw
history blame
15.3 kB
#coding=utf-8
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from contextlib import suppress
import logging
from einops import rearrange
from peft import LoraConfig, get_peft_model
from bigmodelvis import Visualization
from .clip_encoder_hd import CLIPVisionTowerHD
from .conversation import get_conv_template
from .processors_conv import preprocess_qwen
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers.generation import GenerationConfig
from transformers import Qwen2Config, Qwen2ForCausalLM
def get_autocast(precision, cache_enabled=True):
if precision == "amp_bfloat16" or precision == "amp_bf16" or precision == 'bf16':
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled)
elif precision == 'fp16':
return lambda: torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=cache_enabled)
elif precision == 'fp32':
return suppress
else:
raise ValueError('not supported precision: {}'.format(precision))
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class InfMLLM_Unified_HD_Chat(PreTrainedModel):
def __init__(self, config, debug=False):
super().__init__(config)
## Initialize LM model
self.lm_tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, use_fast=False, trust_remote_code=True)
self.media_token_img = "<|image|>"
self.media_token_id_img = self.lm_tokenizer(self.media_token_img, return_tensors="pt",add_special_tokens=False).input_ids.item()
self.lm_model = Qwen2ForCausalLM(config.lm_config)
self.lm_tokenizer.model_max_length = config.max_txt_len
self.template_name = config.conv_style
self.preprocess_function = preprocess_qwen
self.separate = nn.Parameter(torch.zeros([1, 1, 4096]))
self.newline = nn.Parameter(torch.zeros([1, 1, 1, 4096]))
## Initialize image encoder
self.encoder_img = CLIPVisionTowerHD(config.vision_config, vision_select_layer=-2)
self.encoder_img_ln = lambda x: x
self.adapter_img = nn.Sequential(
nn.Linear(self.encoder_img.num_features*4, self.lm_model.config.hidden_size),
nn.GELU(),
nn.Linear(self.lm_model.config.hidden_size, self.lm_model.config.hidden_size)
)
## Others
self.config = config
self.precision = config.precision
self._apply_lemmatizer = getattr(config, 'apply_lemmatizer', False)
self._lemmatizer = None
def forward_encoder_img(self, image):
autocast = get_autocast(self.precision, cache_enabled=True)
with autocast():
assert isinstance(image, list)
image_embeds, image_split = self.encoder_img(image, self.separate, self.newline)
image_embeds = self.encoder_img_ln(image_embeds) # [bsz, L, D]
image_embeds = self.adapter_img(image_embeds)
return image_embeds, image_split
def _concat_embeds(self,
prompt_embeds, prompt_ids, prompt_masks,
labels=None, padding='left'):
emb_lens = [len(emb) for emb in prompt_embeds]
if len(set(emb_lens)) == 1:
if labels is not None:
return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0), torch.stack(labels, dim=0)
return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0)
pad_emb = self.lm_model.get_input_embeddings()(torch.tensor(self.lm_tokenizer.pad_token_id, device=prompt_embeds[0].device))
prompt_embeds_new = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
prompt_ids_new = torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) * self.lm_tokenizer.pad_token_id
prompt_masks_new = torch.zeros([len(emb_lens), max(emb_lens)]).to(prompt_masks[0])
if labels is not None:
labels_new = -100 * torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0])
for i, L in enumerate(emb_lens):
if padding == 'left':
prompt_embeds_new[i, -L:] = prompt_embeds[i]
prompt_ids_new[i, -L:] = prompt_ids[i]
prompt_masks_new[i, -L:] = prompt_masks[i]
if labels is not None:
labels_new[i, -L:] = labels[i]
elif padding == 'right':
prompt_embeds_new[i, :L] = prompt_embeds[i]
prompt_ids_new[i, :L] = prompt_ids[i]
prompt_masks_new[i, :L] = prompt_masks[i]
if labels is not None:
labels_new[i, :L] = labels[i]
else:
raise ValueError()
if labels is not None:
return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new
return prompt_embeds_new, prompt_ids_new, prompt_masks_new
def _insert_media_feat(self,
prompt_embeds, prompt_ids, prompt_masks,
is_languages,
embeds_media, media_token_id,
index_list=None,
labels=None, len_media=None):
## insert embeds_media into prompt
prompt_embeds_new = []
prompt_masks_new = []
prompt_ids_new = []
labels_new = []
device = embeds_media[0].device
if index_list is not None:
assert len(index_list) == len(embeds_media)
assert len(embeds_media) <= len(prompt_embeds)
for b in range(len(prompt_embeds)):
if (index_list is not None) and (b not in index_list):
prompt_embeds_new.append(prompt_embeds[b])
prompt_ids_new.append(prompt_ids[b])
prompt_masks_new.append(prompt_masks[b])
if labels is not None:
labels_new.append(labels[b])
else:
_idx = prompt_ids[b].tolist().index(media_token_id)
if index_list is not None:
b_media = index_list.index(b)
else:
b_media = b
if len_media is not None:
cur_embeds_media = embeds_media[b_media, :len_media[b_media]]
else:
cur_embeds_media = embeds_media[b_media]
prompt_embeds_new.append(torch.cat([prompt_embeds[b][:_idx+1],
cur_embeds_media,
prompt_embeds[b][_idx+1:]
], dim=0))
prompt_ids_new.append(torch.cat([prompt_ids[b][:_idx+1],
torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100),
prompt_ids[b][_idx+1:]
], dim=0))
if labels is not None:
labels_new.append(torch.cat([labels[b][:_idx+1],
torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100),
labels[b][_idx+1:]
], dim=0))
# if is pure-language sample, mask out image-embeddings
prompt_masks_new.append(torch.cat([prompt_masks[b][:_idx+1],
torch.zeros(len(cur_embeds_media), dtype=torch.long).to(device) if is_languages[b] else
torch.ones(len(cur_embeds_media), dtype=torch.long).to(device),
prompt_masks[b][_idx+1:]], dim=0))
if labels is not None:
return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new
return prompt_embeds_new, prompt_ids_new, prompt_masks_new
@torch.no_grad()
def generate(
self,
samples,
num_beams=5,
max_length=128,
min_length=1,
top_p=0.9,
temperature=0.,
return_prompts=False
):
autocast = get_autocast(self.precision, cache_enabled=True)
with autocast():
conversations = samples['conversations']
is_languages = [False] * len(conversations)
image_img = samples.get('images', None)
index_img = list(range(len(image_img)))
device = None
special_prefix = ["" for _ in range(len(conversations))]
if (self.config.encoder_img is not None) and (image_img is not None) and len(index_img) > 0:
for i in index_img:
special_prefix[i] = self.media_token_img + special_prefix[i]
new_image_img = []
for index in index_img:
new_image_img.append(image_img[index])
embeds_img, len_img = self.forward_encoder_img(new_image_img)
device = embeds_img.device
conv = get_conv_template(self.template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
prompts = []
for i, source in enumerate(conversations):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
per_prefix = special_prefix[i]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
sentence['value'] = sentence['value'].replace("<image>", "").strip() # llava-1.5 add <image> to the begin of the question, remove here
if j == 0:
sentence['value'] = per_prefix + sentence['value']
conv.append_message(role, sentence['value'])
prompts.append(conv.get_prompt())
self.lm_tokenizer.padding_side = "left"
if self.lm_tokenizer.bos_token is not None:
prompt_text = [self.lm_tokenizer.bos_token + t for t in prompts]
else:
prompt_text = prompts
prompt_tokens = self.lm_tokenizer(
prompt_text,
return_tensors="pt",
padding="longest",
truncation=False,
add_special_tokens=False
).to(device)
prompt_embeds = self.lm_model.get_input_embeddings()(prompt_tokens.input_ids)
prompt_masks = prompt_tokens.attention_mask # [bsz, n2]
prompt_ids = prompt_tokens.input_ids
assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left"
if embeds_img is not None:
prompt_embeds, prompt_ids, prompt_masks = self._insert_media_feat(prompt_embeds=prompt_embeds,
prompt_ids=prompt_ids,
prompt_masks=prompt_masks,
is_languages=is_languages,
embeds_media=embeds_img,
media_token_id=self.media_token_id_img,
index_list=index_img,
len_media=len_img)
# pad and concat embeds
prompt_embeds, prompt_ids, prompt_masks = self._concat_embeds(prompt_embeds, prompt_ids, prompt_masks, padding="left")
assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left"
kwargs = {}
kwargs['max_new_tokens'] = max_length
outputs = self.lm_model.generate(
#input_ids=input_ids,
inputs_embeds=prompt_embeds,
attention_mask=prompt_masks,
do_sample=True if temperature > 0 else False,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
eos_token_id=self.lm_tokenizer.eos_token_id,
#max_length=max_length,
min_length=min_length,
**kwargs
)
output_text = self.lm_tokenizer.batch_decode(
outputs, skip_special_tokens=True
)
output_text = [text.strip() for text in output_text]
if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]):
output_text = self._lemmatize(output_text)
if return_prompts:
return output_text, prompts
return output_text
def _lemmatize(self, answers):
def apply(answer):
doc = self.lemmatizer(answer)
words = []
for token in doc:
if token.pos_ in ["NOUN", "VERB"]:
words.append(token.lemma_)
else:
words.append(token.text)
answer = " ".join(words)
return answer
return [apply(answer) for answer in answers]
@property
def lemmatizer(self):
if self._lemmatizer is None:
try:
import spacy
self._lemmatizer = spacy.load("en_core_web_sm")
except ImportError:
logging.error(
"""
Please install spacy and en_core_web_sm model to apply lemmatization.
python -m spacy download en_core_web_sm
OR
import spacy.cli
spacy.cli.download("en_core_web_sm")
"""
)
exit(1)
return self._lemmatizer