import io import os import random import re from typing import Dict import cv2 import imageio import numpy as np import torch import torchvision.transforms as T import transformers from PIL import Image from torch.utils.data import ConcatDataset, WeightedRandomSampler from torchvision.transforms.functional import InterpolationMode from xtuner.utils import IGNORE_INDEX IGNORE_TOKEN_ID = IGNORE_INDEX from ..utils import (get_conv_template, IMG_CONTEXT_TOKEN, IMG_START_TOKEN, IMG_END_TOKEN, ) try: from petrel_client.client import Client from petrel_client.common.config import Config except ImportError as E: print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') import sys def preprocess( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO # Mask targets. Only compute loss on the assistant outputs. sep = conv.sep + conv.roles[1] + ': ' for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) turns = conversation.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID for i, turn in enumerate(turns): if turn == '': break turn_len = len(tokenizer(turn).input_ids) parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep # "-2" is hardcoded for the Llama tokenizer to make the offset correct. instruction_len = len(tokenizer(parts[0]).input_ids) - 2 if i != 0 and not tokenizer.legacy: # The legacy and non-legacy modes handle special tokens differently instruction_len -= 1 # Ignore the user instructions target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID cur_len += turn_len if i != 0 and not tokenizer.legacy: # The legacy and non-legacy modes handle special tokens differently cur_len -= 1 target[cur_len:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) logger.info(tokenizer.decode(z)) exit() if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print( f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' ) sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_mpt( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() # Mask targets. Only compute loss on the assistant outputs. sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) turns = conversation.split(conv.sep) re_turns = [conv.sep.join(turns[:3])] # system + user + gpt for conv_idx in range(3, len(turns), 2): re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_TOKEN_ID for i, turn in enumerate(re_turns): if turn == '': break turn_len = len(tokenizer(turn).input_ids) + 1 parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep instruction_len = len(tokenizer(parts[0]).input_ids) # Ignore the user instructions target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) cur_len += turn_len target[cur_len:] = IGNORE_TOKEN_ID if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print( f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' ) sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_phi3( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations tokenizer.padding_side = 'right' input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() # Mask targets. Only compute loss on the assistant outputs. sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|> for conversation, target in zip(conversations, targets): total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) turns = conversation.split(conv.sep) re_turns = [conv.sep.join(turns[:3])] # system + user + gpt for conv_idx in range(3, len(turns), 2): re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') target[target == endoftext_id] = IGNORE_TOKEN_ID for i, turn in enumerate(re_turns): if turn == '': break if i == 0: turn_len = len(tokenizer(turn).input_ids) else: turn_len = len(tokenizer(turn).input_ids) - 1 parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep if i == 0: instruction_len = len(tokenizer(parts[0]).input_ids) - 1 else: instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # Ignore the user instructions target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) cur_len += turn_len target[cur_len:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) print(repr(tokenizer.decode(z))) if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print( f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' ) sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), ) def preprocess_internlm( template_name, sources, tokenizer: transformers.PreTrainedTokenizer, num_image_token_list: list, text_only: bool = False, group_by_length: bool = False, use_packed_ds: bool = False, ds_name: str = None, num_image: int = 1 ) -> Dict: conv = get_conv_template(template_name) roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]['from']] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence['from']] assert role == conv.roles[j % 2], f'{i}' sentence['value'] = sentence['value'].strip() conv.append_message(role, sentence['value']) conversations.append(conv.get_prompt()) if not text_only: new_conversations = [] for conversation in conversations: for i in range(num_image): image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' conversation = conversation.replace('', image_tokens, 1) new_conversations.append(conversation) conversations = new_conversations # Tokenize conversations input_ids = tokenizer( conversations, return_tensors='pt', padding=False if group_by_length or use_packed_ds else 'max_length', max_length=tokenizer.model_max_length, truncation=True, ).input_ids targets = input_ids.clone() for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID # parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n info = parts[0] + conv.roles[1] temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的 target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID cur_len = cur_len + temp_len for index in range(1, len(parts) - 1): info = parts[index] part1, part2 = info.split(conv.roles[0]) temp_len = len(tokenizer(part1).input_ids) - 1 cur_len = cur_len + temp_len part = conv.roles[0] + part2 + conv.roles[1] temp_len = len(tokenizer(part).input_ids) - 1 target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID cur_len = cur_len + temp_len last_info = parts[-1] temp_len = len(tokenizer(last_info).input_ids) - 1 cur_len = cur_len + temp_len target[cur_len:] = IGNORE_TOKEN_ID if False: # Inspect and check the correctness of masking z = target.clone() z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) print(repr(tokenizer.decode(z))) if cur_len < tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') sys.stdout.flush() return dict( input_ids=input_ids, labels=targets, attention_mask=input_ids.ne(tokenizer.pad_token_id), )