3D-GRAND / llava /util /tokenization.py
jedyang97's picture
initial demo
947767a
raw
history blame
20.2 kB
import copy
from typing import Dict, Sequence
import torch
import transformers
from llava.constants import (
IGNORE_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_TOKEN_INDEX,
)
from llava import conversation as conversation_lib
from llava.mm_utils import tokenizer_image_token
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = "unknown"
sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
def preprocess_multimodal(
sources: Sequence[str], is_multimodal: bool, mm_use_im_start_end: bool
) -> Dict:
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
sentence["value"] = sentence["value"].strip()
if "mmtag" in conversation_lib.default_conversation.version:
sentence["value"] = sentence["value"].replace(
DEFAULT_IMAGE_TOKEN,
"<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>",
)
replace_token = DEFAULT_IMAGE_TOKEN
if mm_use_im_start_end:
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources
def preprocess_llama_2(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_llama_2_obj_identifier(
sources,
tokenizer: transformers.PreTrainedTokenizer,
obj_dict: Dict[str, Dict],
obj_context_feature_type: str,
mode: str,
) -> Dict:
"""This function tokenizes the conversation into the following format:
%%%% Object-centric context: <obj_0>: <obj_0_feat>; <obj_1>: <obj_1_feat>; ... <obj_i>: <obj_i_feat>;%%%%"
where <obj_i_feat> is currently placeholered by IMAGE_TOKEN_INDEX
but will later be replaced by the actual feature in vector form.
We mark all string tokens as not trainable, only keep the feature vectors trainable.
Args:
sources (_type_): the conversation sources
tokenizer (transformers.PreTrainedTokenizer): the tokenizer
obj_dict (Dict[str, Dict]): the object dictionary for the scene
obj_context_feature_type (str): the type of object feature to use for the object context
Returns:
Dict: the tokenized input_ids and labels
"""
conv = conversation_lib.conv_llava_llama_2_obj_identifier.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = torch.stack(
[tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations],
dim=0,
)
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
# Mask targets
sep = "[/INST] "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if (
cur_len < tokenizer.model_max_length and mode != "generate"
): # check if target is correctly masked. when generating, we don't have any target
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
assert (
input_ids.shape[0] == targets.shape[0] == 1
), "Only support tokenization for one conversation at a time"
input_id = input_ids[0]
targets = targets[0]
# TODO: replace -200 (IMAGE_TOKEN_INDEX) with object identifier tokens,
# we want the LLM to see:
# %%%% Object-centric context: <obj_0>: <obj_0_feat>; <obj_1>: <obj_1_feat>; ... <obj_i>: <obj_i_feat>;%%%%"
# where <obj_i_feat> will later be replaced by the actual feature in vector form,
if obj_context_feature_type == "vector":
obj_context = "%%%% Object-centric context:"
for obj_id, obj_info in obj_dict.items():
obj_context += f" <{obj_id}>: {tokenizer.sep_token};" # use </s> as a placeholder, later it will be replaced by the actual feature vector
obj_context += "%%%%"
tokenized_obj_context = tokenizer(obj_context).input_ids[
1:-1
] # strip the bos and eos tokens
tokenized_obj_context = torch.tensor(tokenized_obj_context, dtype=torch.long)
tokenized_obj_context[tokenized_obj_context == tokenizer.sep_token_id] = (
IMAGE_TOKEN_INDEX # replace </s> with IMAGE_TOKEN_INDEX, so that later we can use -200 to find where the feature vector should be inserted
)
# mark all string tokens as not trainable, only keep the feature vectors trainable
tokenized_obj_context_target = tokenized_obj_context.clone()
tokenized_obj_context_target[tokenized_obj_context != IMAGE_TOKEN_INDEX] = IGNORE_INDEX
elif obj_context_feature_type == "text":
obj_context = "%%%% Object-centric context:"
for obj_id, obj_info in obj_dict.items():
obj_context += f" <{obj_id}>: {obj_info};"
obj_context += "%%%%"
tokenized_obj_context = tokenizer(obj_context).input_ids[
1:-1
] # strip the bos and eos tokens
tokenized_obj_context = torch.tensor(tokenized_obj_context, dtype=torch.long)
tokenized_obj_context_target = tokenized_obj_context.clone()
tokenized_obj_context_target[:] = IGNORE_INDEX # mark all tokens as not trainable
# now, insert the object context into input_id and target, where the IMAGE_TOKEN_INDEX is
separation_idx = torch.where(input_id == IMAGE_TOKEN_INDEX)[0]
input_id_with_obj_context = torch.cat(
[input_id[:separation_idx], tokenized_obj_context, input_id[separation_idx + 1 :]]
)
target_with_obj_context = torch.cat(
[
targets[:separation_idx],
tokenized_obj_context_target,
targets[separation_idx + 1 :],
]
)
if obj_context_feature_type == "vector":
return dict(
input_ids=input_id_with_obj_context,
labels=target_with_obj_context,
obj_dict=obj_dict, # return the object dictionary so that we can later embed the features
)
elif obj_context_feature_type == "text":
return dict(input_ids=input_id_with_obj_context, labels=target_with_obj_context)
def preprocess_v1(
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack(
[
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
],
dim=0,
)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
# Mask targets
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_mpt(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = torch.stack(
[tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations],
dim=0,
)
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
if rou == "":
break
parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer_image_token(rou, tokenizer)) + len(
tokenizer_image_token(conv.sep, tokenizer)
)
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
source[0]["value"] = DEFAULT_IMAGE_TOKEN
conversation = (
source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
)
conversations.append(conversation)
# tokenize conversations
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations
]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
return preprocess_plain(sources, tokenizer)
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version.startswith("v1"):
return preprocess_v1(sources, tokenizer, has_image=has_image)
if conversation_lib.default_conversation.version == "mpt":
return preprocess_mpt(sources, tokenizer)
# add end signal and concatenate together
conversations = []
for source in sources:
header = f"{conversation_lib.default_conversation.system}\n\n"
conversation = _add_speaker_and_signal(header, source)
conversations.append(conversation)
# tokenize conversations
def get_tokenize_len(prompts):
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
if has_image:
input_ids = [
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
for prompt in conversations
]
else:
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
if has_image:
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
else:
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)[
"input_ids_lens"
]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
return dict(input_ids=input_ids, labels=targets)