import transformers from utils import printf import copy class prompt: def __init__(self, tokenizer, max_len, add_eos=True): self.tokenizer = tokenizer self.max_len = max_len self.add_eos=add_eos class instruct_prompt(prompt): prompt = ( "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ) prompt_input = ( "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" "### Instruction:{instruction}\n\n### Input:{input}\n\n### Response:" ) prompt_history = "User:{input}\n\nAssistant:{output}\n\n" prompt_post = "User:{input}\n\nAssistant:" def preprocess_gen(self, data_point): if 'history' not in data_point: # single instruction format {'instruction':..,'input':..} if 'input' in data_point: user_prompt = self.prompt_input.format_map(data_point) else: user_prompt = self.prompt.format_map(data_point) else: # multi turn format {'history':[..], 'input':[..]} user_prompt = "\n".join(["User:" + i['input']+"\n"+"Assistant:" + i['output'] for i in data_point['history']]) + "\nUser:" + data_point['input'] + "\nAssistant:" user_prompt = user_prompt[-self.max_len:] user_prompt=self.prompt.format_map({'instruction':user_prompt}) input_ids = self.tokenizer(user_prompt)["input_ids"] return input_ids def preprocess_train(self, data_point): # single instruction format {'instruction':..,'input':..,'output':..} if 'instruction' in data_point: if 'input' in data_point: user_prompt = self.prompt_input.format_map(data_point) else: user_prompt = self.prompt.format_map(data_point) output = data_point["output"] # multi turn format {'input':[..], 'output':[..]} else: user_prompt = '' lens = len(data_point['input']) for i in range(lens-1): user_prompt += self.prompt_history.format_map({'input':data_point['input'][i],'output':data_point['output'][i]}) user_prompt += self.prompt_post.format_map({'input':data_point['input'][-1]}) user_prompt = self.prompt.format_map({'instruction': user_prompt}) output = data_point['output'][-1] len_user_prompt_tokens = (len(self.tokenizer( user_prompt, truncation=True, max_length=self.max_len + 1, )["input_ids"])- 1) # no eos token full_tokens = self.tokenizer( user_prompt + output, truncation=True, max_length=self.max_len + 1, padding="max_length", )["input_ids"][:-1] return { "input_ids": full_tokens, "labels": [-100] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:], "attention_mask": [1] * (len(full_tokens)), } def data_collator(self,): return transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False) def postprocess(self, text, render=True): #import pdb;pdb.set_trace() printf(text) output = text.split("### Response:")[1].strip() output = output.replace("Belle", "Vicuna") printf(output) if '###' in output: output = output.split("###")[0] if 'User' in output: output = output.split("User")[0] output = output.replace('�','').replace('', '') if render: # fix gradio chatbot markdown code render bug lines = output.split("\n") for i, line in enumerate(lines): if "```" in line: if line != "```": lines[i] = f'
'
else:
lines[i] = '
'
else:
if i > 0:
lines[i] = "','\n') work for html; but not for gradio return output class chat_prompt(prompt): prompt_pre = ( "The following is a conversation between an AI assistant called Assistant and a human user called User. " "The assistant is intelligent, knowledgeable and polite to answer questions of user.\n\n" ) prompt_history = "User:{input}\n\nAssistant:{output}\n\n" prompt_post = "User:{input}\n\nAssistant:" def preprocess_gen(self, data_point): user_prompt = self.prompt_pre len_avail = self.max_len - len(self.tokenizer(user_prompt, add_special_tokens=False)['input_ids']) input_prompt = self.prompt_post.format_map({'input':data_point['input']}) len_avail -= len(self.tokenizer(input_prompt, add_special_tokens=False)['input_ids']) lens = len(data_point['history']) tokenized_lens = [] for i in range(lens): tmp_prompt = self.prompt_history.format_map(data_point['history'][i]) tokenized_lens.append(len(self.tokenizer(tmp_prompt,add_special_tokens=False)["input_ids"])) # 启发式:/2 优先除前面的 i = 0 while sum(tokenized_lens) > len_avail and i < lens: history = data_point['history'][i] tmp_len1 = len(history['input']) tmp_len2 = len(history['output']) if tmp_len2 > tmp_len1: history['output'] = history['output'][:tmp_len2//2] else: history['input'] = history['input'][:tmp_len1//2] prompt = self.prompt_history.format_map(history) single_len =(len(self.tokenizer(prompt,add_special_tokens=False)["input_ids"])) tokenized_lens[i] = single_len i += 1 total_len = sum(tokenized_lens) # 还不够的话 直接截断 while total_len > len_avail and i < lens - 1 : total_len -= tokenized_lens[i] data_point['history'] = data_point['history'][1:] i += 1 # 最终合并 for i in range(lens): user_prompt += self.prompt_history.format_map(data_point['history'][i]) user_prompt += input_prompt printf({'real_input:':user_prompt}) inputs = self.tokenizer(user_prompt)["input_ids"] return inputs def preprocess_train(self, data_point): user_prompt = self.prompt_pre lens = len(data_point['input']) for i in range(lens-1): user_prompt += self.prompt_history.format_map({'input':data_point['input'][i].strip(),'output':data_point['output'][i].strip()}) user_prompt += self.prompt_post.format_map({'input':data_point['input'][-1].strip()}) len_user_prompt_tokens = len(self.tokenizer( user_prompt, truncation=True, max_length=self.max_len, )["input_ids"]) - 1 # remove extra eos if self.add_eos: full_tokens = self.tokenizer( user_prompt + data_point["output"][-1].strip(), truncation=True, padding=False, max_length=self.max_len, )["input_ids"] # need eos else: full_tokens = self.tokenizer( user_prompt + data_point["output"][-1].strip(), truncation=True, padding=False, max_length=self.max_len+1, )["input_ids"][:-1] # delete eos return { "input_ids": full_tokens, "labels": [-100] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:], "attention_mask": [1] * (len(full_tokens)), } def data_collator(self,): return transformers.DataCollatorForSeq2Seq(self.tokenizer) def postprocess(self, text, render=False): output = text.split("Assistant:")[-1].strip() if 'User:' in output: output = output.split("User:")[0] output = output.replace('�','') if render: # fix gradio chatbot markdown code render bug lines = output.split("\n") for i, line in enumerate(lines): if "```" in line: if line != "```": lines[i] = f'' else: if i > 0: lines[i] = "' else: lines[i] = '
" + line.replace("<", "<").replace(">", ">").replace("__", '\_\_') output = "".join(lines) # output = output.replace('','\n') work for html; but not for gradio return output def get_data_collator(): return transformers.DataCollatorForLanguageModeling