| |
| |
| |
| |
| |
| import warnings |
| from typing import Any, List, Optional, Tuple, Union |
| import re |
| import json |
| import math |
| import librosa |
| import numpy as np |
| from PIL import Image |
| from decord import VideoReader, cpu |
| from torch import nn |
| import torch |
| import torchvision.transforms as T |
| from torchvision.transforms.functional import InterpolationMode |
| from transformers import (GenerationConfig, Qwen3ForCausalLM, WhisperFeatureExtractor) |
| from transformers.modeling_utils import PreTrainedModel |
| import onnxruntime |
| import torchaudio.compliance.kaldi as kaldi |
| import torchaudio |
| from transformers.utils.hub import cached_file |
|
|
| from .configuration_interactiveomni import InteractiveOmniConfig |
| from .modeling_intern_vit import InternVisionModel |
| from .modeling_whisper import AudioWhisperModel |
| from .modeling_voicelm import VoiceLM |
| from .conversation import get_conv_template |
|
|
| from .modeling_flow import CausalMaskedDiffWithXvec |
| from .modeling_hifigan import HiFTGenerator |
|
|
| import logging |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
| IMG_START_TOKEN = '<img>' |
| IMG_END_TOKEN = '</img>' |
| IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' |
| AUDIO_START_TOKEN = '<audio>' |
| AUDIO_END_TOKEN = '</audio>' |
| AUDIO_CONTEXT_TOKEN = '<AUDIO_CONTEXT>' |
|
|
|
|
| class InteractiveOmniModel(PreTrainedModel): |
| config_class = InteractiveOmniConfig |
| main_input_name = 'pixel_values' |
| base_model_prefix = 'language_model' |
| _no_split_modules = ['InternVisionModel', 'AudioWhisperModel', 'Qwen3DecoderLayer', 'Qwen2DecoderLayer'] |
|
|
| def __init__(self, config: InteractiveOmniConfig, vision_model=None, language_model=None, audio_model=None): |
| super().__init__(config) |
|
|
| image_size = config.force_image_size or config.vision_config.image_size |
| patch_size = config.vision_config.patch_size |
| self.patch_size = patch_size |
| self.select_layer = config.select_layer |
| self.template = config.template |
| self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) |
| self.downsample_ratio = config.downsample_ratio |
| self.ps_version = config.ps_version |
| self.audio_feature_extractor = WhisperFeatureExtractor(**config.audio_preprocessor_config) |
| self.transform = self.build_transform(input_size=image_size) |
|
|
| self.campplus_session = None |
| self.default_speaker_embedding = None |
| self.default_wav_path = None |
|
|
| logger.info(f'num_image_token: {self.num_image_token}') |
| logger.info(f'ps_version: {self.ps_version}') |
| if vision_model is not None: |
| self.vision_model = vision_model |
| else: |
| self.vision_model = InternVisionModel(config.vision_config) |
| if audio_model is not None: |
| self.audio_model = audio_model |
| else: |
| self.audio_model = AudioWhisperModel(config.audio_config) |
| if language_model is not None: |
| self.language_model = language_model |
| else: |
| self.language_model = Qwen3ForCausalLM(config.llm_config) |
|
|
| self.voicelm_model = VoiceLM(config.voicelm_config) |
| self.flow_model = CausalMaskedDiffWithXvec(config.flow_config).float() |
| self.hifigan_model = HiFTGenerator(config.hifigan_config).float() |
|
|
| vit_hidden_size = config.vision_config.hidden_size |
| audio_hidden_size = config.audio_config.d_model |
| llm_hidden_size = config.llm_config.hidden_size |
|
|
| self.mlp1 = nn.Sequential( |
| nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), |
| nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, llm_hidden_size) |
| ) |
| self.mlp2 = nn.Sequential( |
| nn.LayerNorm(audio_hidden_size), |
| nn.Linear(audio_hidden_size, llm_hidden_size), |
| nn.GELU(), |
| nn.Linear(llm_hidden_size, llm_hidden_size) |
| ) |
| |
| self.mlp_llm2voicelm = nn.Sequential( |
| nn.LayerNorm(llm_hidden_size), |
| nn.Linear(llm_hidden_size, config.voicelm_config.llm_input_size), |
| nn.GELU(), |
| nn.Linear(config.voicelm_config.llm_input_size, config.voicelm_config.llm_input_size) |
| ) |
| self.gate = nn.Sequential( |
| nn.Linear(2 * llm_hidden_size, llm_hidden_size), |
| nn.Sigmoid() |
| ) |
|
|
| self.img_context_token_id = None |
| self.audio_context_token_id = None |
| self.neftune_alpha = None |
|
|
| self.post_init() |
| pass |
| |
| def fusion(self, rep, emb): |
| gate = self.gate(torch.cat([rep, emb], dim=-1)) |
| return rep * gate + emb * (1 - gate) |
| |
| def __load_campplus_session(self, campplus_path:str): |
| '''''' |
| logger.info(f"load campplus session: {campplus_path}") |
| option = onnxruntime.SessionOptions() |
| option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
| option.intra_op_num_threads = 1 |
| campplus_session = onnxruntime.InferenceSession( |
| campplus_path, |
| sess_options=option, |
| providers=["CPUExecutionProvider"], |
| ) |
| self.campplus_session = campplus_session |
| return campplus_session |
|
|
| def extract_speaker_embedding(self, prompt_wav:str): |
| '''extract speaker embedding tensor''' |
| logger.info(f"extract speaker embedding: {prompt_wav}") |
| target_sr = 16000 |
| prompt_speech_16k, sample_rate = torchaudio.load(prompt_wav) |
| prompt_speech_16k = prompt_speech_16k.mean(dim=0, keepdim=True) |
| if sample_rate != target_sr: |
| assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) |
| prompt_speech_16k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(prompt_speech_16k) |
|
|
| feat = kaldi.fbank( |
| prompt_speech_16k, |
| num_mel_bins=80, |
| dither=0, |
| sample_frequency=target_sr, |
| ) |
| feat = feat - feat.mean(dim=0, keepdim=True) |
| speaker_embedding = self.campplus_session.run( |
| None, |
| {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}, |
| )[0].flatten().tolist() |
| speaker_embedding = torch.tensor([speaker_embedding]) |
| return speaker_embedding |
|
|
| def build_transform(self, input_size): |
| MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
| transform = T.Compose([ |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
| T.ToTensor(), |
| T.Normalize(mean=MEAN, std=STD) |
| ]) |
|
|
| return transform |
| |
| def find_closest_aspect_ratio(self, image, min_num=1, max_num=6, image_size=448): |
| assert min_num == 1 |
| original_width, original_height = image.size |
| log_ratio = math.log(original_width / original_height) |
| ratio = original_width * original_height / (image_size * image_size) |
| multiple = min(math.ceil(ratio), max_num) |
| if multiple <= 1: |
| return [1, 1] |
| candidate_split_grids_nums = [] |
| for i in [multiple - 1, multiple, multiple + 1]: |
| if i > max_num: |
| continue |
| candidate_split_grids_nums.append(i) |
| |
| candidate_grids = [] |
| for split_grids_nums in candidate_split_grids_nums: |
| m = 1 |
| while m <= split_grids_nums: |
| if split_grids_nums % m == 0: |
| candidate_grids.append([m, split_grids_nums // m]) |
| m += 1 |
| best_grid = [1, 1] |
| min_error = float("inf") |
| for grid in candidate_grids: |
| error = abs(log_ratio - math.log(grid[0] / grid[1])) |
| if error < min_error: |
| best_grid = grid |
| min_error = error |
|
|
| return best_grid |
|
|
| def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): |
| target_aspect_ratio = self.find_closest_aspect_ratio(image, min_num, max_num, image_size) |
| target_width = image_size * target_aspect_ratio[0] |
| target_height = image_size * target_aspect_ratio[1] |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
| |
| resized_img = image.resize((target_width, target_height)) |
| processed_images = [] |
| for i in range(blocks): |
| box = ( |
| (i % (target_width // image_size)) * image_size, |
| (i // (target_width // image_size)) * image_size, |
| ((i % (target_width // image_size)) + 1) * image_size, |
| ((i // (target_width // image_size)) + 1) * image_size |
| ) |
| |
| split_img = resized_img.crop(box) |
| processed_images.append(split_img) |
| assert len(processed_images) == blocks |
| if use_thumbnail and len(processed_images) != 1: |
| thumbnail_img = image.resize((image_size, image_size)) |
| processed_images.append(thumbnail_img) |
| return processed_images |
| |
| def load_image(self, image, input_size=448, max_num=12): |
| if not isinstance(image, Image.Image): |
| image = Image.open(image).convert('RGB') |
| images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
| return images |
|
|
| def pixel_shuffle(self, x, scale_factor=0.5): |
| n, w, h, c = x.size() |
| |
| x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) |
| |
| x = x.permute(0, 2, 1, 3).contiguous() |
| |
| x = x.view(n, int(h * scale_factor), int(w * scale_factor), |
| int(c / (scale_factor * scale_factor))) |
| if self.ps_version == 'v1': |
| warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " |
| 'which results in a transposed image.') |
| else: |
| x = x.permute(0, 2, 1, 3).contiguous() |
| return x |
|
|
| def extract_feature(self, pixel_values): |
| if self.select_layer == -1: |
| vit_embeds = self.vision_model( |
| pixel_values=pixel_values, |
| output_hidden_states=False, |
| return_dict=True).last_hidden_state |
| else: |
| vit_embeds = self.vision_model( |
| pixel_values=pixel_values, |
| output_hidden_states=True, |
| return_dict=True).hidden_states[self.select_layer] |
| vit_embeds = vit_embeds[:, 1:, :] |
|
|
| if self.training and self.neftune_alpha is not None: |
| vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha) |
|
|
| h = w = int(vit_embeds.shape[1] ** 0.5) |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) |
| vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) |
| vit_embeds = self.mlp1(vit_embeds) |
| return vit_embeds |
|
|
| def get_T_after_cnn(self, L_in, dilation=1): |
| for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "): |
| L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 |
| L_out = 1 + L_out // stride |
| L_in = L_out |
| return L_out |
|
|
| def process_audio(self, audio, return_tensors, sampling_rate=16000): |
| L = (audio.shape[0] if audio.shape[0] <= 480000 else 480000) |
| mel_len = L // 160 |
| audio_len_after_cnn = self.get_T_after_cnn(mel_len) |
| audio_token_num = (audio_len_after_cnn - 2) // 2 + 1 |
| inputs = self.audio_feature_extractor(audio, return_tensors=return_tensors, sampling_rate=sampling_rate) |
| inputs['audio_len_after_cnn'] = torch.tensor(audio_len_after_cnn, dtype=torch.long) |
| inputs['audio_token_num'] = torch.tensor(audio_token_num, dtype=torch.long) |
| return inputs |
|
|
| def load_audio(self, audio_file, sampling_rate=16000): |
| audio_values, _ = librosa.load(audio_file, sr=sampling_rate) |
|
|
| audio_process_values = self.process_audio(audio_values, sampling_rate=sampling_rate, return_tensors="pt") |
| input_features = audio_process_values['input_features'] |
| audio_len_after_cnn = audio_process_values['audio_len_after_cnn'] |
| audio_token_num = audio_process_values['audio_token_num'] |
|
|
| audio_input_dict = {'audio_values': input_features, |
| 'audio_len_after_cnn': audio_len_after_cnn, |
| 'audio_token_num': audio_token_num, |
| } |
| return audio_input_dict |
| |
| def extract_audio_feature(self, audio_values, audio_len_after_cnn): |
|
|
| audio_values = audio_values.squeeze(1) |
| max_len_in_batch = int(torch.max(audio_len_after_cnn).item()) |
| padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype, device=audio_values.device) |
| for index in range(len(audio_values)): |
| padding_mask[index, :int(audio_len_after_cnn[index].item())] = 0 |
|
|
| last_hidden_state = self.audio_model(audio_values, padding_mask, audio_len_after_cnn) |
|
|
| audio_embeds = self.mlp2(last_hidden_state) |
|
|
| return audio_embeds |
| |
| def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=32): |
| if bound: |
| start, end = bound[0], bound[1] |
| else: |
| start, end = -100000, 100000 |
| start_idx = max(first_idx, round(start * fps)) |
| end_idx = min(round(end * fps), max_frame) |
| seg_size = float(end_idx - start_idx) / num_segments |
| frame_indices = np.array([ |
| int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) |
| for idx in range(num_segments) |
| ]) |
| return frame_indices |
| |
| def load_video(self, video_path, bound=None, num_segments=32): |
| vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) |
| max_frame = len(vr) - 1 |
| fps = float(vr.get_avg_fps()) |
| frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) |
| frames = list() |
| for frame_index in frame_indices: |
| img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') |
| frames.append(img) |
| return frames |
|
|
| def find_second_last_occurrence(self, input_ids_list, target_id): |
| '''find taget_id index''' |
| reversed_list = list(reversed(input_ids_list)) |
| first_occurrence = -1 |
| second_occurrence = -1 |
| for idx, val in enumerate(reversed_list): |
| if val == target_id: |
| if first_occurrence == -1: |
| first_occurrence = idx |
| elif second_occurrence == -1: |
| second_occurrence = idx |
| break |
| |
| if second_occurrence == -1: |
| return -1 |
| return len(input_ids_list) - second_occurrence - 1 |
| |
| def decode_speech_tokens( |
| self, |
| speech_tokens, |
| speaker_embedding=None, |
| flow_prompt_speech_token=None, |
| prompt_speech_feat=None, |
| finalize=True, |
| token_offset=0, |
| ): |
| if speaker_embedding is None: |
| speaker_embedding = torch.zeros(1, 192) |
| pass |
| if flow_prompt_speech_token is None: |
| flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32) |
| pass |
| if prompt_speech_feat is None: |
| prompt_speech_feat = torch.zeros(1, 0, 80) |
| pass |
|
|
| self.flow_model.encoder.static_chunk_size = 2 * self.flow_model.input_frame_rate |
| self.flow_model.decoder.estimator.static_chunk_size = 2 * self.flow_model.input_frame_rate * self.flow_model.token_mel_ratio |
| device = speech_tokens.device |
|
|
| tts_mel, _ = self.flow_model.inference( |
| token=speech_tokens.to(device), |
| token_len=torch.tensor([speech_tokens.shape[1]], dtype=torch.int32).to(device), |
| prompt_token=flow_prompt_speech_token.to(device), |
| prompt_token_len=torch.tensor([flow_prompt_speech_token.shape[1]], dtype=torch.int32).to(device), |
| prompt_feat=prompt_speech_feat.to(device), |
| prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(device), |
| embedding=speaker_embedding.to(device), |
| finalize=finalize, |
| ) |
| tts_mel = tts_mel[:, :, token_offset * self.config.flow_config.token_mel_ratio:] |
|
|
| hift_cache_source = torch.zeros(1, 1, 0) |
| tts_speech, tts_source = self.hifigan_model.inference(speech_feat=tts_mel, cache_source=hift_cache_source) |
| |
| return tts_speech |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: torch.FloatTensor, |
| input_ids: torch.FloatTensor, |
| attention_mask: torch.LongTensor, |
| visual_features: Optional[torch.FloatTensor] = None, |
| audio_values: Optional[torch.FloatTensor] = None, |
| audio_len_after_cnn: Optional[bool] = None, |
| audio_token_num: Optional[bool] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| output_hidden_states: Optional[bool] = None, |
| start_token_id:int = 151644, |
| generate_audio:bool = False, |
| speaker_embedding:torch.Tensor = torch.zeros(1, 192), |
| mix_ratio:list=[5,25], |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
| assert self.img_context_token_id is not None |
| assert self.audio_context_token_id is not None |
|
|
| vit_embeds = None |
| if visual_features is not None: |
| vit_embeds = visual_features |
| elif pixel_values is not None: |
| vit_embeds = self.extract_feature(pixel_values) |
| cur_conv_start_id = self.find_second_last_occurrence(input_ids.tolist()[0], start_token_id) |
| |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| B, N, C = input_embeds.shape |
| input_embeds = input_embeds.reshape(B * N, C) |
| |
| input_ids = input_ids.reshape(B * N) |
|
|
| if vit_embeds is not None: |
| selected = (input_ids == self.img_context_token_id) |
| input_embeds[selected] = vit_embeds.reshape(-1, C) |
|
|
| if audio_values is not None and audio_len_after_cnn is not None and audio_token_num is not None: |
| audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn) |
| output_audios = [] |
| for i in range(len(audio_token_num)): |
| token_num = int(audio_token_num[i].item()) |
| audio = audio_embeds[i][:token_num] |
| output_audios.append(audio) |
| output_audios = torch.cat(output_audios, dim=0) |
| selected = (input_ids == self.audio_context_token_id) |
| input_embeds[selected] = output_audios.reshape(-1, C) |
| |
| input_embeds = input_embeds.reshape(B, N, C) |
| |
| outputs = self.language_model.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| output_hidden_states=output_hidden_states or generate_audio, |
| return_dict_in_generate=generate_audio, |
| use_cache=True, |
| **generate_kwargs, |
| ) |
| if not generate_audio: |
| return outputs, None, None |
|
|
| hidden_states = torch.cat( |
| [outputs.hidden_states[0][-1][:, -1:, :]] + [outputs.hidden_states[i][-1] for i in range(1, len(outputs.hidden_states))], |
| dim=1, |
| ) |
| sampled_token = outputs.sequences |
| if sampled_token.shape[1] == hidden_states.shape[1] + 1: |
| sampled_token = sampled_token[:, 1:] |
| sampled_token_embeddings = self.language_model.get_input_embeddings()(sampled_token) |
| target_text_token_hidden_states = self.fusion(hidden_states, sampled_token_embeddings) |
|
|
| input_token_hidden_states = outputs.hidden_states[0][-1][:, cur_conv_start_id:-1, :] |
| question_input_embeddings = input_embeds[:, cur_conv_start_id+1:, :] |
| input_token_hidden_states = self.fusion(input_token_hidden_states, question_input_embeddings) |
|
|
| input_feature = self.mlp_llm2voicelm(input_token_hidden_states) |
| target_text_feature = self.mlp_llm2voicelm(target_text_token_hidden_states) |
|
|
| try: |
| speech_tokens = self.voicelm_model.inference_bistream(input_feature, target_text_feature, mix_ratio=mix_ratio) |
| speech_tokens = torch.LongTensor([speech_tokens]).to(input_feature.device) |
| tts_speech = self.decode_speech_tokens( |
| speech_tokens, |
| speaker_embedding=speaker_embedding, |
| ) |
| except Exception as e: |
| logger.warning(f"=========voice lm except:{e}") |
| return outputs.sequences,None, None |
| return outputs.sequences, speech_tokens, tts_speech |
|
|
| def chat( |
| self, |
| tokenizer, |
| generation_config, |
| messages, |
| max_patch_num=12, |
| frame=8, |
| generate_audio=False, |
| speaker_embedding=torch.zeros(1, 192), |
| print_flag=True, |
| ): |
| if self.flow_model.dtype != torch.float32 or self.hifigan_model.dtype != torch.float32: |
| logger.info(f"reset flow model and higigan model dtype to float32") |
| self.reset_vocoder() |
| pass |
| if messages is None or len(messages) == 0: |
| raise RuntimeError('no messages') |
| role_transfer_dict = { |
| 'system': ['user'], |
| 'user': ['assistant'], |
| 'assistant': ['user'], |
| } |
|
|
| first_role = ['system', 'user'] |
| last_role = ['user'] |
| if messages[-1]['role'] not in last_role: |
| raise RuntimeError(f"last role error, expect {last_role}, but got {messages[-1]}") |
| |
| current_role = None |
| dynamic_images = list() |
| dynamic_nums = list() |
| audio_values = list() |
| audio_len_after_cnn = list() |
| audio_token_num = list() |
| template = get_conv_template(self.template) |
| for index in range(len(messages)): |
| text = '' |
| audios = list() |
| images = list() |
| message = messages[index] |
| if index == 0: |
| if message['role'] not in first_role: |
| raise RuntimeError(f'first role error expect {first_role}, but got {message}') |
| else: |
| if message['role'] not in current_role: |
| raise RuntimeError(f'role error expect {current_role}, but got {message}') |
| current_role = message['role'] |
| if isinstance(message["content"], list): |
| for item in message["content"]: |
| if item['type'] == 'text': |
| if item.get('text', None) is None: |
| continue |
| text += item['text'] |
| elif item['type'] == 'audio': |
| if item.get('audio', None) is None: |
| continue |
| if type(item['audio']) is list: |
| assert len(item['audio']) == 1, f'only support 1 audio file in round, but got {item["audio"]}' |
| audio = item['audio'][0] |
| else: |
| audio = item['audio'] |
| audios.append(audio) |
| elif item['type'] == 'image': |
| if item.get('image', None) is None: |
| continue |
| if type(item['image']) is not list: |
| images.append(item['image']) |
| else: |
| images.extend(item['image']) |
| elif item['type'] == 'video': |
| if item.get('video', None) is None: |
| continue |
| if type(item['video']) is list: |
| assert len(item['video']) == 1, f'only support 1 video file in round, but got {item["video"]}' |
| video = item['video'][0] |
| else: |
| video = item['video'] |
| frames = self.load_video(video, num_segments=frame) |
| images.extend(frames) |
| else: |
| assert isinstance(message["content"], str), message["content"] |
| text = message["content"] |
|
|
| if len(audios) != 0: |
| assert len(audios) == 1, f'only support 1 audio file in round, but got {audios}' |
| if '<audio>' in text: |
| matches = re.findall(r"<audio>", text) |
| assert len(matches) == len(audios), f'<audio> error {text} {len(audios)}' + text |
| text = re.sub(r'(<audio>)(?!\n)', r'\1\n', text) |
| else: |
| text = '<audio>\n'*len(audios) + text |
|
|
| audio_path = audios[0] |
| audio_input_dict = self.load_audio(audio_path) |
| assert audio_input_dict['audio_token_num'].item() != 0, f'audio_token_num of {audio_path} is 0.' |
| audio_values.append(audio_input_dict['audio_values']) |
| audio_len_after_cnn.append(audio_input_dict['audio_len_after_cnn']) |
| audio_token_num.append(audio_input_dict['audio_token_num']) |
|
|
| if images is not None: |
| if '<image>' in text: |
| matches = re.findall(r"<image>", text) |
| assert len(matches) == len(images), f'<image> error {text} {len(images)}' + text |
| text = re.sub(r'(<image>)(?!\n)', r'\1\n', text) |
| else: |
| text = '<image>\n'*len(images) + text |
|
|
| for image in images: |
| dynamic_image = self.load_image(image, max_num=max_patch_num) |
| dynamic_images += dynamic_image |
| dynamic_nums.append(len(dynamic_image)) |
| |
| if message['role'] == 'system': |
| template.set_system_message(text) |
| elif message['role'] == 'user': |
| template.append_message(template.roles[0], text) |
| elif message['role'] == 'assistant': |
| template.append_message(template.roles[1], text) |
| else: |
| raise ValueError('unexpected role') |
|
|
| current_role = role_transfer_dict[current_role] |
|
|
| template.append_message(template.roles[1], None) |
| |
| if len(audio_values) != 0: |
| audio_values = torch.cat(audio_values, dim=0).to(dtype=self.dtype).cuda() |
| audio_len_after_cnn = torch.stack(audio_len_after_cnn, dim=0) |
| audio_token_num = torch.stack(audio_token_num, dim=0) |
| else: |
| audio_values = None |
| audio_len_after_cnn = None |
| audio_token_num = None |
|
|
| if len(dynamic_images) != 0: |
| pixel_values = [self.transform(image) for image in dynamic_images] |
| pixel_values = torch.stack(pixel_values) |
| pixel_values = pixel_values.to(torch.bfloat16).cuda() |
| else: |
| pixel_values = None |
| dynamic_nums = None |
| |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) |
| self.img_context_token_id = img_context_token_id |
| audio_context_token_id = tokenizer.convert_tokens_to_ids(AUDIO_CONTEXT_TOKEN) |
| self.audio_context_token_id = audio_context_token_id |
|
|
| |
| eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]] |
| start_token_id = tokenizer.convert_tokens_to_ids(["<|im_start|>"])[0] |
|
|
| query = template.get_prompt() |
|
|
| if audio_values is not None: |
| if print_flag: |
| logger.info(f'audio num: {len(audio_token_num)}') |
| audio_tokens_list = list() |
| for index in range(len(audio_token_num)): |
| audio_token_num_i = audio_token_num[index] |
| if print_flag: |
| logger.info(f'audio_token_num: {audio_token_num_i}') |
| audio_tokens = AUDIO_START_TOKEN + AUDIO_CONTEXT_TOKEN * audio_token_num_i + AUDIO_END_TOKEN |
| audio_tokens_list.append(audio_tokens) |
|
|
| audio_tokens_iter = iter(audio_tokens_list) |
|
|
| query = re.sub(r"<audio>", lambda match:next(audio_tokens_iter), query) |
|
|
| if pixel_values is not None: |
| if print_flag: |
| logger.info(f'image num: {len(dynamic_nums)}') |
| image_tokens_list = list() |
| total_dynamic_num = 0 |
| for index in range(len(dynamic_nums)): |
| dynamic_num = dynamic_nums[index] |
| total_dynamic_num += dynamic_num |
| if print_flag: |
| logger.info(f'dynamic ViT batch size: {dynamic_num}') |
| image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * dynamic_num + IMG_END_TOKEN |
| image_tokens_list.append(image_tokens) |
| assert total_dynamic_num == pixel_values.shape[0], f'dynamic num not equal, {total_dynamic_num}, {pixel_values.shape[0]}' |
|
|
| image_tokens_iter = iter(image_tokens_list) |
|
|
| query = re.sub(r"<image>", lambda match:next(image_tokens_iter), query) |
|
|
| model_inputs = tokenizer(query, return_tensors='pt', add_special_tokens=False) |
| input_ids = model_inputs['input_ids'].cuda() |
| attention_mask = model_inputs['attention_mask'].cuda() |
| generation_config['eos_token_id'] = eos_token_id |
| generation_output, speech_token, audio_bytes = self.generate( |
| pixel_values=pixel_values, |
| audio_values=audio_values, |
| audio_len_after_cnn=audio_len_after_cnn, |
| audio_token_num=audio_token_num, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| generate_audio=generate_audio, |
| start_token_id=start_token_id, |
| speaker_embedding=speaker_embedding, |
| **generation_config |
| ) |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=False)[0] |
| response = response.split("<|im_end|>")[0].replace('<|endoftext|>', '').strip() |
| query_to_print = query |
| if pixel_values is not None: |
| query_to_print = query_to_print.replace(IMG_CONTEXT_TOKEN, '') |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') |
| if audio_values is not None: |
| query_to_print = query_to_print.replace(AUDIO_CONTEXT_TOKEN, '') |
| query_to_print = query_to_print.replace(f'{AUDIO_START_TOKEN}{AUDIO_END_TOKEN}', '<audio>') |
| if print_flag: |
| logger.info('query: ' + json.dumps(query_to_print, ensure_ascii=False)) |
| logger.info('response: ' + response) |
|
|
| if generate_audio: |
| return response, audio_bytes |
| return response |
| |
| def __cache_file(self, pretrained_model_name_or_path:str, filename:str, **kw): |
| '''cache some file''' |
| full_path = cached_file( |
| pretrained_model_name_or_path, |
| filename, |
| subfolder=kw.pop("subfolder", None), |
| cache_dir=kw.pop("cache_dir", None), |
| force_download=kw.pop("force_download", False), |
| proxies=kw.pop("proxies", None), |
| resume_download=kw.pop("resume_download", None), |
| local_files_only=kw.pop("local_files_only", False), |
| token=kw.pop("use_auth_token", None), |
| revision=kw.pop("revision", None), |
| ) |
| if full_path is None: |
| raise ValueError(f"""{pretrained_model_name_or_path}/{filename} not exists""") |
| return full_path |
| |
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path, |
| *model_args, |
| config=None, |
| cache_dir=None, |
| ignore_mismatched_sizes=False, |
| force_download=False, |
| local_files_only=False, |
| token=None, |
| revision="main", |
| use_safetensors=None, |
| weights_only=True, |
| **kwargs, |
| ): |
| model = super().from_pretrained( |
| pretrained_model_name_or_path, |
| *model_args, |
| config=config, |
| cache_dir=cache_dir, |
| ignore_mismatched_sizes=ignore_mismatched_sizes, |
| force_download=force_download, |
| local_files_only=local_files_only, |
| token=token, |
| revision=revision, |
| use_safetensors=use_safetensors, |
| weights_only=weights_only, |
| **kwargs, |
| ) |
| campplus_path = model.__cache_file(pretrained_model_name_or_path, "campplus.onnx", **kwargs) |
| model.__load_campplus_session(campplus_path) |
| default_wav_path = model.__cache_file(pretrained_model_name_or_path, "taozi.wav", **kwargs) |
| model.default_wav_path = default_wav_path |
| model.default_speaker_embedding = model.extract_speaker_embedding(default_wav_path) |
|
|
| return model |