Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union | |
| import copy | |
| from .formatter import EmptyFormatter, StringFormatter | |
| from .formatter import Formatter | |
| from ...utils.constants import * | |
| from transformers import PreTrainedTokenizer | |
| import torch | |
| class Template: | |
| format_image_token: "Formatter" | |
| format_user: "Formatter" | |
| format_assistant: "Formatter" | |
| system: "Formatter" | |
| separator: "Formatter" | |
| def encode(self, messages, tokenizer, mode='train'): | |
| """ | |
| 1. get list form messages(conversations:[{from:human, value:message}, {from:gpt, value:message}]) | |
| ===> human_list, value_list | |
| 2. prompt two list | |
| 3. tokenize prompt | |
| 4. make target | |
| """ | |
| question_list, answer_list = self.get_list_from_message(messages) | |
| prompt = self.prompt(question_list, answer_list) | |
| input_ids = self.tokenizer_image_token(prompt, tokenizer, return_tensors='pt') | |
| if mode == 'train': | |
| labels = self.make_labels(input_ids, prompt, tokenizer) | |
| return dict( | |
| input_ids=input_ids, | |
| labels=labels | |
| ) | |
| else: | |
| return dict(input_ids=input_ids, prompt=prompt) | |
| def get_list_from_message(self, messages): | |
| return self._get_list_from_message(messages) | |
| def _get_list_from_message(self, messages): | |
| """ | |
| messages ====> [{from:human, value:message}, {from:gpt, value:message}] | |
| """ | |
| question_list = [] | |
| answer_list = [] | |
| first_is_not_question = 0 | |
| for i, message in enumerate(messages): | |
| if i == 0 and message['from'] != 'human': | |
| first_is_not_question = 1 | |
| continue | |
| if i % 2 == first_is_not_question: | |
| question_list.append(message['value']) | |
| else: | |
| answer_list.append(message['value']) | |
| assert len(question_list) == len(answer_list) , \ | |
| f"qa is not match : length_q:{len(question_list)} vs length_a:{len(answer_list)}" | |
| return question_list, answer_list | |
| def prompt( | |
| self, | |
| question_list, answer_list | |
| ): | |
| if type(question_list) is str: | |
| question_list = [question_list] | |
| if type(answer_list) is str: | |
| answer_list = [answer_list] | |
| msg = self._prompt(question_list, answer_list) | |
| return msg | |
| def _prompt( | |
| self, | |
| question_list, answer_list, | |
| ): | |
| msg = "" | |
| for i, (question, answer) in enumerate(zip(question_list, answer_list)): | |
| if i == 0: | |
| msg += self.system.apply() | |
| if DEFAULT_IMAGE_TOKEN in question: | |
| question = question.replace(DEFAULT_IMAGE_TOKEN, '').strip() | |
| question = self.format_image_token.apply(content=question).strip() | |
| msg += self.format_user.apply(content=question) | |
| msg += self.format_assistant.apply(content=answer) | |
| return msg | |
| def make_labels(self, input_ids, prompt, tokenizer): | |
| labels = copy.deepcopy(input_ids) | |
| sep, eos_token = self.separator.apply() | |
| total_len = int(labels.ne(tokenizer.pad_token_id).sum()) | |
| if tokenizer.pad_token_id == tokenizer.eos_token_id: | |
| total_len += prompt.count(eos_token) | |
| rounds = prompt.split(eos_token) | |
| eos_token_length = len(tokenizer.encode(eos_token)) | |
| labels, cur_len = self._make_masks(labels, tokenizer, sep, eos_token_length, rounds) | |
| if cur_len < tokenizer.model_max_length: | |
| import time | |
| if cur_len != total_len: | |
| print( | |
| f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." | |
| f" (ignored)" | |
| ) | |
| print("number of rounds: ", len(rounds) - 1) | |
| print("rounds: ", rounds[:-1]) | |
| print("prompt: ", prompt) | |
| print(labels) | |
| print(input_ids) | |
| time.sleep(5) | |
| labels[:] = IGNORE_INDEX | |
| return labels | |
| def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds): | |
| cur_len = 0 | |
| for rou in rounds: | |
| if rou == "": | |
| break | |
| parts = rou.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| round_len = len(self.tokenizer_image_token(rou, tokenizer)) + eos_token_length | |
| instruction_len = len(self.tokenizer_image_token(parts[0], tokenizer)) - 1 | |
| labels[cur_len : cur_len + instruction_len] = IGNORE_INDEX | |
| cur_len += round_len | |
| labels[cur_len:] = IGNORE_INDEX | |
| return labels, cur_len | |
| def tokenizer_image_token(cls, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): | |
| def _insert_separator(X, sep): | |
| return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] | |
| prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] | |
| input_ids = [] | |
| offset = 0 | |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: | |
| offset = 1 | |
| input_ids.append(prompt_chunks[0][0]) | |
| for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): | |
| input_ids.extend(x[offset:]) | |
| if return_tensors is not None: | |
| if return_tensors == 'pt': | |
| return torch.tensor(input_ids, dtype=torch.long) | |
| raise ValueError(f'Unsupported tensor type: {return_tensors}') | |
| return input_ids | |