|
import requests |
|
import re, ujson, os, sys, fire, glob, random, time, json |
|
import numpy as np |
|
import io |
|
import torch |
|
from torch.utils.data import default_collate |
|
import torchaudio |
|
from typing import * |
|
from dataclasses import dataclass, field |
|
import transformers |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function |
|
from functools import lru_cache |
|
from io import BytesIO |
|
from PIL import Image |
|
import concurrent.futures as cf |
|
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size |
|
from transformers.image_utils import PILImageResampling |
|
from PIL import Image, ImageOps |
|
from PIL import ImageFile |
|
torch.set_num_threads(1) |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
import base64 |
|
from decord import VideoReader, cpu |
|
import cv2 |
|
import av |
|
import imagesize |
|
import tempfile |
|
import math |
|
from multiprocessing import Pool |
|
from cairosvg import svg2png |
|
import hashlib |
|
|
|
IMAGE_FACTOR = 28 |
|
MIN_PIXELS = 4 * 28 * 28 |
|
MAX_PIXELS = 16384 * 28 * 28 |
|
MAX_RATIO = 200 |
|
|
|
VIDEO_MIN_PIXELS = 128 * 28 * 28 |
|
VIDEO_MAX_PIXELS = 768 * 28 * 28 |
|
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 |
|
FRAME_FACTOR = 2 |
|
FPS = 2.0 |
|
FPS_MIN_FRAMES = 4 |
|
FPS_MAX_FRAMES = 768 |
|
|
|
def round_by_factor(number: int, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def ceil_by_factor(number: int, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
|
|
def floor_by_factor(number: int, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
|
|
def smart_resize( |
|
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS |
|
) -> tuple[int, int]: |
|
""" |
|
Rescales the image so that the following conditions are met: |
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
|
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
""" |
|
if max(height, width) / min(height, width) > MAX_RATIO: |
|
raise ValueError( |
|
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" |
|
) |
|
h_bar = max(factor, round_by_factor(height, factor)) |
|
w_bar = max(factor, round_by_factor(width, factor)) |
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = floor_by_factor(height / beta, factor) |
|
w_bar = floor_by_factor(width / beta, factor) |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = ceil_by_factor(height * beta, factor) |
|
w_bar = ceil_by_factor(width * beta, factor) |
|
return h_bar, w_bar |
|
|
|
|
|
def split_text(text, match_regex): |
|
matches = list(re.finditer(match_regex, text)) |
|
|
|
result = [] |
|
match_flag_list = [] |
|
|
|
last_end = 0 |
|
|
|
for match in matches: |
|
|
|
if text[last_end:match.start()]: |
|
result.append(text[last_end:match.start()]) |
|
match_flag_list.append(False) |
|
|
|
result.append(match.group(0)) |
|
match_flag_list.append(True) |
|
|
|
last_end = match.end() |
|
|
|
if text[last_end:]: |
|
result.append(text[last_end:]) |
|
match_flag_list.append(False) |
|
return result, match_flag_list |
|
|
|
|
|
def read_video(image_path, max_frame_number, decode_way): |
|
if decode_way=='1fps': |
|
try: |
|
|
|
vr = VideoReader(image_path, ctx=cpu(0)) |
|
total_frame_num = len(vr) |
|
fps = round(vr.get_avg_fps()) |
|
frame_idx = [i for i in range(0, len(vr), fps)] |
|
frames = vr.get_batch(frame_idx).asnumpy() |
|
cnt = len(frames) |
|
frame_times = range(cnt) |
|
except Exception as e: |
|
print(image_path) |
|
print('error is', e) |
|
return None |
|
elif decode_way=='key': |
|
try: |
|
with av.open(image_path) as container: |
|
stream = container.streams.video[0] |
|
stream.codec_context.skip_frame = 'NONKEY' |
|
frames = [] |
|
frame_times = [] |
|
fps = int(stream.average_rate) |
|
cnt = 0 |
|
for frame in container.decode(stream): |
|
image = np.array(frame.to_image()) |
|
frames.append(image) |
|
frame_time = int(frame.time) |
|
frame_times.append(frame_time) |
|
cnt += 1 |
|
except Exception as e: |
|
print('error is', e) |
|
return None |
|
if frames is None or len(frames)==0: |
|
return None |
|
if len(frames)>max_frame_number and max_frame_number>0: |
|
|
|
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int) |
|
|
|
frames = frames[indices] |
|
frame_times = frame_times[indices] |
|
return frames, frame_times |
|
|
|
|
|
class OmniImageProcessor: |
|
def __init__(self, config, **kwargs): |
|
self.config = config |
|
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56 |
|
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280 |
|
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14 |
|
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2 |
|
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2 |
|
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2 |
|
|
|
def image_transform(self, strseq, return_mm_data = True): |
|
image = None |
|
if isinstance(strseq, str): |
|
if return_mm_data: |
|
image = Image.open(strseq).convert("RGB") |
|
else: |
|
try: |
|
image = Image.open(BytesIO(strseq)).convert("RGB") |
|
except: |
|
image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") |
|
|
|
image = np.array(image.convert("RGB")) |
|
image_org_size = image.shape[:2] |
|
|
|
|
|
|
|
resized_height, resized_width = smart_resize( |
|
image_org_size[0], image_org_size[1], |
|
factor=self.patch_size * self.spatial_merge_size, |
|
min_pixels=self.min_pixels, |
|
max_pixels=self.max_pixels, |
|
) |
|
output_size = (resized_height, resized_width) |
|
|
|
|
|
|
|
|
|
image = resize(image, output_size, PILImageResampling.BICUBIC) |
|
img = image.transpose(2, 0, 1) |
|
|
|
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis] |
|
|
|
patches = image[np.newaxis, :] |
|
if patches.shape[0] == 1: |
|
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1)) |
|
channel = patches.shape[1] |
|
grid_t = patches.shape[0] // self.temporal_patch_size |
|
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size |
|
patches = patches.reshape( |
|
grid_t, |
|
self.temporal_patch_size, |
|
channel, |
|
grid_h // self.spatial_merge_size, |
|
self.spatial_merge_size, |
|
self.patch_size, |
|
grid_w // self.spatial_merge_size, |
|
self.spatial_merge_size, |
|
self.patch_size, |
|
) |
|
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) |
|
flatten_patches = patches.reshape( |
|
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size |
|
) |
|
|
|
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w) |
|
|
|
|
|
class OmniAudioProcessor: |
|
|
|
def __init__( |
|
self, |
|
config, |
|
**kwargs |
|
): |
|
|
|
assert(len(torchaudio.list_audio_backends()) > 0) |
|
self.config = config |
|
self.mel_filters = mel_filter_bank( |
|
num_frequency_bins=1 + self.config.n_fft // 2, |
|
num_mel_filters=self.config.num_mel_bins, |
|
min_frequency=0.0, |
|
max_frequency=self.config.sampling_rate / 2.0, |
|
sampling_rate=self.config.sampling_rate, |
|
norm="slaney", |
|
mel_scale="slaney", |
|
) |
|
self.window = torch.hann_window(self.config.n_fft) |
|
|
|
@staticmethod |
|
def dynamic_range_compression(x, C=1, clip_val=1e-6): |
|
return torch.log(torch.clamp(x, min=clip_val) * C) |
|
|
|
@staticmethod |
|
def zero_mean_unit_var_norm(x): |
|
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8) |
|
|
|
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False): |
|
metadata = torchaudio.info(uri) |
|
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) |
|
waveform_tensor, _ = torchaudio.load(uri, normalize=True) |
|
if self.config.sampling_rate != metadata.sample_rate: |
|
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128) |
|
|
|
|
|
if metadata.num_channels > 1: |
|
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True) |
|
|
|
|
|
if do_normalize: |
|
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor) |
|
|
|
if return_tensors: |
|
return waveform_tensor |
|
else: |
|
return waveform_tensor.numpy() |
|
|
|
def split_with_overlap(self, waveform): |
|
channels, wave_samples = waveform.shape |
|
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate |
|
if wave_samples <= max_audio_samples or self.config.split_overlap < 0: |
|
return [waveform] |
|
|
|
split_waveform, start = [], 0 |
|
while start < wave_samples: |
|
if start > int(self.config.sampling_rate * self.config.split_overlap): |
|
start -= int(self.config.sampling_rate * self.config.split_overlap) |
|
end = min(start + max_audio_samples, wave_samples) |
|
if end - start>= self.config.n_fft: |
|
split_waveform.append(waveform[:, start:end]) |
|
start = end |
|
return split_waveform |
|
|
|
@classmethod |
|
def inference_output_length(cls, config, input_length): |
|
|
|
kernel_size = config.kernel_size |
|
stride_size = config.stride_size |
|
avg_pooler = config.avg_pooler |
|
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 |
|
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 |
|
if avg_pooler > 1: |
|
bridge_length = encoder_length // avg_pooler |
|
return encoder_length, bridge_length |
|
|
|
def extract_fbank_features(self, waveform): |
|
|
|
channels, wave_samples = waveform.shape |
|
assert(wave_samples >= self.config.n_fft) |
|
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1) |
|
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate: |
|
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0) |
|
else: |
|
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate] |
|
|
|
|
|
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) |
|
magnitudes = stft[..., :-1].abs() ** 2 |
|
|
|
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) |
|
mel_spec = mel_filters.T @ magnitudes |
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
|
if waveform.dim() == 2: |
|
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] |
|
log_spec = torch.maximum(log_spec, max_val - 8.0) |
|
else: |
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
|
log_spec = (log_spec + 4.0) / 4.0 |
|
|
|
log_spec = log_spec[0].numpy() |
|
log_spec[:, valid_frame_nums:] = 0.0 |
|
|
|
return log_spec, valid_frame_nums |
|
|
|
def data_augment(self, feature: np.array, input_length, training=True): |
|
|
|
def mask_start_indices(input_length, mask_length, min_masks, mask_prob): |
|
num_masked_span = int(mask_prob * input_length / mask_length + random.random()) |
|
num_masked_span = max(num_masked_span, min_masks) |
|
start_indices = list(range(input_length - mask_length)) |
|
random.shuffle(start_indices) |
|
start_indices = start_indices[:num_masked_span] |
|
return start_indices |
|
|
|
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0): |
|
return feature |
|
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1: |
|
return feature |
|
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1: |
|
return feature |
|
|
|
if self.config.mask_time_prob > 0: |
|
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob) |
|
for start_idx in start_indices: |
|
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0 |
|
if self.config.mask_feature_prob > 0: |
|
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob) |
|
for start_idx in start_indices: |
|
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0 |
|
|
|
return feature |
|
|
|
@dataclass |
|
class OmniProcessorOutput(ModelOutput): |
|
input_ids: Optional["List|torch.Tensor"] = None |
|
labels: Optional["List|torch.Tensor"] = None |
|
attention_mask: Optional["List|torch.Tensor"] = None |
|
position_ids: Optional["List|torch.Tensor"] = None |
|
seqlens: Optional["List|torch.Tensor"] = None |
|
|
|
audios: Optional["List|torch.Tensor"] = None |
|
encoder_length: Optional["List|torch.Tensor"] = None |
|
bridge_length: Optional["List|torch.Tensor"] = None |
|
|
|
images: Optional["List|torch.Tensor"] = None |
|
patch_nums: Optional["List|torch.Tensor"] = None |
|
images_size: Optional["List|torch.Tensor"] = None |
|
crop_size: Optional["List|torch.Tensor"] = None |
|
images_grid: Optional["List|torch.Tensor"] = None |
|
|
|
videos: Optional["List|torch.Tensor"] = None |
|
videos_patch_nums: Optional["List|torch.Tensor"] = None |
|
videos_size: Optional["List|torch.Tensor"] = None |
|
videos_crop_size: Optional["List|torch.Tensor"] = None |
|
videos_grid: Optional["List|torch.Tensor"] = None |
|
|
|
raw_text: Optional[str] = None |
|
index: Optional[int] = None |
|
|
|
def concatenate(self, other): |
|
def concat_one(a, b): |
|
if a is None and b is None: |
|
return None |
|
elif a is None and b is not None: |
|
return b |
|
elif a is not None and b is None: |
|
return a |
|
else: |
|
return a + b |
|
return OmniProcessorOutput( |
|
input_ids=concat_one(self.input_ids, other.input_ids), |
|
labels=concat_one(self.labels, other.labels), |
|
audios=concat_one(self.audios, other.audios), |
|
encoder_length=concat_one(self.encoder_length, other.encoder_length), |
|
bridge_length=concat_one(self.bridge_length, other.bridge_length), |
|
images=concat_one(self.images, other.images), |
|
images_grid=concat_one(self.images_grid, other.images_grid), |
|
patch_nums=concat_one(self.patch_nums, other.patch_nums), |
|
|
|
videos=concat_one(self.videos, other.videos), |
|
videos_grid=concat_one(self.videos_grid, other.videos_grid), |
|
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums), |
|
|
|
position_ids=concat_one(self.position_ids, other.position_ids), |
|
seqlens=concat_one(self.seqlens, other.seqlens), |
|
images_size=concat_one(self.images_size, other.images_size), |
|
videos_size=concat_one(self.videos_size, other.videos_size), |
|
index = self.index |
|
) |
|
|
|
class OmniMMProcessor(object): |
|
def __init__(self, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
config, |
|
training, |
|
relative_path=None, |
|
parallel=None, |
|
**kwargs, |
|
): |
|
self.tokenizer = tokenizer |
|
self.config = config |
|
self.audio_processor = OmniAudioProcessor(config.audio_config) |
|
self.visual_processor = None |
|
if hasattr(config, "visual_config"): |
|
self.visual_processor = OmniImageProcessor(config.visual_config) |
|
self.video_processor = None |
|
if hasattr(config, "video_config"): |
|
self.video_processor = OmniImageProcessor(config.video_config) |
|
self.training = training |
|
self.relative_path = relative_path |
|
self.parallel = parallel |
|
|
|
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id) |
|
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id) |
|
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id) |
|
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id) |
|
self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id) |
|
self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id) |
|
|
|
self.image_start_tag = None |
|
self.image_end_tag = None |
|
self.image_pad_tag = None |
|
self.video_start_tag = None |
|
self.video_end_tag = None |
|
|
|
self.videoframe_start_tag = '<videoframe_start_omni>' |
|
self.videoframe_end_tag = '<videoframe_end_omni>' |
|
if hasattr(self.config, "visual_config"): |
|
|
|
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id) |
|
|
|
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id) |
|
|
|
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id) |
|
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id) |
|
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id) |
|
if hasattr(self.config, "video_config"): |
|
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id) |
|
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id) |
|
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id) |
|
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id) |
|
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id) |
|
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id) |
|
|
|
self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>') |
|
|
|
|
|
|
|
def _get_audio(self, audio_info): |
|
try: |
|
audio_info = ujson.loads(audio_info) |
|
if 'path' in audio_info.keys(): |
|
audio_uri = None |
|
if os.path.exists(audio_info['path']): |
|
audio_uri = audio_info['path'] |
|
elif self.relative_path is not None: |
|
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/')) |
|
if not os.path.exists(audio_uri): |
|
audio_uri = None |
|
if audio_uri is not None: |
|
waveform = self.audio_processor.load_audio_waveform(audio_uri, True) |
|
waveforms = self.audio_processor.split_with_overlap(waveform) |
|
|
|
ret = OmniProcessorOutput() |
|
for i, waveform in enumerate(waveforms): |
|
audio, input_length = self.audio_processor.extract_fbank_features(waveform) |
|
audio = self.audio_processor.data_augment(audio, input_length, self.training) |
|
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length) |
|
if bridge_length <= 0: |
|
continue |
|
current_ret = OmniProcessorOutput( |
|
audios=[audio[:,:input_length]], |
|
encoder_length=[encoder_length], |
|
bridge_length=[bridge_length], |
|
) |
|
if ret.audios is None: |
|
ret = current_ret |
|
else: |
|
ret = ret.concatenate(current_ret) |
|
return ret |
|
else: |
|
raise ValueError("can not find path in audio_info") |
|
except Exception as e: |
|
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info))) |
|
return OmniProcessorOutput() |
|
|
|
|
|
def _get_image(self, image_info): |
|
try: |
|
try: |
|
image_info = ujson.loads(image_info) |
|
except: |
|
image_info = re.sub(r"(?<!\\)'", '"', image_info) |
|
image_info = ujson.loads(image_info) |
|
if 'base64' in image_info.keys(): |
|
image_data = base64.b64decode(image_info['base64']) |
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data) |
|
elif 'local' in image_info.keys(): |
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local']) |
|
elif 'path' in image_info.keys() and os.path.exists(image_info['path']): |
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path']) |
|
elif 'url' in image_info.keys(): |
|
image_bytes = self._get_vision_obj_byte('url', image_info['url']) |
|
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes) |
|
else: |
|
raise ValueError("can not find any path in image_info") |
|
|
|
merge_length = self.visual_processor.merge_size**2 |
|
patch_nums = np.array(image_list).prod() // merge_length |
|
|
|
if org_size[0] * org_size[1] > 16**2: |
|
return OmniProcessorOutput( |
|
images=[image_feat], |
|
patch_nums=[patch_nums], |
|
crop_size=[image_list], |
|
images_size= [org_size], |
|
images_grid=[image_list] |
|
) |
|
else: |
|
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info))) |
|
return OmniProcessorOutput() |
|
|
|
except Exception as e: |
|
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info))) |
|
return OmniProcessorOutput() |
|
|
|
|
|
def _get_video_frame(self, video_frame_infos): |
|
try: |
|
pattern = r'\{.*?\}' |
|
matches = re.findall(pattern, video_frame_infos) |
|
ret = OmniProcessorOutput() |
|
|
|
for match in matches: |
|
video_frame_info = ujson.loads(match) |
|
|
|
if 'local' in video_frame_info.keys(): |
|
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local']) |
|
elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']): |
|
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path']) |
|
else: |
|
raise ValueError("can not find any path in video_info") |
|
|
|
merge_length = self.video_processor.merge_size**2 |
|
patch_nums = np.array(image_list).prod() // merge_length |
|
|
|
if org_size[0] * org_size[1] > 16**2: |
|
ret = ret.concatenate( |
|
OmniProcessorOutput( |
|
videos=[image_feat], |
|
videos_patch_nums=[patch_nums], |
|
videos_crop_size=[image_list], |
|
videos_size= [org_size], |
|
videos_grid=[image_list] |
|
) |
|
) |
|
else: |
|
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info))) |
|
return ret |
|
|
|
except Exception as e: |
|
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info))) |
|
return OmniProcessorOutput() |
|
|
|
|
|
def _get_vision_obj_byte(self, source, path): |
|
vision_obj_byte = None |
|
if source == "local": |
|
if os.path.exists(path): |
|
vision_obj_byte = open(path, "rb").read() |
|
else: |
|
vision_obj_byte = None |
|
if source == "base64": |
|
vision_obj_byte = base64.b64decode(path) |
|
if source == "url": |
|
vision_obj_byte = requests.get(url=path).content |
|
return vision_obj_byte |
|
|
|
|
|
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"): |
|
if decode_way=='1fps': |
|
frame_suffix = f'_frames' |
|
elif decode_way=='key': |
|
frame_suffix = f'_keyframes' |
|
else: |
|
raise ValueError('unvalid decode way!!!') |
|
|
|
server = "local" |
|
if 'local' in video_info.keys(): |
|
|
|
video_path = video_info['local'] |
|
|
|
frame_path = video_path.split('.')[0] + frame_suffix |
|
mm_obj_byte = self._get_vision_obj_byte('local', video_path) |
|
elif 'base64' in video_info.keys(): |
|
md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest() |
|
if self.relative_path is not None: |
|
video_path = os.path.join(self.relative_path, md5) |
|
else: |
|
video_path = os.path.join(os.getcwd(), md5) |
|
frame_path = md5 + frame_suffix |
|
mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64']) |
|
elif 'url' in video_info.keys(): |
|
md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest() |
|
if self.relative_path is not None: |
|
video_path = os.path.join(self.relative_path, md5) |
|
else: |
|
video_path = os.path.join(os.getcwd(), md5) |
|
frame_path = md5 + frame_suffix |
|
mm_obj_byte = self._get_vision_obj_byte('url', video_info['url']) |
|
else: |
|
raise ValueError('unvalid video server !!!') |
|
return "" |
|
|
|
if mm_obj_byte is None: |
|
return "" |
|
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0: |
|
|
|
os.makedirs(frame_path, exist_ok=True) |
|
frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) |
|
for frame_idx, frame in enumerate(frames): |
|
output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg") |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(output_filename, frame) |
|
frame_paths = os.listdir(frame_path) |
|
|
|
|
|
frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] |
|
frame_times.sort() |
|
frame_number = len(frame_times) |
|
if frame_number > max_frame_number: |
|
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int) |
|
else: |
|
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int) |
|
|
|
replace_str = "" |
|
for frame_idx, idx in enumerate(indices): |
|
frame_time = frame_times[idx] |
|
frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')} |
|
frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern |
|
frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) |
|
frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) |
|
frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}') |
|
replace_str += frame_str |
|
|
|
return replace_str |
|
|
|
def sample_frame(self,frames_str,max_frame = 32): |
|
def uniform_sample(lst, num_samples): |
|
if num_samples > len(lst): |
|
return lst |
|
interval = len(lst) / num_samples |
|
samples = [lst[int(i * interval)] for i in range(num_samples)] |
|
return samples |
|
p = rf'({self.image_start_tag}.*?{self.image_end_tag})' |
|
frames_str_split = re.split(p,frames_str) |
|
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]] |
|
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame)) |
|
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]]) |
|
|
|
def _get_video_frame_str(self, video_info): |
|
try: |
|
if self.videoframe_start_tag in video_info: |
|
frames_str = video_info |
|
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag) |
|
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num) |
|
video_info = ujson.loads(video_info) |
|
|
|
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way) |
|
return frames_str |
|
except Exception as e: |
|
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info))) |
|
return "" |
|
|
|
def _replace_image(self, image_text): |
|
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text) |
|
ret = self._get_image(image_info) |
|
if ret.patch_nums is None: |
|
return '' |
|
return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag |
|
|
|
def _replace_video_frame(self, video_frame_text): |
|
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text) |
|
ret = self._get_video_frame(video_frame_info) |
|
if ret.videos_patch_nums is None: |
|
return '' |
|
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))] |
|
return ret, ''.join(video_frame_str) |
|
|
|
|
|
def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'): |
|
|
|
if (self.audio_start_tag != None) and (mtype == 'audio'): |
|
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S) |
|
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S) |
|
elif (self.image_start_tag != None) and (mtype == 'image'): |
|
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S) |
|
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S) |
|
elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'): |
|
match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S) |
|
drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S) |
|
elif (self.video_start_tag != None) and (mtype == 'video'): |
|
match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S) |
|
drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S) |
|
else: |
|
raise ValueError("mtype not supportted!") |
|
new_text_list = [] |
|
new_mm_label_list = [] |
|
new_trainable_flag_list = [] |
|
for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list): |
|
for t,m in zip(*split_text(text, match_regex)): |
|
new_trainable_flag_list.append(trainable) |
|
if m: |
|
new_text_list.append(re.sub(drop_regex, '', t)) |
|
new_mm_label_list.append(mtype) |
|
else: |
|
new_text_list.append(t) |
|
new_mm_label_list.append(mm_label) |
|
return new_text_list, new_mm_label_list, new_trainable_flag_list |
|
|
|
def process_multimodal_chunk(self, text, mm_label, trainable): |
|
ret = OmniProcessorOutput() |
|
if mm_label == 'audio': |
|
ret = self._get_audio(text) |
|
if ret.bridge_length is not None: |
|
ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False) |
|
else: |
|
raise ValueError(f"Get audio data Failed at Process audio chunk {text}") |
|
elif mm_label == 'audiogen': |
|
ret = self._get_audio(text) |
|
if ret.bridge_length is not None: |
|
ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False) |
|
else: |
|
raise ValueError(f"Get audio data Failed at Process audio chunk {text}") |
|
elif mm_label == 'image': |
|
ret, input_str = self._replace_image(text) |
|
if input_str: |
|
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False) |
|
else: |
|
raise ValueError("Get image data Failed at Process image chunk") |
|
elif mm_label == 'video': |
|
frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag |
|
ret, input_str = self._replace_video_frame(frame_str) |
|
if input_str: |
|
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False) |
|
else: |
|
raise ValueError("Get video data Failed at Process video chunk") |
|
elif mm_label == 'text': |
|
ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False) |
|
if len(ret.input_ids) > self.tokenizer.model_max_length-1: |
|
raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】") |
|
else: |
|
raise ValueError(f"mm_label not supportted! must in ['audio', 'image', 'text'] but get {mm_label}") |
|
return ret |
|
|
|
def process_one(self, text, index=0, raw_only=False): |
|
ret = OmniProcessorOutput(index=index) |
|
all_text_list = [] |
|
all_mm_label_list = [] |
|
all_trainable_flag_list = [] |
|
text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S)) |
|
if len(text_list) == 1: |
|
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0]) |
|
all_text_list.append(text) |
|
all_mm_label_list.append('text') |
|
all_trainable_flag_list.append(True) |
|
else: |
|
for text, match in zip(text_list, match_flag): |
|
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text) |
|
if text.strip() == '': |
|
continue |
|
all_text_list.append(text) |
|
all_mm_label_list.append('text') |
|
all_trainable_flag_list.append(match) |
|
|
|
for mtype in self.config.multimodal: |
|
all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype) |
|
if len(all_text_list) == 0: |
|
print(f"Process {text} chunk error: No valid Text data!!!!!") |
|
return OmniProcessorOutput(index=index) |
|
|
|
for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list): |
|
try: |
|
mret = self.process_multimodal_chunk(text, mm_label, trainable) |
|
ret = ret.concatenate(mret) |
|
except ValueError as e: |
|
tt = text[:24].replace('\n','<LF>') |
|
print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}") |
|
return OmniProcessorOutput(index=index) |
|
|
|
if raw_only: |
|
ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False) |
|
return ret |
|
return ret |
|
|
|
@torch.no_grad() |
|
def __call__(self, example, parallel=128): |
|
if isinstance(example, Dict): |
|
pass |
|
elif isinstance(example, str): |
|
return self.process_one(example) |
|
elif isinstance(example, List): |
|
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor: |
|
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)] |
|
batch_data = [key.result() for key in cf.as_completed(future_list)] |
|
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data]) |
|
assert(valid_num == len(batch_data)) |
|
batch_data = sorted(batch_data, key=lambda x: x.index) |
|
|
|
ret = OmniProcessorOutput() |
|
for i in range(len(batch_data)): |
|
ret = ret.concatenate(batch_data[i]) |
|
self.tokenizer.padding_side = "left" |
|
max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length) |
|
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt') |
|
ret.input_ids = padding_result["input_ids"] |
|
ret.attention_mask = padding_result["attention_mask"] |
|
|
|
if ret.audios is not None: |
|
max_audios_len = max([x.shape[-1] for x in ret.audios]) |
|
ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios]) |
|
|
|
ret.encoder_length = default_collate(ret.encoder_length) |
|
ret.bridge_length = default_collate(ret.bridge_length) |
|
|
|
if ret.images is not None: |
|
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images] |
|
ret.patch_nums = default_collate(ret.patch_nums) |
|
|
|
if ret.videos is not None: |
|
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos] |
|
ret.videos_patch_nums = default_collate(ret.videos_patch_nums) |
|
|
|
return ret |
|
|
|
else: |
|
raise ValueError("example format supported yet") |
|
|