import json
from transformers import LlamaTokenizer
class MiniCPMVTokenizer(LlamaTokenizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.im_start = ""
self.im_end = ""
self.ref_start = "["
self.ref_end = "]"
self.box_start = ""
self.box_end = ""
self.quad_start = ""
self.quad_end = ""
self.point_start = ""
self.point_end = ""
self.slice_start = ""
self.slice_end = ""
@property
def eos_id(self):
return self.sp_model.eos_id()
@property
def bos_id(self):
return self.sp_model.bos_id()
@property
def unk_id(self):
return self.sp_model.unk_id()
@property
def im_start_id(self):
return self._convert_token_to_id(self.im_start)
@property
def im_end_id(self):
return self._convert_token_to_id(self.im_end)
def apply_chat_template(self,
conversation,
add_image_msg: bool=True):
if isinstance(conversation, str):
conversation = json.loads(conversation)
prompt = ""
for i, msg in enumerate(conversation):
role = msg["role"]
content = msg["content"]
assert role in ["user", "assistant"]
if i == 0:
assert role == "user", "The role of first msg should be user"
if add_image_msg is True and "(./)" not in content:
content = "(./)" + content
prompt += "<用户>" if role == "user" else ""
prompt += content
prompt += ""
return prompt