Spaces:
Runtime error
Runtime error
import warnings | |
import torch | |
import yaml | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import json | |
from model.tokenizer import Tokenizer | |
import os | |
import torchvision.transforms as transforms | |
import random | |
import torchvision.transforms.functional as F | |
import torchaudio | |
from . import conversation_lib | |
import numpy as np | |
from . import video_utils | |
from .imu_utils import get_imu_frames | |
IGNORE_INDEX = -100 | |
DEFAULT_IMAGE_TOKEN = "<image>" | |
try: | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
except ImportError: | |
BICUBIC = Image.BICUBIC | |
T_random_resized_crop = transforms.Compose([ | |
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC, | |
antialias=None), # 3 is bicubic | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) | |
# image transform | |
transform_img_train = transforms.Compose([ | |
transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=( | |
0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) | |
class PairRandomResizedCrop(transforms.RandomResizedCrop): | |
def forward(self, imgs): | |
i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) | |
return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs] | |
class PairToTensor(transforms.ToTensor): | |
def __call__(self, pics): | |
return [F.to_tensor(pic) for pic in pics] | |
class PairNormalize(transforms.Normalize): | |
def forward(self, tensors): | |
return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors] | |
transform_pairimg_train = transforms.Compose([ | |
PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=( | |
0.75, 1.3333), interpolation=3, antialias=None), # 3 is bicubic | |
PairToTensor(), | |
PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) | |
def pc_norm(pc): | |
""" pc: NxC, return NxC """ | |
xyz = pc[:, :3] | |
other_feature = pc[:, 3:] | |
centroid = torch.mean(xyz, dim=0) | |
xyz = xyz - centroid | |
m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1))) | |
xyz = xyz / m | |
pc = torch.cat((xyz, other_feature), dim=1) | |
return pc | |
def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False): | |
waveform, sr = torchaudio.load(wav_name) | |
# assert sr == 16000, 'input audio sampling rate must be 16kHz' | |
if sr != 16000: | |
trans = torchaudio.transforms.Resample(sr, 16000) | |
waveform = trans(waveform) | |
waveform = waveform - waveform.mean() | |
fbank = torchaudio.compliance.kaldi.fbank( | |
waveform, htk_compat=True, sample_frequency=16000, use_energy=False, | |
window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10) | |
n_frames = fbank.shape[0] | |
p = target_length - n_frames | |
if p > 0: | |
m = torch.nn.ZeroPad2d((0, 0, 0, p)) | |
fbank = m(fbank) | |
elif p < 0: | |
fbank = fbank[0:target_length, :] | |
if aug: | |
freqm = torchaudio.transforms.FrequencyMasking(48) | |
timem = torchaudio.transforms.TimeMasking(192) | |
fbank = torch.transpose(fbank, 0, 1) | |
fbank = fbank.unsqueeze(0) | |
fbank = freqm(fbank) | |
fbank = timem(fbank) | |
fbank = fbank.squeeze(0) | |
fbank = torch.transpose(fbank, 0, 1) | |
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2) | |
return fbank | |
class ConversationGenerator: | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.header = f"{conversation_lib.default_conversation.system}\n\n" | |
self._probe_tokenizer_style() | |
def _probe_tokenizer_style(self): | |
""" | |
Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token, | |
while some others will merge the space into the next word, forming a token representing " darling". | |
Knowing which style the tokenizer takes is necessary for correct ground-truth label masking. | |
""" | |
probe = "Probe am I" | |
sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe, | |
bos=False, eos=False) | |
sentence2 = self.tokenizer.encode(probe, | |
bos=False, eos=False) | |
if sentence1[-len(sentence2):] == sentence2: | |
self.space_before_to_predict = False | |
else: | |
sentence3 = self.tokenizer.encode(" " + probe, | |
bos=False, eos=False) | |
assert sentence1[-len(sentence3):] == sentence3 | |
self.space_before_to_predict = True | |
def add_speaker_and_signal(self, source, get_conversation=True): | |
"""Add speaker and start/end signal on each round.""" | |
BEGIN_SIGNAL = "### " | |
END_SIGNAL = "\n" | |
conversation = self.header | |
to_predict_list = [] | |
for sentence in source: | |
from_str = sentence["from"] | |
if from_str.lower() in ["human"]: | |
from_str = conversation_lib.default_conversation.roles[0] | |
elif from_str.lower() in ["gpt", "assistant"]: | |
from_str = conversation_lib.default_conversation.roles[1] | |
else: | |
raise ValueError(f"unknown dialog role: {from_str.lower()}") | |
value = sentence["value"] | |
if DEFAULT_IMAGE_TOKEN in value: | |
value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip() | |
sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL | |
if from_str == conversation_lib.default_conversation.roles[1]: | |
to_predict_value = value + END_SIGNAL + "###" | |
if self.space_before_to_predict: | |
to_predict_value = " " + to_predict_value | |
to_predict_list.append(to_predict_value) | |
if get_conversation: | |
conversation = conversation + sentence_value | |
conversation = conversation + BEGIN_SIGNAL | |
return conversation, to_predict_list | |
DATASETS = dict( | |
image=[ | |
dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'), | |
dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'), | |
dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'), | |
], | |
audio=[ | |
dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'), | |
dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'), | |
dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'), | |
], | |
video=[ | |
dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'), | |
dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'), | |
], | |
point=[ | |
dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'), | |
], | |
rgbd=[ | |
dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'), | |
], | |
rgbn=[ | |
dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'), | |
], | |
imu=[ | |
dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'), | |
], | |
fmri=[ | |
dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'), | |
], | |
) | |
IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/" | |
class FinetuneDialogDataset(Dataset): | |
def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None): | |
if isinstance(dataset, str): | |
dataset = [dataset] | |
self.dataset = dataset | |
group_ann = {} | |
for d in dataset: | |
for meta in DATASETS[d]: | |
meta_path, meta_type = meta['path'], meta['type'] | |
meta_ext = os.path.splitext(meta_path)[-1] | |
if meta_ext == ".json": | |
with open(meta_path) as f: | |
meta_l = json.load(f) | |
# add data_type | |
# this is a temp solution | |
new_meta_l = [] | |
for l in meta_l: | |
l['data_type'] = meta_type | |
new_meta_l.append(l) | |
meta_l = new_meta_l | |
elif meta_ext == ".jsonl": | |
meta_l = [] | |
with open(meta_path) as f: | |
for i, line in enumerate(f): | |
try: | |
meta_l.append(json.loads(line)) | |
except json.decoder.JSONDecodeError as e: | |
print( | |
f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True) | |
raise e | |
else: | |
raise NotImplementedError( | |
f"Unknown meta file extension: \"{meta_ext}\". " | |
f"Currently, .json, .jsonl are supported. " | |
"If you are using a supported format, please set the file extension so that the proper parsing " | |
"routine can be called." | |
) | |
if meta_type not in group_ann: | |
group_ann[meta_type] = [] | |
print(f"{meta_path}, type {meta_type}: len {len(meta_l)}") | |
group_ann[meta_type] += meta_l | |
# sort group_ann for higher efficiency (items in one global batch with similar length) | |
for meta_type, meta_l in group_ann.items(): | |
meta_l.sort(key=lambda data_item: sum( | |
[len(_['value']) for _ in data_item['conversations']])) | |
self.group_ann = group_ann | |
self.ann = sum(list(self.group_ann.values()), start=[]) | |
self.group_indices = {} | |
start_pos = 0 | |
for meta_type, meta_l in self.group_ann.items(): | |
self.group_indices[meta_type] = list( | |
range(start_pos, start_pos + len(meta_l))) | |
start_pos = start_pos + len(meta_l) | |
print(f"total length: {len(self)}") | |
self.transform = transform | |
print(f"transform:\n{self.transform}") | |
self.max_words = max_words | |
self.image_words = image_words | |
self.tokenizer = Tokenizer(model_path=tokenizer_path) | |
self.conversation_generator = ConversationGenerator(self.tokenizer) | |
self.load_funcs = dict( | |
image=self.load_image, | |
audio=self.load_audio, | |
video=self.load_video, | |
point=self.load_point, | |
rgbd=self.load_rgbx, | |
rgbn=self.load_rgbx, | |
imu=self.load_imu, | |
fmri=self.load_fmri | |
) | |
def __len__(self): | |
return len(self.ann) | |
def load_image(self, data): | |
filename = data['image'] | |
image = Image.open(filename).convert('RGB') | |
image = self.transform(image) | |
return image | |
def load_audio(self, data): | |
audio_path = data['image'] | |
fbank = make_audio_features(audio_path, mel_bins=128) | |
fbank = fbank.transpose(0, 1)[None] # [1, 128, 1024] | |
return fbank | |
def load_video(self, data): | |
video_path = data['image'] | |
video_feats = video_utils.load_and_transform_video_data( | |
video_path, video_path, clip_duration=1, clips_per_video=5) | |
return video_feats[:, :, 0] | |
def load_point(self, data): | |
point_path = data['image'] | |
point_feat = torch.load(point_path, map_location='cpu') | |
point_feat = point_feat.transpose(0, 1) | |
return point_feat | |
def load_rgbx(self, data): | |
image_path = data['image'] | |
x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image'] | |
image = Image.open(image_path).convert('RGB') | |
x_image = Image.open(x_image_path).convert('RGB') | |
x_image = x_image.resize(image.size[-2:]) | |
image, x_image = transform_pairimg_train([image, x_image]) | |
# [2, 3, H, W] | |
image = torch.stack([image, x_image], dim=0) | |
return image | |
def load_fmri(self, data): | |
fmri_path = data['image'] | |
data = np.load(fmri_path) | |
data = data.mean(axis=0) | |
data = torch.tensor(data[None]) | |
return data | |
def load_imu(self, data_dict): | |
uid = data_dict["video_uid"] | |
w_s = data_dict["window_start"] | |
w_e = data_dict["window_end"] | |
imu_data = get_imu_frames( | |
IMU_PATH, uid, | |
video_start_sec=w_s, | |
video_end_sec=w_e, | |
) | |
if imu_data is None: | |
raise ValueError | |
return imu_data['signal'] | |
def __getitem__(self, index, expect_type=None): | |
if expect_type is None: | |
data_item = self.ann[index] | |
else: | |
# in case we want get data from specific data_type | |
data_item = self.group_ann[expect_type][index] | |
data_type = data_item['data_type'] | |
if data_type != 'text': | |
if data_type in self.load_funcs: | |
try: | |
image = self.load_funcs[data_type](data_item) | |
if image == None: | |
raise ValueError('Data is None') | |
except: | |
print('Error', data_item) | |
rand_idx = random.randint( | |
0, len(self.group_ann[data_type])) | |
return self.__getitem__(rand_idx, expect_type=data_type) | |
else: | |
raise ValueError(f'Does not support {data_type}') | |
else: | |
image = None | |
# warnings.warn("pure black image for examples without image") | |
# image = torch.zeros(3, 224, 224) | |
source = data_item["conversations"] | |
conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal( | |
source) | |
if len(to_predict_values) == 0: | |
warnings.warn( | |
f"see dialog data with nothing to predict, data: {data_item}") | |
return self[index-1] | |
tokenzed_conversation = self.tokenizer.encode( | |
conversation, bos=True, eos=True) | |
labels = [IGNORE_INDEX for _ in tokenzed_conversation] | |
check_pos = 0 | |
for value in to_predict_values: | |
tokenized_value = self.tokenizer.encode( | |
value, bos=False, eos=False) | |
value_pos = find_sublist( | |
tokenzed_conversation[check_pos:], tokenized_value) + check_pos | |
if value_pos == -1: | |
print( | |
"a sentence mismatches the corresponding piece in the conversation") | |
return self[index-1] | |
labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value | |
assert labels[value_pos:value_pos+len( | |
tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)] | |
check_pos = value_pos+len(tokenized_value) | |
input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64) | |
labels = torch.tensor(labels, dtype=torch.int64) | |
if image is not None: | |
max_words = self.max_words - self.image_words | |
else: | |
max_words = self.max_words | |
padding = max_words - input2.shape[0] | |
if padding > 0: | |
input2 = torch.cat( | |
(input2, torch.zeros(padding, dtype=torch.int64) - 1)) | |
labels = torch.cat( | |
(labels, torch.zeros(padding, dtype=torch.int64) - 1)) | |
elif padding < 0: | |
input2 = input2[:max_words] | |
labels = labels[:max_words] | |
input2_mask = input2.ge(0) | |
label_mask = labels.ge(0) | |
input2[~input2_mask] = 0 | |
labels[~label_mask] = 0 | |
input2_mask = input2_mask.float() | |
label_mask = label_mask.float() | |
if image is None: | |
return input2, labels, data_item['data_type'] | |
else: | |
return input2, labels, image, data_item['data_type'] | |
def groups(self): | |
return list(self.group_indices.values()) | |
def find_sublist(a: list, b: list): | |
len_a, len_b = len(a), len(b) | |
for i in range(len_a - len_b + 1): | |
if a[i:i+len_b] == b: | |
return i | |
return -1 | |