|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers import PreTrainedModel |
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoConfig |
|
from .configuration_wemm import WeMMConfig |
|
from .vision_model import Idefics2VisionTransformer |
|
from .connector import Idefics2Connector |
|
from .image_processor import Idefics2ImageProcessor |
|
from .modeling_downsampler import DownsamplerModel |
|
from .modeling_projector import ProjectorModel |
|
from .modeling_internlm2 import InternLM2ForCausalLM |
|
from .tokenization_internlm2 import InternLM2Tokenizer |
|
from peft import PeftModel |
|
from peft import PeftConfig |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
IMAGE_TOKEN_INDEX = -200 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
IGNORE_INDEX = -100 |
|
from transformers import StoppingCriteria |
|
from transformers import PreTrainedTokenizerFast, StoppingCriteriaList |
|
import torch.nn.functional as F |
|
class StopWordStoppingCriteria(StoppingCriteria): |
|
"""StopWord stopping criteria.""" |
|
def __init__(self, tokenizer, stop_word): |
|
self.tokenizer = tokenizer |
|
self.stop_word = stop_word |
|
self.length = len(self.stop_word) |
|
def __call__(self, input_ids, *args, **kwargs) -> bool: |
|
cur_text = self.tokenizer.decode(input_ids[0]) |
|
cur_text = cur_text.replace('\r', '').replace('\n', '') |
|
return cur_text[-self.length:] == self.stop_word |
|
def get_stop_criteria( |
|
tokenizer, |
|
stop_words=[], |
|
): |
|
stop_criteria = StoppingCriteriaList() |
|
for word in stop_words: |
|
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
|
return stop_criteria |
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
emb = np.concatenate([emb_h, emb_w], axis=-1) |
|
return emb |
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float) |
|
omega /= embed_dim / 2. |
|
omega = 1. / 10000**omega |
|
pos = np.squeeze(pos) |
|
out = np.einsum('hw,d->hwd', pos, omega) |
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
emb = np.concatenate([emb_sin, emb_cos], axis=-1) |
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size_h, dtype=np.float32) |
|
grid_w = np.arange(grid_size_w, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
def recover_navit_subimages_with_pos_emb( |
|
sub_image_hidden_states, |
|
attention_mask, |
|
num_sub_images, |
|
visual_embedding_group, |
|
pos_hidden_size, |
|
thumbnail_only=False): |
|
if num_sub_images < 0: |
|
num_sub_images = 0 |
|
_slice = int(np.sqrt(num_sub_images)) |
|
N, L, D = sub_image_hidden_states.shape |
|
_, H, W = attention_mask.shape |
|
if thumbnail_only is True: |
|
num_sub_images += 1 |
|
sub_image_hidden_states = sub_image_hidden_states.reshape(-1, num_sub_images, H, W, D) |
|
attention_mask = attention_mask.reshape(-1, num_sub_images, H, W) |
|
if thumbnail_only is True: |
|
sub_image_hidden_states = sub_image_hidden_states[:, -1:, :, :, :] |
|
attention_mask = attention_mask[:, -1:, :, :] |
|
_slice = 1 |
|
def _infer_ori_image_patch_shape(sub_image_attention_mask): |
|
ind_h, ind_w = torch.where(sub_image_attention_mask > 0) |
|
return torch.max(ind_h) + 1, torch.max(ind_w) + 1 |
|
def _pad_to_same(image_hidden): |
|
_dtype = image_hidden.dtype |
|
visual_downsample_stride = int(np.sqrt(visual_embedding_group)) |
|
full_h, full_w, _ = image_hidden.shape |
|
target_h, target_w = H * _slice, W * _slice |
|
|
|
to_pad_h = (target_h - full_h) + ( |
|
visual_downsample_stride - target_h % visual_downsample_stride) % visual_downsample_stride |
|
to_pad_w = (target_w - full_w) + ( |
|
visual_downsample_stride - target_w % visual_downsample_stride) % visual_downsample_stride |
|
|
|
image_hidden = image_hidden.permute(2, 0, 1).unsqueeze(0) |
|
pad_size = (0, to_pad_w, 0, to_pad_h) |
|
|
|
image_hidden = F.pad(image_hidden.to(torch.float32), pad_size, mode='replicate').squeeze(0).permute(1, 2, 0) |
|
return image_hidden.to(_dtype) |
|
|
|
image_hidden_states = list() |
|
valid_image_token = list() |
|
image_2d_pos = list() |
|
for batch_id in range(len(sub_image_hidden_states)): |
|
ori_h, ori_w = _infer_ori_image_patch_shape(attention_mask[batch_id][0]) |
|
full_h, full_w = ori_h * _slice, ori_w * _slice |
|
|
|
this_image_hidden = sub_image_hidden_states[batch_id][:, 0:ori_h, 0:ori_w, :] \ |
|
.view(_slice, _slice, ori_h, ori_w, D).permute(0, 2, 1, 3, 4).contiguous().view(full_h, full_w, D) |
|
pos_emb = get_2d_sincos_pos_embed(pos_hidden_size, grid_size_h=full_h, |
|
grid_size_w=full_w) |
|
pos_emb = torch.tensor(pos_emb, dtype=this_image_hidden.dtype, device=this_image_hidden.device) |
|
image_hidden_states.append(_pad_to_same(this_image_hidden)) |
|
image_2d_pos.append(_pad_to_same(pos_emb)) |
|
valid_image_token.append([full_h, full_w]) |
|
image_hidden_states = torch.stack(image_hidden_states) |
|
image_2d_pos = torch.stack(image_2d_pos) |
|
valid_image_token = torch.tensor(valid_image_token, dtype=torch.int64) |
|
return image_hidden_states, image_2d_pos, valid_image_token |
|
def visiual_token_downsample( |
|
visual_downsampler, |
|
image_hidden_states, |
|
valid_image_token, |
|
visual_embedding_group, |
|
image_2d_pos): |
|
if image_2d_pos is not None: |
|
image_hidden_states = image_hidden_states + image_2d_pos |
|
image_hidden_states = visual_downsampler(image_hidden_states) |
|
valid_image_token = torch.ceil(valid_image_token / np.sqrt(visual_embedding_group)).to(torch.int64) |
|
return image_hidden_states, valid_image_token |
|
def merge_native_qformer( |
|
clip_embeddings_native_patch, |
|
valid_image_token_shape, |
|
clip_embeddings_qformer, |
|
visual_source_spliter, |
|
num_sub_images): |
|
|
|
def add_split_token_for_qformer_token(qformer_emb): |
|
|
|
len_per_token = int(qformer_emb.size(0) // (num_sub_images + 1)) |
|
qformer_emb_with_spliter = list() |
|
for i in range(num_sub_images + 1): |
|
qformer_emb_with_spliter.append( |
|
visual_source_spliter(torch.tensor([2 * i]).to(visual_source_spliter.weight.device)) |
|
) |
|
qformer_emb_with_spliter.append(qformer_emb[i * len_per_token:(i + 1) * len_per_token]) |
|
qformer_emb_with_spliter.append( |
|
visual_source_spliter(torch.tensor([2 * i + 1]).to(visual_source_spliter.weight.device)) |
|
) |
|
return torch.cat(qformer_emb_with_spliter, dim=0) |
|
|
|
merged_visual_embeddings = list() |
|
for batch_id in range(clip_embeddings_native_patch.size(0)): |
|
h, w = valid_image_token_shape[batch_id] |
|
native_patch_emb = clip_embeddings_native_patch[batch_id][:h, :w, :].reshape(h*w, -1) |
|
if clip_embeddings_qformer is not None: |
|
qformer_emb = clip_embeddings_qformer[batch_id] |
|
qformer_emb = add_split_token_for_qformer_token(qformer_emb) |
|
merged_visual_embeddings.append( |
|
torch.cat( |
|
[visual_source_spliter(torch.tensor([10]).to(visual_source_spliter.weight.device)), |
|
native_patch_emb, |
|
visual_source_spliter(torch.tensor([11]).to(visual_source_spliter.weight.device)), |
|
qformer_emb], |
|
dim=0)) |
|
else: |
|
merged_visual_embeddings.append( |
|
torch.cat( |
|
[visual_source_spliter(torch.tensor([0]).to(visual_source_spliter.weight.device)), |
|
native_patch_emb, |
|
visual_source_spliter(torch.tensor([1]).to(visual_source_spliter.weight.device))], |
|
dim=0)) |
|
|
|
return merged_visual_embeddings |
|
class WemmForConditionalGeneration(PreTrainedModel): |
|
config_class = WeMMConfig |
|
def __init__(self, config: WeMMConfig): |
|
super().__init__(config) |
|
|
|
self.vision_tower = Idefics2VisionTransformer(config.vision_config) |
|
self.image_processor = Idefics2ImageProcessor(config.image_processor) |
|
self.language_model = InternLM2ForCausalLM(config.text_config) |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path, trust_remote_code=True, encode_special_tokens=True) |
|
self.downsampler = DownsamplerModel(config.downsampler_config) |
|
self.visual_source_spliter_emb = torch.nn.Embedding(**config.spliter_emb_config) |
|
|
|
self.gen_config = GenerationConfig( |
|
max_new_tokens=512, |
|
do_sample=False, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
pad_token_id=self.tokenizer.pad_token_id |
|
if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, |
|
) |
|
self.do_image_splitting = config.do_image_splitting |
|
self.stop_criteria = get_stop_criteria( |
|
tokenizer=self.tokenizer, stop_words=['<|im_end|>']) |
|
self.config = config |
|
|
|
def chat(self, conversations, gen_config=None): |
|
prompt = "" |
|
image_path = conversations[0]['images'][0] |
|
for i,ann in enumerate(conversations): |
|
if(ann['role'] == 'user'): |
|
prompt += f"<|im_start|>user\n{ann['content']}<|im_end|>\n" |
|
elif(ann['role'] == 'assistant'): |
|
prompt += f"<|im_start|>assistant\n{ann['content']}<|im_end|>\n" |
|
prompt += '<|im_start|>assistant\n' |
|
with torch.no_grad(): |
|
output = self.generate(image_path, prompt, gen_config=gen_config) |
|
return output |
|
|
|
|
|
def mm_generate(self, image_path, prompt, gen_config=None): |
|
prompt = "<image>" + '\n' + prompt |
|
prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n" |
|
return self.generate(image_path,prompt,gen_config) |
|
|
|
def generate(self, image_path, prompt, gen_config=None): |
|
image = Image.open(image_path).convert('RGB') |
|
navit980_images = self.image_processor([[image]], return_tensors="pt", do_image_splitting=self.do_image_splitting) |
|
batch_size_navit = navit980_images['pixel_values'].shape[0] |
|
navit_pixel_values = navit980_images['navit_pixel_values'].cuda() |
|
navit_patch_attention_mask = navit980_images["pixel_attention_mask"].cuda() |
|
clip_visual_outputs = self.vision_tower(pixel_values=navit_pixel_values,patch_attention_mask=navit_patch_attention_mask,).last_hidden_state |
|
|
|
super_image_hidden_states, image_2d_pos, valid_image_token_shape = \ |
|
recover_navit_subimages_with_pos_emb( |
|
clip_visual_outputs, navit_patch_attention_mask, num_sub_images=-1, |
|
visual_embedding_group=4, |
|
pos_hidden_size=4096, |
|
thumbnail_only=True |
|
) |
|
clip_embeddings_native_patch, valid_image_token_shape = visiual_token_downsample( |
|
self.downsampler, |
|
super_image_hidden_states, valid_image_token_shape, |
|
visual_embedding_group=4, image_2d_pos=None |
|
) |
|
merged_visual_embeddings = \ |
|
merge_native_qformer( |
|
clip_embeddings_native_patch, |
|
valid_image_token_shape, |
|
clip_embeddings_qformer=None, |
|
visual_source_spliter=self.visual_source_spliter_emb, |
|
num_sub_images=-1 |
|
) |
|
chunk_encode = [] |
|
for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)): |
|
if idx == 0: |
|
cur_encode = self.tokenizer.encode(chunk) |
|
else: |
|
cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) |
|
chunk_encode.append(cur_encode) |
|
assert len(chunk_encode) == 2 |
|
ids = [] |
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
ids.extend(cur_chunk_encode) |
|
if idx != len(chunk_encode) - 1: |
|
ids.append(IMAGE_TOKEN_INDEX) |
|
ids = torch.tensor(ids).cuda().unsqueeze(0) |
|
pixel_values = None |
|
mm_inputs = self.prepare_inputs_labels_for_multimodal( |
|
llm=self.language_model, input_ids=ids, pixel_values=pixel_values, clip_embeddings=merged_visual_embeddings) |
|
generate_output = self.language_model.generate( |
|
**mm_inputs, |
|
generation_config=gen_config if gen_config is not None else self.gen_config, |
|
streamer=None, |
|
bos_token_id=self.tokenizer.bos_token_id, |
|
stopping_criteria=self.stop_criteria |
|
) |
|
predict = self.tokenizer.decode( |
|
generate_output[0], skip_special_tokens=True).strip() |
|
return predict |
|
def get_valid_visual_embedding(self, embedding, valid_token_shape): |
|
if valid_token_shape is None: |
|
return embedding |
|
h, w = valid_token_shape |
|
return embedding[:h, :w, :].reshape(h*w, -1) |
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, |
|
llm: PreTrainedModel, |
|
input_ids: torch.LongTensor = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
clip_embeddings: Optional[torch.FloatTensor] = None, |
|
hard_coded_max_len: Optional[int] = None, |
|
**kwargs): |
|
if pixel_values is None and clip_embeddings is None: |
|
return { |
|
'input_ids': input_ids, |
|
'position_ids': position_ids, |
|
'attention_mask': attention_mask, |
|
'past_key_values': past_key_values, |
|
'inputs_embeds': None, |
|
'labels': labels |
|
} |
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] |
|
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
labels = [ |
|
cur_labels[cur_attention_mask] |
|
for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
|
] |
|
|
|
new_inputs_embeds = [] |
|
new_labels = [] |
|
new_img_masks = [] |
|
cur_image_idx = 0 |
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
if num_images == 0: |
|
cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None |
|
|
|
cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None |
|
cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) |
|
if cur_clip_emb is not None and cur_pixel_values is not None: |
|
cur_inputs_embeds = torch.cat( |
|
[cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0) |
|
elif cur_pixel_values is not None: |
|
cur_inputs_embeds = torch.cat( |
|
[cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) |
|
elif cur_clip_emb is not None: |
|
cur_inputs_embeds = torch.cat( |
|
[cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0) |
|
else: |
|
raise ValueError |
|
new_inputs_embeds.append(cur_inputs_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
new_img_masks.append(torch.zeros( |
|
cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool()) |
|
cur_image_idx += 1 |
|
continue |
|
|
|
image_token_indices = [-1] + torch.where( |
|
cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [ |
|
cur_input_ids.shape[0] |
|
] |
|
cur_input_ids_noim = [] |
|
cur_labels = labels[batch_idx] |
|
cur_labels_noim = [] |
|
for i in range(len(image_token_indices) - 1): |
|
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + |
|
1:image_token_indices[i + |
|
1]]) |
|
cur_labels_noim.append(cur_labels[image_token_indices[i] + |
|
1:image_token_indices[i + 1]]) |
|
split_sizes = [x.shape[0] for x in cur_labels_noim] |
|
cur_inputs_embeds = llm.get_input_embeddings()( |
|
torch.cat(cur_input_ids_noim)) |
|
cur_inputs_embeds_no_im = torch.split( |
|
cur_inputs_embeds, split_sizes, dim=0) |
|
cur_new_inputs_embeds = [] |
|
cur_new_labels = [] |
|
cur_img_masks = [] |
|
|
|
for i in range(num_images + 1): |
|
cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) |
|
cur_new_labels.append(cur_labels_noim[i]) |
|
cur_img_masks.append(torch.zeros( |
|
cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool()) |
|
if i < num_images: |
|
cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None |
|
cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None |
|
|
|
cur_image_idx += 1 |
|
|
|
|
|
if cur_pixel_values is not None: |
|
cur_new_inputs_embeds.append(cur_pixel_values) |
|
cur_img_masks.append(torch.ones( |
|
cur_pixel_values.shape[0], device=cur_pixel_values.device).bool()) |
|
cur_new_labels.append( |
|
torch.full((cur_pixel_values.shape[0], ), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype)) |
|
|
|
|
|
if cur_clip_emb is not None: |
|
cur_new_inputs_embeds.append(cur_clip_emb) |
|
cur_img_masks.append(torch.ones( |
|
cur_clip_emb.shape[0], device=cur_clip_emb.device).bool()) |
|
cur_new_labels.append( |
|
torch.full((cur_clip_emb.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype)) |
|
|
|
cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) |
|
cur_new_labels = torch.cat(cur_new_labels) |
|
cur_img_masks = torch.cat(cur_img_masks) |
|
|
|
new_inputs_embeds.append(cur_new_inputs_embeds) |
|
new_labels.append(cur_new_labels) |
|
new_img_masks.append(cur_img_masks) |
|
|
|
|
|
max_len = max(x.shape[0] for x in new_inputs_embeds) |
|
if hard_coded_max_len is not None: |
|
max_len = min(max_len, hard_coded_max_len) |
|
batch_size = len(new_inputs_embeds) |
|
|
|
new_inputs_embeds_padded = [] |
|
new_labels_padded = torch.full((batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device) |
|
attention_mask = torch.zeros((batch_size, max_len), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device) |
|
position_ids = torch.zeros((batch_size, max_len), |
|
dtype=position_ids.dtype, |
|
device=position_ids.device) |
|
new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool() |
|
|
|
for i, (cur_new_embed, |
|
cur_new_labels, cur_new_img_masks) in enumerate(zip(new_inputs_embeds, new_labels, new_img_masks)): |
|
cur_new_embed = cur_new_embed[:max_len] |
|
cur_new_labels = cur_new_labels[:max_len] |
|
cur_new_img_masks = cur_new_img_masks[:max_len] |
|
|
|
cur_len = cur_new_embed.shape[0] |
|
new_inputs_embeds_padded.append( |
|
torch.cat((cur_new_embed, |
|
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device)), |
|
dim=0)) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, |
|
cur_len, |
|
dtype=position_ids.dtype, |
|
device=position_ids.device) |
|
new_img_masks_padded[i, :cur_len] = cur_new_img_masks |
|
|
|
new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) |
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
prepared_data = { |
|
'input_ids': None, |
|
'position_ids': position_ids, |
|
'attention_mask': attention_mask, |
|
'past_key_values': past_key_values, |
|
'inputs_embeds': new_inputs_embeds, |
|
'labels': new_labels, |
|
} |
|
|
|
prepared_data.update({'im_mask': new_img_masks_padded}) |
|
return prepared_data |
|
|
|
AutoConfig.register("wemm_hf", WeMMConfig) |
|
AutoModel.register(WeMMConfig, WemmForConditionalGeneration) |