Baichuan-Omni-1d5-Base / processor_omni.py
lin5547's picture
Upload folder using huggingface_hub
2725f73 verified
raw
history blame
45.2 kB
import requests
import re, ujson, os, sys, fire, glob, random, time, json
import numpy as np
import io
import torch
from torch.utils.data import default_collate
import torchaudio
from typing import *
from dataclasses import dataclass, field
import transformers
from transformers.modeling_outputs import ModelOutput
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
from functools import lru_cache
from io import BytesIO
from PIL import Image
import concurrent.futures as cf
from transformers.image_transforms import resize, center_crop, get_resize_output_image_size
from transformers.image_utils import PILImageResampling
from PIL import Image, ImageOps
from PIL import ImageFile
torch.set_num_threads(1) # 限制torch的线程数 否则可能会卡住
ImageFile.LOAD_TRUNCATED_IMAGES = True
import base64
from decord import VideoReader, cpu
import cv2
import av
import imagesize
import tempfile
import math
from multiprocessing import Pool
from cairosvg import svg2png
import hashlib
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def split_text(text, match_regex):
matches = list(re.finditer(match_regex, text))
# 初始化结果列表
result = []
match_flag_list = []
# 上一个匹配的结束位置
last_end = 0
# 遍历所有匹配项
for match in matches:
# 添加匹配项之前的部分
if text[last_end:match.start()]:
result.append(text[last_end:match.start()])
match_flag_list.append(False)
# 添加匹配项
result.append(match.group(0))
match_flag_list.append(True)
# 更新上一个匹配的结束位置
last_end = match.end()
# 添加最后一个匹配项之后的部分
if text[last_end:]:
result.append(text[last_end:])
match_flag_list.append(False)
return result, match_flag_list
def read_video(image_path, max_frame_number, decode_way):
if decode_way=='1fps':
try:
# print(image_path)
vr = VideoReader(image_path, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, len(vr), fps)]
frames = vr.get_batch(frame_idx).asnumpy()
cnt = len(frames)
frame_times = range(cnt)
except Exception as e:
print(image_path)
print('error is', e)
return None
elif decode_way=='key':
try:
with av.open(image_path) as container:
stream = container.streams.video[0]
stream.codec_context.skip_frame = 'NONKEY'
frames = []
frame_times = []
fps = int(stream.average_rate)
cnt = 0
for frame in container.decode(stream): # 关键帧存成image patch
image = np.array(frame.to_image())
frames.append(image)
frame_time = int(frame.time)
frame_times.append(frame_time)
cnt += 1
except Exception as e:
print('error is', e)
return None
if frames is None or len(frames)==0:
return None
if len(frames)>max_frame_number and max_frame_number>0:
# 生成14个均匀间隔的索引
indices = np.linspace(0, len(frames) - 1, max_frame_number, dtype=int)
# 根据索引获取对应元素
frames = frames[indices]
frame_times = frame_times[indices]
return frames, frame_times
class OmniImageProcessor:
def __init__(self, config, **kwargs):
self.config = config # visual_config
self.min_pixels = self.config.min_pixels if hasattr(self.config, 'min_pixels') else 56 * 56
self.max_pixels = self.config.max_pixels if hasattr(self.config, 'max_pixels') else 28 * 28 * 1280
self.patch_size = self.config.patch_size if hasattr(self.config, 'patch_size') else 14
self.temporal_patch_size = self.config.temporal_patch_size if hasattr(self.config, 'temporal_patch_size') else 2
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
self.spatial_merge_size = self.config.spatial_merge_size if hasattr(self.config, 'spatial_merge_size') else 2
def image_transform(self, strseq, return_mm_data = True):
image = None
if isinstance(strseq, str):
if return_mm_data:
image = Image.open(strseq).convert("RGB")
else:
try:
image = Image.open(BytesIO(strseq)).convert("RGB")
except:
image = Image.open(BytesIO(svg2png(bytestring=strseq))).convert("RGB") # interleaved有的是矢量图,需要转换
image = np.array(image.convert("RGB")) # 这一步首先将图像转换为 RGB 格式,确保图像有三个通道(R、G、B)。然后使用 np.array() 将其转换为 NumPy 数组,方便后续处理。
image_org_size = image.shape[:2] # 这里保存了图像的原始大小(高度和宽度),image.shape 返回图像的形状 (高度, 宽度, 通道数),而 image.shape[:2] 提取了前两个值,即原始的高度和宽度。这个信息可以用于后续的对比或其他处理。
# resize, crop, scale, normalize
# 输出一个新的尺寸,这个尺寸通常是 (宽度, 高度) 格式,用于后续的图像调整操作,如缩放或裁剪。
resized_height, resized_width = smart_resize(
image_org_size[0], image_org_size[1],
factor=self.patch_size * self.spatial_merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
output_size = (resized_height, resized_width)
# 使用 resize 函数将图像调整到 output_size 大小。PILImageResampling.BICUBIC 指定使用双三次插值法来进行图像缩放,这种方法通常能够提供较好的图像质量。
# image: 输入的图像数据,可以是 NumPy 数组或 PIL 图像对象;output_size: 目标大小,通常是一个二元组 (宽度, 高度)。这个尺寸可以是图像的绝对大小,也可以是相对于原始图像的比例;
# resample: 可选的重采样方法,通常用于确定如何插值像素。例如,PILImageResampling.BICUBIC 表示使用双三次插值法,这是一种平滑的插值方法,常用于图像缩放。
image = resize(image, output_size, PILImageResampling.BICUBIC)
img = image.transpose(2, 0, 1)
# 对图像进行归一化和标准化处理
image = (img / 255.0 - np.array(self.config.image_mean)[:, np.newaxis, np.newaxis]) / np.array(self.config.image_std)[:,np.newaxis,np.newaxis]
# 处理成patch
patches = image[np.newaxis, :]
if patches.shape[0] == 1:
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.spatial_merge_size,
self.spatial_merge_size,
self.patch_size,
grid_w // self.spatial_merge_size,
self.spatial_merge_size,
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)
return flatten_patches, image_org_size, (grid_t, grid_h, grid_w)
class OmniAudioProcessor:
# 包含基本的音频特征抽取模块 + 输入数据解析模块
def __init__(
self,
config, # audio processor config
**kwargs
):
# make sure you have install 'conda install -c conda-forge 'ffmpeg<7'' for torchaudio
assert(len(torchaudio.list_audio_backends()) > 0)
self.config = config
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.n_fft // 2,
num_mel_filters=self.config.num_mel_bins,
min_frequency=0.0,
max_frequency=self.config.sampling_rate / 2.0,
sampling_rate=self.config.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
self.window = torch.hann_window(self.config.n_fft)
@staticmethod
def dynamic_range_compression(x, C=1, clip_val=1e-6):
return torch.log(torch.clamp(x, min=clip_val) * C)
@staticmethod
def zero_mean_unit_var_norm(x):
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
def load_audio_waveform(self, uri, return_tensors=True, do_normalize=False):
metadata = torchaudio.info(uri) # sample_rate, num_frames, num_channels, bits_per_sample, encoding=PCM_S
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
waveform_tensor, _ = torchaudio.load(uri, normalize=True)
if self.config.sampling_rate != metadata.sample_rate:
waveform_tensor = torchaudio.functional.resample(waveform_tensor, metadata.sample_rate, self.config.sampling_rate, lowpass_filter_width=128)
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
if metadata.num_channels > 1:
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
# normalized to zero mean
if do_normalize:
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
if return_tensors: # (channels, samples)
return waveform_tensor
else:
return waveform_tensor.numpy()
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
channels, wave_samples = waveform.shape
max_audio_samples = self.config.max_audio_seconds * self.config.sampling_rate
if wave_samples <= max_audio_samples or self.config.split_overlap < 0:
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
split_waveform, start = [], 0
while start < wave_samples: # 统一按秒数对齐overlap
if start > int(self.config.sampling_rate * self.config.split_overlap):
start -= int(self.config.sampling_rate * self.config.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
end = min(start + max_audio_samples, wave_samples)
if end - start>= self.config.n_fft: # 保证至少有一帧数据
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
start = end
return split_waveform
@classmethod
def inference_output_length(cls, config, input_length):
# for whisper + bridge
kernel_size = config.kernel_size
stride_size = config.stride_size
avg_pooler = config.avg_pooler
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
if avg_pooler > 1:
bridge_length = encoder_length // avg_pooler
return encoder_length, bridge_length
def extract_fbank_features(self, waveform):
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
channels, wave_samples = waveform.shape
assert(wave_samples >= self.config.n_fft)
valid_frame_nums = min(self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length, wave_samples // self.config.hop_length + 1)
if wave_samples < self.config.max_audio_seconds * self.config.sampling_rate:
waveform = torch.nn.functional.pad(waveform, (0, self.config.max_audio_seconds * self.config.sampling_rate - wave_samples), "constant", 0)
else:
waveform = waveform[:, :self.config.max_audio_seconds * self.config.sampling_rate]
# window = torch.hann_window(self.config.n_fft)
stft = torch.stft(waveform, self.config.n_fft, self.config.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
log_spec[:, valid_frame_nums:] = 0.0 # pad0
return log_spec, valid_frame_nums
def data_augment(self, feature: np.array, input_length, training=True):
# reference https://arxiv.org/pdf/1904.08779
def mask_start_indices(input_length, mask_length, min_masks, mask_prob):
num_masked_span = int(mask_prob * input_length / mask_length + random.random())
num_masked_span = max(num_masked_span, min_masks)
start_indices = list(range(input_length - mask_length))
random.shuffle(start_indices)
start_indices = start_indices[:num_masked_span]
return start_indices
if not training or (self.config.mask_time_prob <= 0 and self.config.mask_feature_prob <= 0):
return feature
if input_length < self.config.mask_time_length * self.config.mask_time_min_masks + 1:
return feature
if self.config.num_mel_bins < self.config.mask_feature_length * self.config.mask_feature_min_masks + 1:
return feature
if self.config.mask_time_prob > 0:
start_indices = mask_start_indices(input_length, self.config.mask_time_length, self.config.mask_time_min_masks, self.config.mask_time_prob)
for start_idx in start_indices:
feature[:, start_idx: start_idx + self.config.mask_time_length] = 0.0
if self.config.mask_feature_prob > 0:
start_indices = mask_start_indices(self.config.num_mel_bins, self.config.mask_feature_length, self.config.mask_feature_min_masks, self.config.mask_feature_prob)
for start_idx in start_indices:
feature[start_idx: start_idx + self.config.mask_feature_length, :] = 0.0
return feature
@dataclass
class OmniProcessorOutput(ModelOutput):
input_ids: Optional["List|torch.Tensor"] = None
labels: Optional["List|torch.Tensor"] = None
attention_mask: Optional["List|torch.Tensor"] = None
position_ids: Optional["List|torch.Tensor"] = None
seqlens: Optional["List|torch.Tensor"] = None # 需要配合Omni Modeling使用
# audio fields
audios: Optional["List|torch.Tensor"] = None
encoder_length: Optional["List|torch.Tensor"] = None
bridge_length: Optional["List|torch.Tensor"] = None
# image fields
images: Optional["List|torch.Tensor"] = None
patch_nums: Optional["List|torch.Tensor"] = None
images_size: Optional["List|torch.Tensor"] = None
crop_size: Optional["List|torch.Tensor"] = None
images_grid: Optional["List|torch.Tensor"] = None
# video fields
videos: Optional["List|torch.Tensor"] = None
videos_patch_nums: Optional["List|torch.Tensor"] = None
videos_size: Optional["List|torch.Tensor"] = None
videos_crop_size: Optional["List|torch.Tensor"] = None
videos_grid: Optional["List|torch.Tensor"] = None
# processor fields
raw_text: Optional[str] = None
index: Optional[int] = None
def concatenate(self, other): # 仅限list使用
def concat_one(a, b):
if a is None and b is None:
return None
elif a is None and b is not None:
return b
elif a is not None and b is None:
return a
else:
return a + b
return OmniProcessorOutput(
input_ids=concat_one(self.input_ids, other.input_ids),
labels=concat_one(self.labels, other.labels),
audios=concat_one(self.audios, other.audios),
encoder_length=concat_one(self.encoder_length, other.encoder_length),
bridge_length=concat_one(self.bridge_length, other.bridge_length),
images=concat_one(self.images, other.images),
images_grid=concat_one(self.images_grid, other.images_grid),
patch_nums=concat_one(self.patch_nums, other.patch_nums),
videos=concat_one(self.videos, other.videos),
videos_grid=concat_one(self.videos_grid, other.videos_grid),
videos_patch_nums=concat_one(self.videos_patch_nums, other.videos_patch_nums),
position_ids=concat_one(self.position_ids, other.position_ids),
seqlens=concat_one(self.seqlens, other.seqlens),
images_size=concat_one(self.images_size, other.images_size),
videos_size=concat_one(self.videos_size, other.videos_size),
index = self.index # concat保持index不变
)
class OmniMMProcessor(object):
def __init__(self,
tokenizer: transformers.PreTrainedTokenizer,
config,
training,
relative_path=None,
parallel=None,
**kwargs,
):
self.tokenizer = tokenizer
self.config = config
self.audio_processor = OmniAudioProcessor(config.audio_config)
self.visual_processor = None
if hasattr(config, "visual_config"):
self.visual_processor = OmniImageProcessor(config.visual_config)
self.video_processor = None
if hasattr(config, "video_config"):
self.video_processor = OmniImageProcessor(config.video_config)
self.training = training
self.relative_path = relative_path
self.parallel = parallel
# audio tag
self.audio_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_start_token_id)
self.audio_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_end_token_id)
self.audio_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_pad_token_id)
self.audio_delim_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audio_delim_token_id)
self.audiogen_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_start_token_id)
self.audiogen_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.audio_config.audiogen_end_token_id)
# image tag
self.image_start_tag = None
self.image_end_tag = None
self.image_pad_tag = None
self.video_start_tag = None
self.video_end_tag = None
# videoframe tag只是为了兼容图片帧作为输入的情况,没有token id,在抽取视频帧的时候,会将这个替换成image tag的start、end
self.videoframe_start_tag = '<videoframe_start_omni>'
self.videoframe_end_tag = '<videoframe_end_omni>'
if hasattr(self.config, "visual_config"):
# special token for start_tag
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_start_token_id)
# special token for end_tag
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_end_token_id)
# special token for pad_tag
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_pad_token_id)
self.image_line_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_line_token_id)
self.image_delimiter_tag = self.tokenizer.convert_ids_to_tokens(self.config.visual_config.image_delimiter_token_id)
if hasattr(self.config, "video_config"):
self.video_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_start_token_id)
self.video_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_end_token_id)
self.image_start_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_start_token_id)
self.image_end_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_end_token_id)
self.image_pad_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.image_pad_token_id)
self.video_place_tag = self.tokenizer.convert_ids_to_tokens(self.config.video_config.video_place_token_id)
self.frame_pattern = getattr(self.config.video_config, 'frame_pattern', '<frame>')
# @lru_cache(maxsize=1024)
def _get_audio(self, audio_info):
try:
audio_info = ujson.loads(audio_info)
if 'path' in audio_info.keys():
audio_uri = None
if os.path.exists(audio_info['path']):
audio_uri = audio_info['path']
elif self.relative_path is not None:
audio_uri = os.path.join(self.relative_path, audio_info['path'].lstrip('/'))
if not os.path.exists(audio_uri):
audio_uri = None
if audio_uri is not None:
waveform = self.audio_processor.load_audio_waveform(audio_uri, True)
waveforms = self.audio_processor.split_with_overlap(waveform)
ret = OmniProcessorOutput() # 默认初始化 audios字段为None
for i, waveform in enumerate(waveforms): #(zip(waveforms,vocoder_waveforms)):
audio, input_length = self.audio_processor.extract_fbank_features(waveform)
audio = self.audio_processor.data_augment(audio, input_length, self.training)
encoder_length, bridge_length = self.audio_processor.inference_output_length(self.config.audio_config, input_length)
if bridge_length <= 0:
continue
current_ret = OmniProcessorOutput(
audios=[audio[:,:input_length]],
encoder_length=[encoder_length],
bridge_length=[bridge_length],
)
if ret.audios is None:
ret = current_ret
else:
ret = ret.concatenate(current_ret) # 拼接多个切片
return ret
else:
raise ValueError("can not find path in audio_info")
except Exception as e:
print("**** get audio error: {}, info: {} *****".format(str(e), str(audio_info)))
return OmniProcessorOutput()
# @lru_cache(maxsize=1024)
def _get_image(self, image_info):
try:
try:
image_info = ujson.loads(image_info)
except:
image_info = re.sub(r"(?<!\\)'", '"', image_info)
image_info = ujson.loads(image_info)
if 'base64' in image_info.keys():
image_data = base64.b64decode(image_info['base64'])
image_feat, org_size, image_list = self.visual_processor.image_transform(image_data)
elif 'local' in image_info.keys():
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['local'])
elif 'path' in image_info.keys() and os.path.exists(image_info['path']):
image_feat, org_size, image_list = self.visual_processor.image_transform(image_info['path'])
elif 'url' in image_info.keys():
image_bytes = self._get_vision_obj_byte('url', image_info['url'])
image_feat, org_size, image_list = self.visual_processor.image_transform(image_bytes)
else:
raise ValueError("can not find any path in image_info")
merge_length = self.visual_processor.merge_size**2
patch_nums = np.array(image_list).prod() // merge_length
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
return OmniProcessorOutput(
images=[image_feat],
patch_nums=[patch_nums],
crop_size=[image_list],
images_size= [org_size],
images_grid=[image_list]
)
else:
print("**** image too small: {}, info: {} *****".format(str(org_size), str(image_info)))
return OmniProcessorOutput()
except Exception as e:
print("**** get image error: {}, info: {} *****".format(str(e), str(image_info)))
return OmniProcessorOutput()
# @lru_cache(maxsize=1024)
def _get_video_frame(self, video_frame_infos):
try:
pattern = r'\{.*?\}'
matches = re.findall(pattern, video_frame_infos)
ret = OmniProcessorOutput()
# 逐个解析
for match in matches:
video_frame_info = ujson.loads(match)
# video_frame_info = ujson.loads(video_frame_info)
if 'local' in video_frame_info.keys():
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['local'])
elif 'path' in video_frame_info.keys() and os.path.exists(video_frame_info['path']):
image_feat, org_size, image_list = self.video_processor.image_transform(video_frame_info['path'])
else:
raise ValueError("can not find any path in video_info")
merge_length = self.video_processor.merge_size**2
patch_nums = np.array(image_list).prod() // merge_length
if org_size[0] * org_size[1] > 16**2: # 极端小的图过滤
ret = ret.concatenate(
OmniProcessorOutput(
videos=[image_feat],
videos_patch_nums=[patch_nums],
videos_crop_size=[image_list],
videos_size= [org_size],
videos_grid=[image_list]
)
)
else:
print("**** video too small: {}, info: {} *****".format(str(org_size), str(video_frame_info)))
return ret
except Exception as e:
print("**** get video error: {}, info: {} *****".format(str(e), str(video_frame_info)))
return OmniProcessorOutput()
# 读取视频
def _get_vision_obj_byte(self, source, path):
vision_obj_byte = None
if source == "local":
if os.path.exists(path):
vision_obj_byte = open(path, "rb").read()
else:
vision_obj_byte = None
if source == "base64":
vision_obj_byte = base64.b64decode(path)
if source == "url":
vision_obj_byte = requests.get(url=path).content
return vision_obj_byte
# 将视频切分为帧,保存至子目录中
def _split_video_to_frames(self, video_info, max_frame_number=-1, decode_way="1fps"):
if decode_way=='1fps':
frame_suffix = f'_frames'
elif decode_way=='key':
frame_suffix = f'_keyframes'
else:
raise ValueError('unvalid decode way!!!')
server = "local"
if 'local' in video_info.keys():
# 本地路径
video_path = video_info['local']
# 帧保存本地路径
frame_path = video_path.split('.')[0] + frame_suffix
mm_obj_byte = self._get_vision_obj_byte('local', video_path)
elif 'base64' in video_info.keys():
md5 = hashlib.md5(video_info['base64'].encode('utf-8')).hexdigest()
if self.relative_path is not None:
video_path = os.path.join(self.relative_path, md5)
else:
video_path = os.path.join(os.getcwd(), md5)
frame_path = md5 + frame_suffix
mm_obj_byte = self._get_vision_obj_byte('base64', video_info['base64'])
elif 'url' in video_info.keys():
md5 = hashlib.md5(video_info['url'].encode('utf-8')).hexdigest()
if self.relative_path is not None:
video_path = os.path.join(self.relative_path, md5)
else:
video_path = os.path.join(os.getcwd(), md5)
frame_path = md5 + frame_suffix
mm_obj_byte = self._get_vision_obj_byte('url', video_info['url'])
else:
raise ValueError('unvalid video server !!!')
return ""
if mm_obj_byte is None: # 未读取到视频文件
return ""
if not os.path.exists(frame_path) or len(os.listdir(frame_path))==0:
# 保存帧
os.makedirs(frame_path, exist_ok=True)
frames, frame_times = read_video(io.BytesIO(mm_obj_byte), max_frame_number=-1, decode_way=decode_way) #读取全部帧
for frame_idx, frame in enumerate(frames):
output_filename = os.path.join(frame_path, f"{frame_times[frame_idx]}.jpg")
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_filename, frame)
frame_paths = os.listdir(frame_path)
# 选取帧
frame_times = [int(filename.split('/')[-1].replace('.jpg', '')) for filename in frame_paths if filename.endswith('.jpg')] # 文件名对应秒数
frame_times.sort() #从小到大排序
frame_number = len(frame_times)
if frame_number > max_frame_number:
indices = np.linspace(0, frame_number - 1, max_frame_number, dtype=int)
else:
indices = np.linspace(0, frame_number - 1, frame_number, dtype=int)
# 拼接模式
replace_str = ""
for frame_idx, idx in enumerate(indices):
frame_time = frame_times[idx] # frame_time表示帧对应的时间 单位为s 同时也是存储的文件名
frame_dict = {"local": os.path.join(frame_path, f'{frame_time}.jpg')}
frame_str = self.frame_pattern.format(frame_idx) if '{}' in self.frame_pattern else self.frame_pattern # {}对应的是第几张图片
frame_str = frame_str.replace('<TIMEIDX>', str(frame_time)) # TIMEIDX对应的是第几秒
frame_str = frame_str.replace('<TIMESTAMP>', time.strftime("%H:%M:%S", time.gmtime(frame_time))) # TIMESTAMP对应的是时间戳
frame_str = frame_str.replace('<frame>', f'{self.image_start_tag}{json.dumps(frame_dict)}{self.image_end_tag}')
replace_str += frame_str
return replace_str
def sample_frame(self,frames_str,max_frame = 32):
def uniform_sample(lst, num_samples):
if num_samples > len(lst):
return lst
interval = len(lst) / num_samples
samples = [lst[int(i * interval)] for i in range(num_samples)]
return samples
p = rf'({self.image_start_tag}.*?{self.image_end_tag})'
frames_str_split = re.split(p,frames_str)
frame_idxs = [idx for idx in range(len(frames_str_split)) if self.image_start_tag in frames_str_split[idx]]
sample_frame_idxs = set(uniform_sample(frame_idxs, max_frame))
return ''.join([item for idx,item in enumerate(frames_str_split) if idx in sample_frame_idxs or self.image_start_tag not in frames_str_split[idx]])
def _get_video_frame_str(self, video_info):
try:
if self.videoframe_start_tag in video_info:#如果是以视频帧的形式表示一个视频,则替换成image tag
frames_str = video_info
frames_str = frames_str.replace(self.videoframe_start_tag,self.image_start_tag).replace(self.videoframe_end_tag,self.image_end_tag)
return self.sample_frame(frames_str, max_frame = self.config.video_config.max_frame_num)
video_info = ujson.loads(video_info)
# 获取包含多帧图像路径的字符串,最大帧数量max_frame_number
frames_str = self._split_video_to_frames(video_info, max_frame_number=self.config.video_config.max_frame_num, decode_way=self.config.video_config.decode_way)
return frames_str
except Exception as e:
print("**** get video error: {}, info: {} *****".format(str(e), str(video_info)))
return ""
def _replace_image(self, image_text):
image_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', image_text)
ret = self._get_image(image_info) # 重复取结果 cached result
if ret.patch_nums is None:
return ''
return ret, self.image_start_tag + self.image_pad_tag * ret.patch_nums[0] + self.image_end_tag
def _replace_video_frame(self, video_frame_text):
video_frame_info = re.sub(re.compile(self.image_start_tag + "|" + self.image_end_tag), '', video_frame_text)
ret = self._get_video_frame(video_frame_info) # 重复取结果 cached result
if ret.videos_patch_nums is None:
return ''
video_frame_str = [self.image_start_tag + self.video_place_tag * ret.videos_patch_nums[i] + self.image_end_tag for i in range(len(ret.videos_patch_nums))]
return ret, ''.join(video_frame_str)
def split_multimodal_chunk(self, text_list, mm_label_list, trainable_list, mtype='audio'):
# 抽取text中的json格式音频/图像信息,读取并转化为特征,同时估计encoder token数,填入对应数量的pad token
if (self.audio_start_tag != None) and (mtype == 'audio'):
match_regex = re.compile(self.audio_start_tag + '.*?' + self.audio_end_tag,re.S)
drop_regex = re.compile(self.audio_start_tag + "|" + self.audio_end_tag,re.S)
elif (self.image_start_tag != None) and (mtype == 'image'):
match_regex = re.compile(self.image_start_tag + '.*?' + self.image_end_tag,re.S)
drop_regex = re.compile(self.image_start_tag + "|" + self.image_end_tag,re.S)
elif (self.audiogen_start_tag != None) and (mtype == 'audiogen'):
match_regex = re.compile(self.audiogen_start_tag + '.*?' + self.audiogen_end_tag,re.S)
drop_regex = re.compile(self.audiogen_start_tag + "|" + self.audiogen_end_tag,re.S)
elif (self.video_start_tag != None) and (mtype == 'video'):
match_regex = re.compile(self.video_start_tag + '.*?' + self.video_end_tag,re.S)
drop_regex = re.compile(self.video_start_tag + "|" + self.video_end_tag,re.S)
else:
raise ValueError("mtype not supportted!")
new_text_list = []
new_mm_label_list = []
new_trainable_flag_list = []
for text,mm_label,trainable in zip(text_list,mm_label_list,trainable_list):
for t,m in zip(*split_text(text, match_regex)):
new_trainable_flag_list.append(trainable)
if m:
new_text_list.append(re.sub(drop_regex, '', t))
new_mm_label_list.append(mtype)
else:
new_text_list.append(t)
new_mm_label_list.append(mm_label)
return new_text_list, new_mm_label_list, new_trainable_flag_list
def process_multimodal_chunk(self, text, mm_label, trainable):
ret = OmniProcessorOutput()
if mm_label == 'audio':
ret = self._get_audio(text)
if ret.bridge_length is not None:
ret.input_ids = self.tokenizer.encode(self.audio_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audio_end_tag,add_special_tokens=False)
else:
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
elif mm_label == 'audiogen':
ret = self._get_audio(text)
if ret.bridge_length is not None:
ret.input_ids = self.tokenizer.encode(self.audiogen_start_tag,add_special_tokens=False) + self.tokenizer.encode(self.audio_pad_tag,add_special_tokens=False) * sum(ret.bridge_length) + self.tokenizer.encode(self.audiogen_end_tag,add_special_tokens=False)
else:
raise ValueError(f"Get audio data Failed at Process audio chunk {text}")
elif mm_label == 'image':
ret, input_str = self._replace_image(text)
if input_str:
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
else:
raise ValueError("Get image data Failed at Process image chunk")
elif mm_label == 'video':
frame_str = self.video_start_tag+self._get_video_frame_str(text)+self.video_end_tag
ret, input_str = self._replace_video_frame(frame_str)
if input_str:
ret.input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
else:
raise ValueError("Get video data Failed at Process video chunk")
elif mm_label == 'text':
ret.input_ids = self.tokenizer.encode(text, add_special_tokens=False)
if len(ret.input_ids) > self.tokenizer.model_max_length-1: # 过滤长文本
raise ValueError(f"Text too long, please check text length! 【{text[:5]+'...'*6+text[-5:]}】")
else:
raise ValueError(f"mm_label not supportted! must in ['audio', 'image', 'text'] but get {mm_label}")
return ret
def process_one(self, text, index=0, raw_only=False):
ret = OmniProcessorOutput(index=index)
all_text_list = []
all_mm_label_list = []
all_trainable_flag_list = []
text_list, match_flag = split_text(text, re.compile("<trainable_start>.*?<trainable_end>",re.S))
if len(text_list) == 1:
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text_list[0])
all_text_list.append(text)
all_mm_label_list.append('text')
all_trainable_flag_list.append(True)
else:
for text, match in zip(text_list, match_flag):
text = re.sub(re.compile("<trainable_start>|<trainable_end>",re.S), '', text)
if text.strip() == '':
continue # 把多余的空格干掉
all_text_list.append(text)
all_mm_label_list.append('text')
all_trainable_flag_list.append(match)
# 处理多模态信息
for mtype in self.config.multimodal: # 循环获取音频 图像结果
all_text_list, all_mm_label_list, all_trainable_flag_list = self.split_multimodal_chunk(all_text_list, all_mm_label_list, all_trainable_flag_list, mtype)
if len(all_text_list) == 0:
print(f"Process {text} chunk error: No valid Text data!!!!!")
return OmniProcessorOutput(index=index)
for text, mm_label, trainable in zip(all_text_list, all_mm_label_list, all_trainable_flag_list):
try:
mret = self.process_multimodal_chunk(text, mm_label, trainable)
ret = ret.concatenate(mret)
except ValueError as e:
tt = text[:24].replace('\n','<LF>')
print(f"Process {tt if mm_label == 'text' else text} {mm_label} chunk error: {str(e)}")
return OmniProcessorOutput(index=index)
if raw_only:
ret.raw_text = self.tokenizer.decode(ret.input_ids, skip_special_tokens=False)
return ret
return ret
@torch.no_grad()
def __call__(self, example, parallel=128):
if isinstance(example, Dict):
pass
elif isinstance(example, str):
return self.process_one(example)
elif isinstance(example, List): # batch推理 异步多线程处理
with cf.ThreadPoolExecutor(min(parallel, len(example))) as executor:
future_list = [executor.submit(self.process_one, di, idx) for idx, di in enumerate(example)]
batch_data = [key.result() for key in cf.as_completed(future_list)]
valid_num = sum([1 if x.input_ids is not None else 0 for x in batch_data])
assert(valid_num == len(batch_data)) # 推理数据严格要求数量对齐
batch_data = sorted(batch_data, key=lambda x: x.index) # 保证顺序不变
ret = OmniProcessorOutput()
for i in range(len(batch_data)):
ret = ret.concatenate(batch_data[i])
self.tokenizer.padding_side = "left"
max_len = min(max([len(x.input_ids) for x in batch_data]),self.tokenizer.model_max_length)
padding_result = self.tokenizer.pad({"input_ids": [r.input_ids for r in batch_data]}, return_tensors='pt')
ret.input_ids = padding_result["input_ids"]
ret.attention_mask = padding_result["attention_mask"] # batch推理不pack 不需要seqlens
if ret.audios is not None:
max_audios_len = max([x.shape[-1] for x in ret.audios])
ret.audios = default_collate([np.pad(x, ((0,0),(0,max_audios_len - x.shape[-1])), 'constant', constant_values=0) for x in ret.audios])
ret.encoder_length = default_collate(ret.encoder_length)
ret.bridge_length = default_collate(ret.bridge_length)
if ret.images is not None:
ret.images = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.images]
ret.patch_nums = default_collate(ret.patch_nums)
if ret.videos is not None:
ret.videos = [torch.from_numpy(np.asarray(image, dtype=np.float32)) for image in ret.videos]
ret.videos_patch_nums = default_collate(ret.videos_patch_nums)
return ret
else:
raise ValueError("example format supported yet")