import copy from typing import Dict, Sequence import torch import transformers from llava.constants import ( IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX, ) from llava import conversation as conversation_lib from llava.mm_utils import tokenizer_image_token def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ) for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(target, tokenized_lens, speakers): # cur_idx = 0 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len def _add_speaker_and_signal(header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" BEGIN_SIGNAL = "### " END_SIGNAL = "\n" conversation = header for sentence in source: from_str = sentence["from"] if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: from_str = "unknown" sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL return conversation def preprocess_multimodal( sources: Sequence[str], is_multimodal: bool, mm_use_im_start_end: bool ) -> Dict: if not is_multimodal: return sources for source in sources: for sentence in source: if DEFAULT_IMAGE_TOKEN in sentence["value"]: sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] sentence["value"] = sentence["value"].strip() if "mmtag" in conversation_lib.default_conversation.version: sentence["value"] = sentence["value"].replace( DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "", ) replace_token = DEFAULT_IMAGE_TOKEN if mm_use_im_start_end: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def preprocess_llama_2( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack( [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ], dim=0, ) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") return dict( input_ids=input_ids, labels=targets, ) def preprocess_llama_2_obj_identifier( sources, tokenizer: transformers.PreTrainedTokenizer, obj_dict: Dict[str, Dict], obj_context_feature_type: str, mode: str, ) -> Dict: """This function tokenizes the conversation into the following format: %%%% Object-centric context: : ; : ; ... : ;%%%%" where is currently placeholered by IMAGE_TOKEN_INDEX but will later be replaced by the actual feature in vector form. We mark all string tokens as not trainable, only keep the feature vectors trainable. Args: sources (_type_): the conversation sources tokenizer (transformers.PreTrainedTokenizer): the tokenizer obj_dict (Dict[str, Dict]): the object dictionary for the scene obj_context_feature_type (str): the type of object feature to use for the object context Returns: Dict: the tokenized input_ids and labels """ conv = conversation_lib.conv_llava_llama_2_obj_identifier.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack( [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0, ) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 # Mask targets sep = "[/INST] " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if ( cur_len < tokenizer.model_max_length and mode != "generate" ): # check if target is correctly masked. when generating, we don't have any target if cur_len != total_len: target[:] = IGNORE_INDEX print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") assert ( input_ids.shape[0] == targets.shape[0] == 1 ), "Only support tokenization for one conversation at a time" input_id = input_ids[0] targets = targets[0] # TODO: replace -200 (IMAGE_TOKEN_INDEX) with object identifier tokens, # we want the LLM to see: # %%%% Object-centric context: : ; : ; ... : ;%%%%" # where will later be replaced by the actual feature in vector form, if obj_context_feature_type == "vector": obj_context = "%%%% Object-centric context:" for obj_id, obj_info in obj_dict.items(): obj_context += f" <{obj_id}>: {tokenizer.sep_token};" # use as a placeholder, later it will be replaced by the actual feature vector obj_context += "%%%%" tokenized_obj_context = tokenizer(obj_context).input_ids[ 1:-1 ] # strip the bos and eos tokens tokenized_obj_context = torch.tensor(tokenized_obj_context, dtype=torch.long) tokenized_obj_context[tokenized_obj_context == tokenizer.sep_token_id] = ( IMAGE_TOKEN_INDEX # replace with IMAGE_TOKEN_INDEX, so that later we can use -200 to find where the feature vector should be inserted ) # mark all string tokens as not trainable, only keep the feature vectors trainable tokenized_obj_context_target = tokenized_obj_context.clone() tokenized_obj_context_target[tokenized_obj_context != IMAGE_TOKEN_INDEX] = IGNORE_INDEX elif obj_context_feature_type == "text": obj_context = "%%%% Object-centric context:" for obj_id, obj_info in obj_dict.items(): obj_context += f" <{obj_id}>: {obj_info};" obj_context += "%%%%" tokenized_obj_context = tokenizer(obj_context).input_ids[ 1:-1 ] # strip the bos and eos tokens tokenized_obj_context = torch.tensor(tokenized_obj_context, dtype=torch.long) tokenized_obj_context_target = tokenized_obj_context.clone() tokenized_obj_context_target[:] = IGNORE_INDEX # mark all tokens as not trainable # now, insert the object context into input_id and target, where the IMAGE_TOKEN_INDEX is separation_idx = torch.where(input_id == IMAGE_TOKEN_INDEX)[0] input_id_with_obj_context = torch.cat( [input_id[:separation_idx], tokenized_obj_context, input_id[separation_idx + 1 :]] ) target_with_obj_context = torch.cat( [ targets[:separation_idx], tokenized_obj_context_target, targets[separation_idx + 1 :], ] ) if obj_context_feature_type == "vector": return dict( input_ids=input_id_with_obj_context, labels=target_with_obj_context, obj_dict=obj_dict, # return the object dictionary so that we can later embed the features ) elif obj_context_feature_type == "text": return dict(input_ids=input_id_with_obj_context, labels=target_with_obj_context) def preprocess_v1( sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations if has_image: input_ids = torch.stack( [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ], dim=0, ) else: input_ids = tokenizer( conversations, return_tensors="pt", padding="longest", max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.TWO # Mask targets sep = conv.sep + conv.roles[1] + ": " for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") return dict( input_ids=input_ids, labels=targets, ) def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations input_ids = torch.stack( [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0, ) targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(tokenizer_image_token(rou, tokenizer)) + len( tokenizer_image_token(conv.sep, tokenizer) ) instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") return dict( input_ids=input_ids, labels=targets, ) def preprocess_plain( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: # add end signal and concatenate together conversations = [] for source in sources: assert len(source) == 2 assert DEFAULT_IMAGE_TOKEN in source[0]["value"] source[0]["value"] = DEFAULT_IMAGE_TOKEN conversation = ( source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep ) conversations.append(conversation) # tokenize conversations input_ids = [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) target[:tokenized_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=targets) def preprocess( sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, ) -> Dict: """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. """ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: return preprocess_plain(sources, tokenizer) if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: return preprocess_llama_2(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version.startswith("v1"): return preprocess_v1(sources, tokenizer, has_image=has_image) if conversation_lib.default_conversation.version == "mpt": return preprocess_mpt(sources, tokenizer) # add end signal and concatenate together conversations = [] for source in sources: header = f"{conversation_lib.default_conversation.system}\n\n" conversation = _add_speaker_and_signal(header, source) conversations.append(conversation) # tokenize conversations def get_tokenize_len(prompts): return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] if has_image: input_ids = [ tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations ] else: conversations_tokenized = _tokenize_fn(conversations, tokenizer) input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): if has_image: tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) else: tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)[ "input_ids_lens" ] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) return dict(input_ids=input_ids, labels=targets)