videollm-online / models /tokenization_live.py
chenjoya's picture
Upload 9 files
7d1b5a5 verified
raw
history blame
7.45 kB
import torch
from transformers import AutoTokenizer
from functools import partial
from .configuration_live import LiveConfigMixin
def get_stream_placeholder_len(num_frames: int, model_config: LiveConfigMixin) -> str:
return num_frames * model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval) * (num_frames - 1)
def get_stream_placeholder_jinja2(model_config: LiveConfigMixin) -> str:
return f"'{model_config.frame_token_interval}'.join([{model_config.frame_num_tokens} * '{model_config.v_placeholder}'] * message['num_frames'])"
def get_stream_learn_ranges(num_frames: int, model_config: LiveConfigMixin) -> torch.Tensor:
len_frame_placeholder_with_interval = model_config.frame_num_tokens * len(model_config.v_placeholder) + len(model_config.frame_token_interval)
intermediate_interval_idxs = torch.arange(
len_frame_placeholder_with_interval,
len_frame_placeholder_with_interval * num_frames + 1,
len_frame_placeholder_with_interval
) - len(model_config.frame_token_interval)
len_learn = len(model_config.frame_token_interval) if model_config.frame_token_interval else len(model_config.v_placeholder)
learn_ranges = torch.stack([
intermediate_interval_idxs,
intermediate_interval_idxs + len_learn
], dim=1)
return learn_ranges
def chat_template(self, stream_placeholder_jinja2: str):
"""
system prompt
[<v>,<v>,<v>]
User: ...
Assistant: ...</s>
[<v>,<v>]
Assistant: ...</s>
User: ...
Assistant: ...</s>
"""
template = (
"{% if messages[0]['role'] == 'system' %}"
"{{ bos_token + messages[0]['content'] + '\n' }}" # system
"{% set messages = messages[1:] %}"
"{% endif %}"
"{% for message in messages %}"
"{% if message['role'] == 'user' %}"
"{% if add_stream_query_prompt %}"
"{{ ']\nUser: ' + message['content'] }}"
"{% else %}"
"{{ '\nUser: ' + message['content'] }}"
"{% endif %}"
"{% elif message['role'] == 'assistant' %}"
"{{ '\nAssistant: ' + message['content'] + eos_token }}"
"{% elif message['role'] == 'stream' and message['num_frames'] > 0: %}"
"{{ '\n[' + STREAM_PLACEHOLDER + ']' }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '\nAssistant:' }}"
"{% elif add_stream_prompt %}"
"{{ '\n[' }}"
"{% elif add_stream_generation_prompt %}"
"{{ ']\nAssistant:' }}"
"{% endif %}"
)
template = template.replace('STREAM_PLACEHOLDER', stream_placeholder_jinja2)
return template
def chat_template_transition(tokenizer):
return {
(None, 'system'): tokenizer.bos_token,
('system', 'user'): '\n\nUser: ',
('system', 'stream'): '\n\n[',
('user', 'assistant'): '\nAssistant: ',
('user', 'stream'): '\n[',
('user', 'user'): '\nUser: ',
('assistant', 'user'): f'{tokenizer.eos_token}\nUser: ',
('assistant', 'stream'): f'{tokenizer.eos_token}\n[',
('stream', 'user'): ']\nUser: ',
('stream', 'assistant'): ']\nAssistant: ',
'assistant': 'Assistant: ',
'eos_token': tokenizer.eos_token,
}
def chat_template_offsets(tokenizer):
return {k:len(v) for k, v in chat_template_transition(tokenizer).items()}
def get_learn_ranges(conversation: list[dict], *, chat_template_offsets: dict[tuple, int], model_config: LiveConfigMixin):
offset = 0
learn_ranges = []
last_role = None
for message in conversation:
role = message['role']
offset += chat_template_offsets[(last_role, role)]
last_role = role
if role == 'stream':
if message.get('learn', False):
ranges = get_stream_learn_ranges(message['num_frames'], model_config) + offset
# the last one has ]\n, should also consider \n
ranges[-1, 1] += 1
if not isinstance(message['learn'], bool):
ranges = ranges[:message['learn']]
learn_ranges.extend([range(r[0], r[1]) for r in ranges])
offset += get_stream_placeholder_len(message['num_frames'], model_config)
else:
if role == 'assistant':
if message.get('learn', False):
learn_ranges.append(range(offset - chat_template_offsets['assistant'], offset + len(message['content']) + chat_template_offsets['eos_token']))
offset += len(message['content'])
return learn_ranges
def build_live_tokenizer_and_update_config(llm_pretrained: str, model_config: LiveConfigMixin) -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(llm_pretrained, use_fast=True, padding_side='left')
tokenizer.add_special_tokens({'additional_special_tokens': [model_config.v_placeholder]})
v_placeholder_id = len(tokenizer) - 1
if model_config.frame_token_interval:
frame_token_interval_id = tokenizer.convert_tokens_to_ids(model_config.frame_token_interval)
else:
frame_token_interval_id = None
tokenizer.pad_token = tokenizer.eos_token
model_config.update(dict(v_placeholder_id=v_placeholder_id, frame_token_interval_id=frame_token_interval_id, eos_token_id=tokenizer.eos_token_id))
tokenizer.chat_template = chat_template(tokenizer, get_stream_placeholder_jinja2(model_config))
tokenizer.get_learn_ranges = partial(get_learn_ranges, chat_template_offsets=chat_template_offsets(tokenizer), model_config=model_config)
return tokenizer
if __name__ == '__main__':
config = LiveConfigMixin(frame_token_interval=',', frame_token_cls=True, frame_token_pooled=[3,3], frame_num_tokens=10)
tokenizer = build_live_tokenizer_and_update_config('meta-llama/Meta-Llama-3-8B-Instruct', config)
chat = [
{'role': 'system', 'content': 'cool.'},
{'role': 'stream', 'num_frames': 2, 'learn': 1},
{'role': 'user', 'content': 'cool?'},
{'role': 'assistant', 'content': 'cool.', 'learn': True},
{'role': 'stream', 'num_frames': 3, 'learn': 3},
{'role': 'assistant', 'content': 'so cool.', 'learn': True},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
learn_ranges = tokenizer.get_learn_ranges(chat)
batch = tokenizer([prompt], return_offsets_mapping=True, add_special_tokens=False, return_tensors="pt", padding=True)
batch_labels = torch.full_like(batch.input_ids, -100, dtype=torch.long)
for text, labels, input_ids, offset_mapping, learn_range in zip(
[prompt], batch_labels, batch.input_ids, batch.offset_mapping, [learn_ranges]
):
for learn_r in learn_range:
start = torch.nonzero(offset_mapping[:,0] == learn_r.start).item()
if offset_mapping[:,0][-1] >= learn_r.stop:
stop = torch.nonzero(offset_mapping[:,0] == learn_r.stop).item()
else: # the last eos token
stop = len(input_ids)
labels[start-1:stop-1] = input_ids[start:stop]
# NOTE: input_ids may out of boundary of len(tokenizer) - 1. (1 is the added vision placeholder)
# this is because some frames has v_placeholder_id target. so replace it with eos token.
labels[labels >= len(tokenizer) - 1] = tokenizer.eos_token_id
print(batch.input_ids)
print(batch_labels)