|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
import matplotlib.pyplot as plt |
|
from transformers.trainer_pt_utils import LabelSmoother |
|
|
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index |
|
import os |
|
import random |
|
import re |
|
from collections import Counter |
|
from typing import Dict |
|
|
|
import cv2 |
|
import imageio |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
import transformers |
|
from decord import VideoReader |
|
from internvl.conversation import get_conv_template |
|
from PIL import Image |
|
from torch.utils.data import ConcatDataset, WeightedRandomSampler |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, |
|
IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, |
|
SIGLIP_MEAN, SIGLIP_STD) |
|
|
|
try: |
|
from petrel_client.client import Client |
|
from petrel_client.common.config import Config |
|
except ImportError as E: |
|
print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') |
|
import sys |
|
|
|
|
|
def calculate_ngram_repetition(text, n): |
|
words = text.split() |
|
ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] |
|
ngram_counts = Counter(ngrams) |
|
total_ngrams = len(ngrams) |
|
repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1) |
|
return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0 |
|
|
|
|
|
def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10): |
|
for conversation in conversations: |
|
if conversation['from'] == 'gpt': |
|
model_answer = conversation['value'] |
|
repeat_ratio = calculate_ngram_repetition(model_answer, ngram) |
|
if repeat_ratio > repeat_threshold: |
|
raise Exception |
|
|
|
|
|
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): |
|
if sample in ['rand', 'middle']: |
|
acc_samples = min(num_frames, vlen) |
|
|
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
|
ranges = [] |
|
for idx, interv in enumerate(intervals[:-1]): |
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
if sample == 'rand': |
|
try: |
|
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] |
|
except: |
|
frame_indices = np.random.permutation(vlen)[:acc_samples] |
|
frame_indices.sort() |
|
frame_indices = list(frame_indices) |
|
elif fix_start is not None: |
|
frame_indices = [x[0] + fix_start for x in ranges] |
|
elif sample == 'middle': |
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
|
else: |
|
raise NotImplementedError |
|
|
|
if len(frame_indices) < num_frames: |
|
padded_frame_indices = [frame_indices[-1]] * num_frames |
|
padded_frame_indices[:len(frame_indices)] = frame_indices |
|
frame_indices = padded_frame_indices |
|
elif 'fps' in sample: |
|
output_fps = float(sample[3:]) |
|
duration = float(vlen) / input_fps |
|
delta = 1 / output_fps |
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
|
frame_indices = np.around(frame_seconds * input_fps).astype(int) |
|
frame_indices = [e for e in frame_indices if e < vlen] |
|
if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
|
frame_indices = frame_indices[:max_num_frames] |
|
|
|
else: |
|
raise ValueError |
|
return frame_indices |
|
|
|
|
|
def read_frames_gif( |
|
video_path, num_frames, sample='rand', fix_start=None, |
|
client=None, min_num_frames=4 |
|
): |
|
if 's3://' in video_path: |
|
video_bytes = client.get(video_path) |
|
gif = imageio.get_reader(io.BytesIO(video_bytes)) |
|
else: |
|
gif = imageio.get_reader(video_path) |
|
vlen = len(gif) |
|
|
|
t_num_frames = np.random.randint(min_num_frames, num_frames + 1) |
|
frame_indices = get_frame_indices( |
|
t_num_frames, vlen, sample=sample, fix_start=fix_start |
|
) |
|
frames = [] |
|
for index, frame in enumerate(gif): |
|
if index in frame_indices: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8) |
|
frame = Image.fromarray(frame) |
|
frames.append(frame) |
|
return frames |
|
|
|
|
|
def read_frames_decord( |
|
video_path, num_frames, sample='rand', fix_start=None, |
|
client=None, clip=None, min_num_frames=4 |
|
): |
|
if 's3://' in video_path: |
|
video_bytes = client.get(video_path) |
|
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) |
|
else: |
|
video_reader = VideoReader(video_path, num_threads=1) |
|
vlen = len(video_reader) |
|
fps = video_reader.get_avg_fps() |
|
duration = vlen / float(fps) |
|
if clip: |
|
start, end = clip |
|
duration = end - start |
|
vlen = int(duration * fps) |
|
start_index = int(start * fps) |
|
|
|
|
|
t_num_frames = np.random.randint(min_num_frames, num_frames + 1) |
|
|
|
frame_indices = get_frame_indices( |
|
t_num_frames, vlen, sample=sample, fix_start=fix_start, |
|
input_fps=fps |
|
) |
|
if clip: |
|
frame_indices = [f + start_index for f in frame_indices] |
|
frames = video_reader.get_batch(frame_indices).asnumpy() |
|
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] |
|
return frames |
|
|
|
|
|
def extract_frame_number(filename): |
|
|
|
match = re.search(r'_(\d+).jpg$', filename) |
|
return int(match.group(1)) if match else -1 |
|
|
|
|
|
def sort_frames(frame_paths): |
|
|
|
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) |
|
|
|
|
|
def read_frames_folder( |
|
video_path, num_frames, sample='rand', fix_start=None, |
|
client=None, clip=None, min_num_frames=4 |
|
): |
|
if 's3://' in video_path: |
|
image_list = sort_frames(client.list(video_path)) |
|
frames = [] |
|
for image in image_list: |
|
fp = os.path.join(video_path, image) |
|
frame = Image.open(io.BytesIO(client.get(fp))) |
|
frames.append(frame) |
|
else: |
|
image_list = sort_frames(list(os.listdir(video_path))) |
|
frames = [] |
|
for image in image_list: |
|
fp = os.path.join(video_path, image) |
|
frame = Image.open(fp).convert('RGB') |
|
frames.append(frame) |
|
vlen = len(frames) |
|
|
|
t_num_frames = np.random.randint(min_num_frames, num_frames + 1) |
|
|
|
if vlen > t_num_frames: |
|
frame_indices = get_frame_indices( |
|
t_num_frames, vlen, sample=sample, fix_start=fix_start |
|
) |
|
frames = [frames[i] for i in frame_indices] |
|
return frames |
|
|
|
|
|
class WeightedConcatDataset(ConcatDataset): |
|
def __init__(self, datasets, weights): |
|
super().__init__(datasets) |
|
self.weights = torch.DoubleTensor(weights) |
|
self.total_size = sum(len(d) for d in datasets) |
|
self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) |
|
|
|
def __iter__(self): |
|
return iter(self.sampler) |
|
|
|
def __len__(self): |
|
return self.total_size |
|
|
|
|
|
def pil_loader(img_str): |
|
buff = io.BytesIO(img_str) |
|
img = Image.open(buff) |
|
return img.convert('RGB') |
|
|
|
|
|
class TCSLoader(object): |
|
|
|
def __init__(self, conf_path, sc_config_key='sensecore'): |
|
print(f'[TCSLoader] config_path: {conf_path}') |
|
print('--> before Client(conf_path)') |
|
self.client = Client(conf_path) |
|
self.sc_config_key = sc_config_key |
|
print('--> after Client(conf_path)') |
|
|
|
def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None): |
|
if image_type == 'image': |
|
img_value_str = self.client.get(fn) |
|
img = pil_loader(img_value_str) |
|
return img |
|
|
|
elif image_type == 'video': |
|
if fn.endswith('/'): |
|
frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, |
|
client=self.client, sample=sample) |
|
elif fn.endswith('.gif'): |
|
frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, |
|
client=self.client, sample=sample) |
|
else: |
|
frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, |
|
client=self.client, sample=sample, clip=clip) |
|
return frames |
|
|
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
|
|
def simulate_jpeg_degradation(quality): |
|
def jpeg_degrade(img): |
|
with io.BytesIO() as output: |
|
img.convert('RGB').save(output, format='JPEG', quality=quality) |
|
output.seek(0) |
|
img_jpeg = Image.open(output).copy() |
|
return img_jpeg |
|
return jpeg_degrade |
|
|
|
|
|
|
|
qualities = list(range(75, 101)) |
|
jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} |
|
|
|
|
|
def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): |
|
if normalize_type == 'imagenet': |
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
elif normalize_type == 'clip': |
|
MEAN, STD = CLIP_MEAN, CLIP_STD |
|
elif normalize_type == 'siglip': |
|
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD |
|
else: |
|
raise NotImplementedError |
|
if is_train: |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=MEAN, std=STD) |
|
]) |
|
else: |
|
if pad2square is False: |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=MEAN, std=STD) |
|
]) |
|
else: |
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=MEAN, std=STD) |
|
]) |
|
|
|
return transform |
|
|
|
|
|
def preprocess( |
|
template_name, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
num_image_token_list: list, |
|
text_only: bool = False, |
|
group_by_length: bool = False, |
|
use_packed_ds: bool = False, |
|
ds_name: str = None, |
|
num_image: int = 1 |
|
) -> Dict: |
|
conv = get_conv_template(template_name) |
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence['from']] |
|
assert role == conv.roles[j % 2], f'{i}' |
|
conv.append_message(role, sentence['value']) |
|
conversations.append(conv.get_prompt()) |
|
|
|
if not text_only: |
|
new_conversations = [] |
|
for conversation in conversations: |
|
for i in range(num_image): |
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
new_conversations.append(conversation) |
|
conversations = new_conversations |
|
|
|
|
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors='pt', |
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
targets = input_ids.clone() |
|
|
|
|
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ': ' |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
turns = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
for i, turn in enumerate(turns): |
|
if turn == '': |
|
break |
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
instruction_len -= 1 |
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
cur_len += turn_len |
|
|
|
if i != 0 and not tokenizer.legacy: |
|
|
|
cur_len -= 1 |
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
if False: |
|
z = target.clone() |
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
logger.info(tokenizer.decode(z)) |
|
exit() |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_TOKEN_ID |
|
print( |
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
) |
|
sys.stdout.flush() |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
def preprocess_mpt( |
|
template_name, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
num_image_token_list: list, |
|
text_only: bool = False, |
|
group_by_length: bool = False, |
|
use_packed_ds: bool = False, |
|
ds_name: str = None, |
|
num_image: int = 1 |
|
) -> Dict: |
|
conv = get_conv_template(template_name) |
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence['from']] |
|
assert role == conv.roles[j % 2], f'{i}' |
|
conv.append_message(role, sentence['value']) |
|
conversations.append(conv.get_prompt()) |
|
|
|
if not text_only: |
|
new_conversations = [] |
|
for conversation in conversations: |
|
for i in range(num_image): |
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
new_conversations.append(conversation) |
|
conversations = new_conversations |
|
|
|
|
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors='pt', |
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
targets = input_ids.clone() |
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
turns = conversation.split(conv.sep) |
|
re_turns = [conv.sep.join(turns[:3])] |
|
for conv_idx in range(3, len(turns), 2): |
|
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) |
|
cur_len = 0 |
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
for i, turn in enumerate(re_turns): |
|
if turn == '': |
|
break |
|
turn_len = len(tokenizer(turn).input_ids) + 1 |
|
|
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
|
|
|
|
|
|
cur_len += turn_len |
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_TOKEN_ID |
|
print( |
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
) |
|
sys.stdout.flush() |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
def preprocess_phi3( |
|
template_name, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
num_image_token_list: list, |
|
text_only: bool = False, |
|
group_by_length: bool = False, |
|
use_packed_ds: bool = False, |
|
ds_name: str = None, |
|
num_image: int = 1 |
|
) -> Dict: |
|
conv = get_conv_template(template_name) |
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence['from']] |
|
assert role == conv.roles[j % 2], f'{i}' |
|
conv.append_message(role, sentence['value']) |
|
conversations.append(conv.get_prompt()) |
|
|
|
if not text_only: |
|
new_conversations = [] |
|
for conversation in conversations: |
|
for i in range(num_image): |
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
new_conversations.append(conversation) |
|
conversations = new_conversations |
|
|
|
|
|
tokenizer.padding_side = 'right' |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors='pt', |
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
targets = input_ids.clone() |
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) |
|
|
|
turns = conversation.split(conv.sep) |
|
re_turns = [conv.sep.join(turns[:3])] |
|
for conv_idx in range(3, len(turns), 2): |
|
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') |
|
target[target == endoftext_id] = IGNORE_TOKEN_ID |
|
|
|
for i, turn in enumerate(re_turns): |
|
if turn == '': |
|
break |
|
if i == 0: |
|
turn_len = len(tokenizer(turn).input_ids) |
|
else: |
|
turn_len = len(tokenizer(turn).input_ids) - 1 |
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
if i == 0: |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
else: |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
|
|
|
|
|
|
cur_len += turn_len |
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
if False: |
|
z = target.clone() |
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
print(repr(tokenizer.decode(z))) |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_TOKEN_ID |
|
print( |
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
) |
|
sys.stdout.flush() |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
def preprocess_internlm( |
|
template_name, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
num_image_token_list: list, |
|
text_only: bool = False, |
|
group_by_length: bool = False, |
|
use_packed_ds: bool = False, |
|
ds_name: str = None, |
|
num_image: int = 1 |
|
) -> Dict: |
|
conv = get_conv_template(template_name) |
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence['from']] |
|
assert role == conv.roles[j % 2], f'{i}' |
|
sentence['value'] = sentence['value'].strip() |
|
conv.append_message(role, sentence['value']) |
|
conversations.append(conv.get_prompt()) |
|
|
|
if not text_only: |
|
new_conversations = [] |
|
for conversation in conversations: |
|
for i in range(num_image): |
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
|
conversation = conversation.replace('<image>', image_tokens, 1) |
|
new_conversations.append(conversation) |
|
conversations = new_conversations |
|
|
|
|
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors='pt', |
|
padding=False if group_by_length or use_packed_ds else 'max_length', |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
targets = input_ids.clone() |
|
|
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
parts = conversation.split(conv.roles[1]) |
|
info = parts[0] + conv.roles[1] |
|
temp_len = len(tokenizer(info).input_ids) - 1 |
|
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID |
|
cur_len = cur_len + temp_len |
|
|
|
for index in range(1, len(parts) - 1): |
|
info = parts[index] |
|
part1, part2 = info.split(conv.roles[0]) |
|
temp_len = len(tokenizer(part1).input_ids) - 1 |
|
cur_len = cur_len + temp_len |
|
part = conv.roles[0] + part2 + conv.roles[1] |
|
temp_len = len(tokenizer(part).input_ids) - 1 |
|
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID |
|
cur_len = cur_len + temp_len |
|
last_info = parts[-1] |
|
temp_len = len(tokenizer(last_info).input_ids) - 1 |
|
cur_len = cur_len + temp_len |
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
if False: |
|
z = target.clone() |
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
print(repr(tokenizer.decode(z))) |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_TOKEN_ID |
|
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') |
|
sys.stdout.flush() |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
def preprocess_internvl2_5( |
|
template_name, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
num_image_token_list: list, |
|
text_only: bool = False, |
|
group_by_length: bool = False, |
|
use_packed_ds: bool = False, |
|
ds_name: str = None, |
|
num_image: int = 1 |
|
) -> Dict: |
|
assert len(sources) == 1, 'process only the first conversations' |
|
conversations = sources[0] |
|
|
|
if conversations[0]['from'] == 'system': |
|
system_prompt = conversations[0]['value'] |
|
conversations = conversations[1:] |
|
else: |
|
conv = get_conv_template(template_name) |
|
system_prompt = conv.system_message |
|
|
|
|
|
if not text_only: |
|
new_conversations = [] |
|
current_image_idx = 0 |
|
for conversation in conversations: |
|
if conversation['from'] == 'human': |
|
image_cnt = conversation['value'].count('<image>') |
|
for i in range(image_cnt): |
|
if current_image_idx == num_image: |
|
break |
|
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}' |
|
conversation['value'] = conversation['value'].replace('<image>', image_tokens, 1) |
|
current_image_idx += 1 |
|
new_conversations.append(conversation) |
|
conversations = new_conversations |
|
assert current_image_idx == num_image, f'{current_image_idx} != {num_image}' |
|
|
|
batches, roles = [], [] |
|
if system_prompt is not None: |
|
batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n') |
|
roles.append('system') |
|
for conversation in conversations: |
|
if conversation['from'] == 'human': |
|
batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n') |
|
roles.append('human') |
|
elif conversation['from'] == 'gpt': |
|
batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n') |
|
roles.append('gpt') |
|
else: |
|
raise NotImplementedError |
|
|
|
add_bos_token = getattr(tokenizer, 'add_bos_token', False) |
|
if add_bos_token: |
|
batches[0] = tokenizer.bos_token + batches[0] |
|
|
|
|
|
input_ids = tokenizer( |
|
batches, |
|
return_tensors='np', |
|
padding=False, |
|
max_length=tokenizer.model_max_length, |
|
truncation=False, |
|
).input_ids |
|
|
|
if add_bos_token: |
|
input_ids = [item[1:] for item in input_ids] |
|
|
|
final_input_ids, final_targets = [], [] |
|
ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0] |
|
ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0] |
|
for role, input_id in zip(roles, input_ids): |
|
final_input_ids.append(input_id) |
|
if role == 'system' or role == 'human': |
|
final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) |
|
elif role == 'gpt': |
|
target = input_id.copy() |
|
target[:ignore_len] = IGNORE_TOKEN_ID |
|
target[-1:] = IGNORE_TOKEN_ID |
|
final_targets.append(target) |
|
else: |
|
raise NotImplementedError |
|
input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length] |
|
targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length] |
|
|
|
padding = False if group_by_length or use_packed_ds else True |
|
if padding: |
|
current_length = input_ids.size(0) |
|
padding_length = tokenizer.model_max_length - current_length |
|
input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id) |
|
targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID) |
|
|
|
input_ids = input_ids.unsqueeze(0) |
|
targets = targets.unsqueeze(0) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
best_ratio_diff = float('inf') |
|
best_ratio = (1, 1) |
|
area = width * height |
|
for ratio in target_ratios: |
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
if ratio_diff < best_ratio_diff: |
|
best_ratio_diff = ratio_diff |
|
best_ratio = ratio |
|
elif ratio_diff == best_ratio_diff: |
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
best_ratio = ratio |
|
|
|
return best_ratio |
|
|
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False): |
|
orig_width, orig_height = image.size |
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
target_ratios = set( |
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_num and i * j >= min_num) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
target_height = image_size * target_aspect_ratio[1] |
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
processed_images = [] |
|
for i in range(blocks): |
|
box = ( |
|
(i % (target_width // image_size)) * image_size, |
|
(i // (target_width // image_size)) * image_size, |
|
((i % (target_width // image_size)) + 1) * image_size, |
|
((i // (target_width // image_size)) + 1) * image_size |
|
) |
|
|
|
split_img = resized_img.crop(box) |
|
processed_images.append(split_img) |
|
assert len(processed_images) == blocks |
|
if use_thumbnail and len(processed_images) != 1: |
|
thumbnail_img = image.resize((image_size, image_size)) |
|
processed_images.append(thumbnail_img) |
|
if return_ratio: |
|
return processed_images, target_aspect_ratio |
|
return processed_images |
|
|
|
|
|
def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): |
|
|
|
length, orig_height, orig_width = image.shape |
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
target_ratios = set( |
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_num and i * j >= min_num) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
target_height = image_size * target_aspect_ratio[1] |
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
tensor_images = image.unsqueeze(1) |
|
|
|
resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) |
|
resized_images = resized_images > 0 |
|
|
|
|
|
processed_images = [] |
|
for i in range(blocks): |
|
top = (i // (target_width // image_size)) * image_size |
|
left = (i % (target_width // image_size)) * image_size |
|
bottom = top + image_size |
|
right = left + image_size |
|
|
|
split_img = resized_images[..., top:bottom, left:right] |
|
processed_images.append(split_img) |
|
|
|
|
|
|
|
|
|
|
|
processed_images = [img.squeeze(1) for img in processed_images] |
|
|
|
assert len(processed_images) == blocks |
|
if use_thumbnail and len(processed_images) != 1: |
|
thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1) |
|
thumbnail_img = thumbnail_img > 0 |
|
|
|
processed_images.append(thumbnail_img) |
|
return processed_images |
|
|