import logging import os.path from typing import List import torch from header import * import torch.nn.functional as F from .ImageBind import * from .ImageBind import data from .modeling_llama import LlamaForCausalLM from transformers import StoppingCriteria, StoppingCriteriaList # from diffusers import StableDiffusionPipeline from .custom_sd import StableDiffusionPipeline from .custom_vd import TextToVideoSDPipeline from .custom_ad import AudioLDMPipeline from .layers import * from .common.utils import * class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops: List = None, encounters: int = 1): super().__init__() self.stops = stops self.ENCOUNTERS = encounters def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): stop_count = 0 for stop in self.stops: _stop = torch.tensor(stop).to(input_ids[0].device) indices = torch.where(_stop[0] == input_ids) for i in indices: if len(i) > 0: if torch.all(input_ids[0][i:i + len(_stop)] == _stop): stop_count += 1 if stop_count >= self.ENCOUNTERS: return True return False class NextGPTModel(nn.Module): """LoRA for LLaMa model""" def __init__(self, **args): super(NextGPTModel, self).__init__() self.args = args self.max_length = args['max_length'] self.device = torch.cuda.current_device() self.stage = args['stage'] print('args max_length', args['max_length']) imagebind_ckpt_path = os.path.join(self.args['pretrained_ckpt_path'], 'imagebind_ckpt', self.args['imagebind_version']) print(f'Initializing visual encoder from {imagebind_ckpt_path} ...') self.visual_encoder, self.visual_hidden_size = \ imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path) # free vision encoder for name, param in self.visual_encoder.named_parameters(): param.requires_grad = False self.visual_encoder.eval() print('Visual encoder initialized.') self.vicuna_ckpt_path = os.path.join(self.args['pretrained_ckpt_path'], 'vicuna_ckpt', self.args['vicuna_version']) print(f'Initializing language decoder from {self.vicuna_ckpt_path} ...') self.llama_model = LlamaForCausalLM.from_pretrained(self.vicuna_ckpt_path) if self.args.get('freeze_lm'): print("Freezing the LLaMa ...") for param in self.llama_model.parameters(): param.requires_grad = False self.llama_model.eval() else: print("Instruct tuning the LLaMa ...") # add the lora module peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=self.args['lora_r'], lora_alpha=self.args['lora_alpha'], lora_dropout=self.args['lora_dropout'], target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'] ) self.llama_model = get_peft_model(self.llama_model, peft_config) self.llama_model.print_trainable_parameters() print('Language decoder initialized.') # use the new trained tokenizer tokenizer_path = self.vicuna_ckpt_path print(f'Initializing tokenizer from {tokenizer_path} ...') self.llama_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path, use_fast=False) self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_tokenizer.padding_side = "right" # self.llama_tokenizer.add_special_tokens({"mask_token": "[MASK]"}) self._add_image_token() self._add_video_token() self._add_audio_token() self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) print('Tokenizer initialized.') self.llama_proj = nn.Linear( self.visual_hidden_size, self.llama_model.config.hidden_size ) if self.args.get('freeze_input_proj'): for param in self.llama_proj.parameters(): param.requires_grad = False self.input_embeddings = self.llama_model.get_input_embeddings() # the alignment module for LLM-TO-IMAGE self.sd_ckpt_path = self.args['image_diffusion'] self.gen_text_hidden_fcs = nn.ModuleList([]) for layer_idx in self.args['text_emb_to_img_layers']: if layer_idx == -1 or layer_idx == self.llama_model.config.num_hidden_layers: in_dim = self.llama_model.config.hidden_size self.gen_text_hidden_fcs.append( TextFcLayer(in_dim, 768, num_input_tokens=self.args['num_gen_img_tokens'], num_output_tokens=self.args['num_clip_tokens'], mode=self.args['text_fc_to_img_mode'])) # self.sd_pipe.text_encoder.config.hidden_size elif layer_idx < self.llama_model.config.num_hidden_layers: self.gen_text_hidden_fcs.append( TextFcLayer(self.llama_model.config.hidden_size, 768, num_input_tokens=self.args['num_gen_img_tokens'], num_output_tokens=self.args['num_clip_tokens'], mode=self.args['text_fc_to_img_mode'])) else: raise ValueError( f'Embedding of layer {layer_idx} was requested but model only has {self.llama_model.config.num_hidden_layers} layers.') # the alignment module for LLM-TO-VIDEO self.vd_ckpt_path = self.args['video_diffusion'] self.gen_text_hidden_fcs_video = nn.ModuleList([]) for layer_idx in self.args['text_emb_to_video_layers']: if layer_idx == -1 or layer_idx == self.llama_model.config.num_hidden_layers: in_dim = self.llama_model.config.hidden_size # 4096 self.gen_text_hidden_fcs_video.append( TextFcLayer(in_dim, 1024, num_input_tokens=self.args['num_gen_video_tokens'], num_output_tokens=self.args['num_clip_tokens'], mode=self.args['text_fc_to_video_mode'])) # self.vd_pipe.text_encoder.config.hidden_size elif layer_idx < self.llama_model.config.num_hidden_layers: self.gen_text_hidden_fcs_video.append( TextFcLayer(self.llama_model.config.hidden_size, 1024, num_input_tokens=self.args['num_gen_video_tokens'], num_output_tokens=self.args['num_clip_tokens'], mode=self.args['text_fc_to_video_mode'])) else: raise ValueError( f'Embedding of layer {layer_idx} was requested but model only has {self.llama_model.config.num_hidden_layers} layers.') # the alignment module for LLM-TO-AUDIO self.ad_ckpt_path = self.args['audio_diffusion'] self.gen_text_hidden_fcs_audio = nn.ModuleList([]) for layer_idx in self.args['text_emb_to_audio_layers']: if layer_idx == -1 or layer_idx == self.llama_model.config.num_hidden_layers: in_dim = self.llama_model.config.hidden_size self.gen_text_hidden_fcs_audio.append( TextFcLayer(in_dim, 512, num_input_tokens=self.args['num_gen_audio_tokens'], num_output_tokens=1, mode=self.args['text_fc_to_audio_mode'])) # self.ad_pipe.text_encoder.config.projection_dim elif layer_idx < self.llama_model.config.num_hidden_layers: self.gen_text_hidden_fcs_audio.append( TextFcLayer(self.llama_model.config.hidden_size, 512, num_input_tokens=self.args['num_gen_audio_tokens'], num_output_tokens=1, mode=self.args['text_fc_to_audio_mode'])) else: raise ValueError( f'Embedding of layer {layer_idx} was requested but model only has {self.llama_model.config.num_hidden_layers} layers.') if self.args.get('freeze_output_proj'): for name, param in self.gen_text_hidden_fcs.named_parameters(): param.requires_grad = False for name, param in self.gen_text_hidden_fcs_video.named_parameters(): param.requires_grad = False for name, param in self.gen_text_hidden_fcs_audio.named_parameters(): param.requires_grad = False def _add_image_token(self): # Add an image token for loss masking (and visualization) purposes. self.llama_tokenizer.add_tokens([""]) # add special image token to tokenizer self.llama_tokenizer.add_tokens([""]) # add special image token to tokenizer # Add [IMG] tokens to the vocabulary. self.args['gen_img_token_idx'] = [] for i in range(self.args['num_gen_img_tokens']): print(f'Adding [IMG{i}] token to vocabulary.') print(f'Before adding new token, tokenizer("[IMG{i}]") =', self.llama_tokenizer(f'[IMG{i}]', add_special_tokens=False)) num_added_tokens = self.llama_tokenizer.add_tokens(f'[IMG{i}]') print(f'After adding {num_added_tokens} new tokens, tokenizer("[IMG{i}]") =', self.llama_tokenizer(f'[IMG{i}]', add_special_tokens=False)) gen_token_idx = self.llama_tokenizer(f'[IMG{i}]', add_special_tokens=False).input_ids assert len(gen_token_idx) == 1, gen_token_idx self.args['gen_img_token_idx'].append(gen_token_idx[0]) def _add_video_token(self): # self.llama_tokenizer.add_tokens({""}) # add special video token to tokenizer # self.llama_tokenizer.add_tokens({""}) # add special video token to tokenizer # Add [VID] tokens to the vocabulary. self.args['gen_video_token_idx'] = [] for i in range(self.args['num_gen_video_tokens']): print(f'Adding [VID{i}] token to vocabulary.') print(f'Before adding new token, tokenizer("[VID{i}]") =', self.llama_tokenizer(f'[VID{i}]', add_special_tokens=False)) num_added_tokens = self.llama_tokenizer.add_tokens(f'[VID{i}]') print(f'After adding {num_added_tokens} new tokens, tokenizer("[VID{i}]") =', self.llama_tokenizer(f'[VID{i}]', add_special_tokens=False)) gen_token_idx = self.llama_tokenizer(f'[VID{i}]', add_special_tokens=False).input_ids assert len(gen_token_idx) == 1, gen_token_idx self.args['gen_video_token_idx'].append(gen_token_idx[0]) def _add_audio_token(self): # self.llama_tokenizer.add_tokens({""}) # add special audio token to tokenizer # self.llama_tokenizer.add_tokens({""}) # add special audio token to tokenizer # Add [AUD] tokens to the vocabulary. self.args['gen_audio_token_idx'] = [] for i in range(self.args['num_gen_audio_tokens']): print(f'Adding [AUD{i}] token to vocabulary.') print(f'Before adding new token, tokenizer("[AUD{i}]") =', self.llama_tokenizer(f'[AUD{i}]', add_special_tokens=False)) num_added_tokens = self.llama_tokenizer.add_tokens(f'[AUD{i}]') print(f'After adding {num_added_tokens} new tokens, tokenizer("[AUD{i}]") =', self.llama_tokenizer(f'[AUD{i}]', add_special_tokens=False)) gen_token_idx = self.llama_tokenizer(f'[AUD{i}]', add_special_tokens=False).input_ids assert len(gen_token_idx) == 1, gen_token_idx self.args['gen_audio_token_idx'].append(gen_token_idx[0]) def encode_video(self, video_paths): inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)} # convert into visual dtype inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} with torch.no_grad(): embeddings = self.visual_encoder(inputs) video_embeds = embeddings[ModalityType.VISION] # bsz x 1024 inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) # bsz x 1 x llama_size atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 return inputs_llama, atts_llama def encode_audio(self, audio_paths): inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)} # convert into visual dtype inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} with torch.no_grad(): embeddings = self.visual_encoder(inputs) audio_embeds = embeddings[ModalityType.AUDIO] # bsz x 1024 inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) # bsz x 1 x llama_size atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 return inputs_llama, atts_llama def encode_image(self, image_paths): inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)} # convert into visual dtype inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} with torch.no_grad(): embeddings = self.visual_encoder(inputs) image_embeds = embeddings['vision'] # bsz x 1024 inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) # bsz x 1 x llama_size atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) # bsz x 1 return inputs_llama, atts_llama def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask): ''' input_ids, target_ids, attention_mask: bsz x s2 ''' input_ids = input_ids.to(self.device) # bsz x s2 target_ids = target_ids.to(self.device) # bsz x s2 attention_mask = attention_mask.to(self.device) # bsz x s2 batch_size = input_ids.shape[0] bos = torch.ones([batch_size, 1], dtype=input_ids.dtype, device=input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1 if self.args['freeze_lm']: p_after_embeds = self.llama_model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim bos_embeds = self.llama_model.model.embed_tokens(bos) # bsz x 1 x embed_dim else: p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim if img_embeds is not None: p_before = '### Human: ' p_before_tokens = self.llama_tokenizer(p_before, return_tensors="pt", add_special_tokens=False).to( self.device) # peft model need deeper call if self.args['freeze_lm']: p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim else: p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand( batch_size, -1, -1) # bsz x s1 x embed_dim inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1).to( self.device) # bsz x (1+s1+1+s2) x embed_dim # create targets empty_targets = ( torch.ones([batch_size, 1 + p_before_embeds.size()[1] + 1], # 1 (bos) + s1 + 1 dtype=torch.long).to(self.device).fill_(-100) ) # bsz x (1 + s1) targets = torch.cat([empty_targets, target_ids], dim=1).to(self.device) # bsz x (1 + s1 + 1 + s2) assert inputs_embeds.size()[1] == targets.size()[1] atts_prefix = torch.ones([batch_size, 1 + p_before_embeds.size()[1] + 1], dtype=torch.long).to( self.device) # bsz x (1 + s1 + 1) attention_mask = torch.cat([atts_prefix, attention_mask], dim=1).to(self.device) assert attention_mask.size() == targets.size() # bsz x (1 + s1 + 1 + s2) else: p_before = '### Human: ' p_before_tokens = self.llama_tokenizer(p_before, return_tensors="pt", add_special_tokens=False).to( self.device) # peft model need deeper call if self.args['freeze_lm']: p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim else: p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand( batch_size, -1, -1) # bsz x s1 x embed_dim inputs_embeds = torch.cat([bos_embeds, p_before_embeds, p_after_embeds], dim=1).to( self.device) # bsz x (1+s1+s2) x embed_dim # create targets empty_targets = ( torch.ones([batch_size, 1 + p_before_embeds.size()[1]], # 1 (bos) + s1 dtype=torch.long).to(self.device).fill_(-100) ) # bsz x (1 + s1) targets = torch.cat([empty_targets, target_ids], dim=1).to(self.device) # bsz x (1 + s1 + s2) assert inputs_embeds.size()[1] == targets.size()[1] atts_prefix = torch.ones([batch_size, 1 + p_before_embeds.size()[1]], dtype=torch.long).to( self.device) # bsz x (1 + s1) attention_mask = torch.cat([atts_prefix, attention_mask], dim=1).to(self.device) assert attention_mask.size() == targets.size() # bsz x (1 + s1 + s2) return inputs_embeds, targets, attention_mask def _train_with_mode(self, texts, img_embeds=None, modality='text', num_gen_tokens='8', text_hidden_fcs=None, gen_token_idx=None, text_emb_layers=None, text_prompt_embeddins=None, loss_scale=1.0, stage=2): """ :param num_gen_tokens: the number of generation tokens :param modality: mode can be 'image' / 'video' / 'audio' / 'text' :param text_hidden_fcs: alignment module :param gen_token_idx: List :param text_emb_layers: the layer index of LLM hidden states :param text_prompt_embeddins: the textual caption/prompt embeddings :param loss_scale: the scale on the mse loss for alignment :param stage: the training stage :param """ if stage == 2: input_ids, target_ids, attention_mask = process_batch_stage_2(self.llama_tokenizer, texts, self.max_length, num_gen_tokens, modality ) elif stage == 3: input_ids, target_ids, attention_mask = process_batch_stage_3(self.llama_tokenizer, texts, self.max_length, self.args['num_gen_img_tokens'], self.args['num_gen_video_tokens'], self.args['num_gen_audio_tokens'] ) else: raise NotImplementedError inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask) outputs = self.llama_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, labels=targets, ) loss = outputs.loss # calculate the token accuracy chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1] labels = targets[:, 2:] gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S] valid_mask = (labels != -100).reshape(-1) valid_tokens = gen_acc & valid_mask # [B*S] gen_acc = valid_tokens.sum().item() / (valid_mask.sum().item() + 1.0) if modality == 'text': return loss, gen_acc, torch.zeros_like(loss) else: hidden_states = [] # text_hidden_fcs = self.gen_text_hidden_fcs # based on the targets to obtain the hidden state, targets includes the [BOS] token start_pos = (targets == gen_token_idx[0]).nonzero(as_tuple=False)[:, 1].tolist() end_pos = (targets == gen_token_idx[-1]).nonzero(as_tuple=False)[:, 1].tolist() # logging.info(f'targets : {targets}') # logging.info(f'start_pos : {start_pos}') # logging.info(f'end_pos : {end_pos}') assert 0 < len(start_pos) == len(end_pos) == input_ids.size(0) and len(end_pos) > 0, (start_pos, end_pos) for idx, fc_layer in zip(text_emb_layers, text_hidden_fcs): hidden_embedding = [] input_embedding = [] for b, (s, e) in enumerate(zip(start_pos, end_pos)): assert e - s + 1 == num_gen_tokens, (s, e) hidden_embedding.append(outputs.hidden_states[idx][b, s:e + 1, :]) input_embedding.append(self.input_embeddings(targets[b, s:e + 1])) hidden_embedding = torch.stack(hidden_embedding, dim=0) input_embedding = torch.stack(input_embedding, dim=0) hidden_states.append(fc_layer(hidden_embedding, input_embedding)) # (N, seq_len, 2048) embeddings = torch.stack(hidden_states, dim=-1).sum(dim=-1) # (N, 77, 768) # embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (N, T_I_V_A.txt, 256) # Obtain the embeddings produced by the text encoder of a frozen text-to-image generation model input_text = [conversation for conversation in texts] if modality == 'image': mse_loss = l2_loss(embeddings, torch.stack(text_prompt_embeddins, dim=0).to(self.device)) elif modality == 'video': mse_loss = l2_loss(embeddings, torch.stack(text_prompt_embeddins, dim=0).to(self.device)) else: text_prompt_embeddins = torch.stack(text_prompt_embeddins, dim=0).to(self.device) assert len(text_prompt_embeddins.shape) == 2, text_prompt_embeddins.shape text_prompt_embeddins = text_prompt_embeddins.view(text_prompt_embeddins.size(0), 1, text_prompt_embeddins.size(1)) mse_loss = l2_loss(embeddings, text_prompt_embeddins) mse_loss = mse_loss.mean() loss += loss_scale * mse_loss return loss, gen_acc, mse_loss def _enc_align_training_stage_1(self, inputs): """ In the stage 1: training the encoding-side alignment via image/video/audio caption tasks modality: the input modality for each caption task, it could be 'image', 'video' or 'audio'. """ dataset_type = inputs['dataset_types'][0] if dataset_type == 'ImageToText': image_paths = inputs['mm_paths'] mm_embeds, _ = self.encode_image(image_paths) elif dataset_type == 'VideoToText': video_paths = inputs['mm_paths'] mm_embeds, _ = self.encode_video(video_paths) elif dataset_type == 'AudioToText': audio_paths = inputs['mm_paths'] mm_embeds, _ = self.encode_audio(audio_paths) else: raise NotImplementedError input_ids, target_ids, attention_mask = process_batch_stage_1(self.llama_tokenizer, inputs['output_texts'], self.max_length, self.args['prompt']) # print(input_ids) inputs_embeds, targets, attention_mask = self.prompt_wrap(mm_embeds, input_ids, target_ids, attention_mask) outputs = self.llama_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, output_hidden_states=True, labels=targets, ) loss = outputs.loss # calculate the token accuracy chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] # [B, S-1] labels = targets[:, 2:] gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) # [B*S] valid_mask = (labels != -100).reshape(-1) valid_tokens = gen_acc & valid_mask # [B*S] gen_acc = valid_tokens.sum().item() / (valid_mask.sum().item() + 1.0) return loss, gen_acc def _dec_align_training_stage_2(self, inputs): """ In the stage 2: training the decoding-side alignment via minimize the distance between the representation of signal tokens and caption from text encoder within the respective diffusion models. modality: the output modality for each caption. """ dataset_type = inputs['dataset_types'][0] if dataset_type == 'TextToImage': loss, gen_acc, mse_loss = self._train_with_mode(texts=inputs['output_texts'], modality='image', num_gen_tokens=self.args['num_gen_img_tokens'], text_hidden_fcs=self.gen_text_hidden_fcs, gen_token_idx=self.args['gen_img_token_idx'], text_emb_layers=self.args['text_emb_to_img_layers'], text_prompt_embeddins=inputs['caption_embs'], stage=self.stage) elif dataset_type == 'TextToVideo': loss, gen_acc, mse_loss = self._train_with_mode(texts=inputs['output_texts'], modality='video', num_gen_tokens=self.args['num_gen_video_tokens'], text_hidden_fcs=self.gen_text_hidden_fcs_video, gen_token_idx=self.args['gen_video_token_idx'], text_emb_layers=self.args['text_emb_to_video_layers'], text_prompt_embeddins=inputs['caption_embs'], stage=self.stage) elif dataset_type == 'TextToAudio': loss, gen_acc, mse_loss = self._train_with_mode(texts=inputs['output_texts'], modality='audio', num_gen_tokens=self.args['num_gen_audio_tokens'], text_hidden_fcs=self.gen_text_hidden_fcs_audio, gen_token_idx=self.args['gen_audio_token_idx'], text_emb_layers=self.args['text_emb_to_audio_layers'], text_prompt_embeddins=inputs['caption_embs'], stage=self.stage) else: raise NotImplementedError return loss, gen_acc, mse_loss def _instruction_tuning_stage_3(self, inputs): """ In the stage 3: instruction-following training via the instruction dataset. """ loss = 0 gen_acc = 0 mse_loss = [] dataset_type = inputs['dataset_types'][0] if dataset_type == 'TextToImage': loss, gen_acc, mse_loss = self._train_with_mode(inputs['output_texts'], None, 'image', self.args['num_gen_img_tokens'], self.gen_text_hidden_fcs, self.args['gen_img_token_idx'], self.args['text_emb_to_img_layers'], inputs['caption_embs'], stage=self.stage) elif dataset_type == 'TextToVideo': loss, gen_acc, mse_loss = self._train_with_mode(inputs['output_texts'], None, 'video', self.args['num_gen_video_tokens'], self.gen_text_hidden_fcs_video, self.args['gen_video_token_idx'], self.args['text_emb_to_video_layers'], inputs['caption_embs'], loss_scale=2, stage=self.stage) elif dataset_type == 'TextToAudio': loss, gen_acc, mse_loss = self._train_with_mode(inputs['output_texts'], None, 'audio', self.args['num_gen_audio_tokens'], self.gen_text_hidden_fcs_audio, self.args['gen_audio_token_idx'], self.args['text_emb_to_audio_layers'], inputs['caption_embs'], stage=self.stage) elif dataset_type == 'ImageToText': image_paths = inputs['mm_paths'] img_embeds, _ = self.encode_image(image_paths) loss, gen_acc, _ = self._train_with_mode(inputs['output_texts'], img_embeds, modality='text', stage=self.stage) elif dataset_type == 'TextToText': loss, gen_acc, _ = self._train_with_mode(inputs['output_texts'], None, modality='text', stage=self.stage) else: raise NotImplementedError return loss, gen_acc, mse_loss def _stage_4_training(self, inputs): """ In the stage 4, we employ the modality-switch dataset to instruction-tune the overall framework """ pass def forward(self, inputs): loss = 0 gen_acc = 0 mse_loss = None if self.stage == 1: loss, gen_acc = self._enc_align_training_stage_1(inputs) elif self.stage == 2: loss, gen_acc, mse_loss = self._dec_align_training_stage_2(inputs) elif self.stage == 3: loss, gen_acc, mse_loss = self._instruction_tuning_stage_3(inputs) else: raise NotImplementedError(f"stage {self.stage} is not implemented, now it only support [1, 2, 3]") return loss, gen_acc, mse_loss def extract_multimodal_feature(self, inputs): features = [] if inputs['image_paths']: image_embeds, _ = self.encode_image(inputs['image_paths']) features.append(image_embeds) if inputs['audio_paths']: audio_embeds, _ = self.encode_audio(inputs['audio_paths']) features.append(audio_embeds) if inputs['video_paths']: video_embeds, _ = self.encode_video(inputs['video_paths']) features.append(video_embeds) feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) return feature_embeds def _prepare_image_embed(self, text, batch_size): pattern = r'Image>(.*?)<\/Image' matches = re.findall(pattern, text) features = [] p_before_token = self.llama_tokenizer('', add_special_tokens=False, return_tensors='pt').to(self.device) p_after_token = self.llama_tokenizer('', add_special_tokens=False, return_tensors='pt').to(self.device) if self.args['freeze_lm']: p_before_embeds = self.llama_model.model.embed_tokens(p_before_token.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim p_after_embeds = self.llama_model.model.embed_tokens(p_after_token.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim else: p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_token.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_token.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim for m in matches: print('image path: ', m) if m.startswith('temp'): m = os.path.join('./', m) print('image path: ', m) _temp_embedding, _ = self.encode_image([m]) features.append(_temp_embedding) feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) return torch.cat([p_before_embeds, feature_embeds, p_after_embeds], dim=1) def _prepare_video_embed(self, text, batch_size): pattern = r'Video>(.*?)<\/Video' matches = re.findall(pattern, text) features = [] p_before_token = self.llama_tokenizer('', add_special_tokens=False, return_tensors='pt').to(self.device) p_after_token = self.llama_tokenizer('', add_special_tokens=False, return_tensors='pt').to(self.device) if self.args['freeze_lm']: p_before_embeds = self.llama_model.model.embed_tokens(p_before_token.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim p_after_embeds = self.llama_model.model.embed_tokens(p_after_token.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim else: p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_token.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_token.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim for m in matches: print('Video path: ', m) if m.startswith('temp'): m = os.path.join('./', m) print('Video path: ', m) _temp_embedding, _ = self.encode_video([m]) features.append(_temp_embedding) feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) return torch.cat([p_before_embeds, feature_embeds, p_after_embeds], dim=1) def _prepare_audio_embed(self, text, batch_size): pattern = r'Audio>(.*?)<\/Audio' matches = re.findall(pattern, text) features = [] p_before_token = self.llama_tokenizer('', add_special_tokens=False, return_tensors='pt').to(self.device) p_after_token = self.llama_tokenizer('', add_special_tokens=False, return_tensors='pt').to(self.device) if self.args['freeze_lm']: p_before_embeds = self.llama_model.model.embed_tokens(p_before_token.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim p_after_embeds = self.llama_model.model.embed_tokens(p_after_token.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim else: p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_token.input_ids).expand(batch_size, -1, -1) # bsz x s1 x embed_dim p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_token.input_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim for m in matches: print('Audio path: ', m) if m.startswith('temp'): m = os.path.join('./', m) print('Video path: ', m) _temp_embedding, _ = self.encode_audio([m]) features.append(_temp_embedding) feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) return torch.cat([p_before_embeds, feature_embeds, p_after_embeds], dim=1) def prepare_generation_embedding(self, inputs): prompt = inputs['prompt'] text = prompt + '\n### Assistant:' print("text prompt: ", text) batch_size = 1 input_embeds = [] split_text = re.split(r' <|> ', text) for st in split_text: if st.startswith('Image>'): input_embeds.append(self._prepare_image_embed(st, batch_size)) elif st.startswith('Audio>'): input_embeds.append(self._prepare_audio_embed(st, batch_size)) elif st.startswith('Video>'): input_embeds.append(self._prepare_video_embed(st, batch_size)) else: text_tokens = self.llama_tokenizer(st, add_special_tokens=False, return_tensors='pt').to(self.device) bos = torch.ones([batch_size, 1], dtype=text_tokens.input_ids.dtype, device=text_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id # bsz x 1 if self.args['freeze_lm']: text_embeds = self.llama_model.model.embed_tokens(text_tokens.input_ids).expand(batch_size, -1, -1) bos_embeds = self.llama_model.model.embed_tokens(bos) # bsz x 1 x embed_dim else: text_embeds = self.llama_model.model.model.embed_tokens(text_tokens.input_ids).expand(batch_size, -1, -1) bos_embeds = self.llama_model.model.model.embed_tokens(bos) # bsz x 1 x embed_dim input_embeds.append(bos_embeds) input_embeds.append(text_embeds) inputs_embeds = torch.cat(input_embeds, dim=1) # bsz x (1+s2) x embed_dim return inputs_embeds def generate_tokens_embeddings(self, inputs, input_embeds, temperature: float = 0.0, top_p: float = 1.0): """ This function is used to generate the tokens and output embeddings that employed to generate images/videos/audios inputs: dict input_embeds: tensor return: out: the output tokens index output_embeddings: output embeddings for synthesizing images video_output_embedding: output embeddings for synthesizing video audio_output_embedding: output embeddings for synthesizing audio """ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=inputs['stops_id'], encounters=1)]) outputs = self.llama_model.generate( inputs_embeds=input_embeds, max_new_tokens=inputs['max_tgt_len'], top_p=inputs['top_p'], temperature=inputs['temperature'], # repeat_pen, do_sample=True, use_cache=True, stopping_criteria=stopping_criteria, output_hidden_states=True, return_dict_in_generate=True, output_attentions=True ) output_embeddings = [] video_output_embedding = [] audio_output_embedding = [] out = outputs.sequences for _hidden_states in outputs.hidden_states[1:]: for idx in self.args['text_emb_to_img_layers']: output_embeddings.append(_hidden_states[idx]) for idx in self.args['text_emb_to_video_layers']: video_output_embedding.append(_hidden_states[idx]) for idx in self.args['text_emb_to_audio_layers']: audio_output_embedding.append(_hidden_states[idx]) output_embeddings = torch.cat(output_embeddings, dim=1) video_output_embedding = torch.cat(video_output_embedding, dim=1) audio_output_embedding = torch.cat(audio_output_embedding, dim=1) return out, output_embeddings, video_output_embedding, audio_output_embedding def generate_images(self, generated_ids, embeddings, all_gen_idx, generation_model=None, guidance_scale=7.5, num_inference_steps=40): """ To generate the images based on the embeddings generated_ids: the index of the generated tokens embedding: the embeddings for synthesizing images all_gen_idx: the index of [IMG0] in the generated_ids """ last_ret_idx = 0 return_outputs = [] generation_model = StableDiffusionPipeline.from_pretrained(self.sd_ckpt_path, torch_dtype=torch.float16).to( "cuda") for gen_idx in all_gen_idx: assert generated_ids[0, gen_idx:gen_idx + self.args['num_gen_img_tokens']].cpu().detach().numpy().tolist() == self.args[ 'gen_img_token_idx'], ( generated_ids[0, gen_idx:gen_idx + self.args['num_gen_img_tokens']], self.args['gen_img_token_idx']) raw_emb = embeddings[:, gen_idx - 1:gen_idx - 1 + self.args['num_gen_img_tokens'], :] # (1, 8, 4096) # Produce generation embedding. gen_prefix = ' '.join([f'[IMG{i}]' for i in range(self.args['num_gen_img_tokens'])]) gen_prefx_ids = self.llama_tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device) gen_prefix_embs = self.input_embeddings(gen_prefx_ids) # (1, T_I_V_A.txt, D) gen_emb = self.gen_text_hidden_fcs[-1](raw_emb, gen_prefix_embs) # (1, 77, 768) if gen_emb.shape[1] != 77: bs = gen_emb.shape[0] clip_emb = 768 gen_emb = gen_emb.reshape(bs, -1, clip_emb) # (bs, T_I_V_A.txt, 768) seq_len = gen_emb.shape[1] gen_emb = torch.cat([gen_emb, torch.zeros((bs, 77 - seq_len, clip_emb), device=gen_emb.device, dtype=gen_emb.dtype)], dim=1) image_outputs = generation_model(prompt_embeds=gen_emb, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images caption = \ self.llama_tokenizer.batch_decode(generated_ids[:, last_ret_idx:gen_idx], skip_special_tokens=True)[ 0] last_ret_idx = gen_idx + 1 return_outputs.append(caption + f' {gen_prefix}') # return_outputs.append(truncate_caption(caption) + f' {gen_prefix}') return_outputs.append(image_outputs) return return_outputs def generate_videos(self, generated_ids, embeddings, all_gen_idx, generation_model=None, guidance_scale=7.5, num_inference_steps=40, height=320, width=576, num_frames=16): """ To generate videos based on the embeddings generated_ids: the index of the generated tokens embedding: the embeddings for synthesizing videos all_gen_idx: the index of [VID0] in the generated_ids """ return_outputs = [] last_ret_idx = 0 generation_model = TextToVideoSDPipeline.from_pretrained(self.vd_ckpt_path, torch_dtype=torch.float16).to( "cuda") for gen_idx in all_gen_idx: assert generated_ids[0, gen_idx:gen_idx + self.args['num_gen_video_tokens']].cpu().detach().numpy().tolist() == \ self.args[ 'gen_video_token_idx'], ( generated_ids[0, gen_idx:gen_idx + self.args['num_gen_video_tokens']], self.args['gen_video_token_idx']) raw_emb = embeddings[:, gen_idx - 1:gen_idx - 1 + self.args['num_gen_video_tokens'], :] # (1, 8, 4096) # print(f'gen_idx: {gen_idx}') # print('4', raw_emb.size()) # assert len(self.args['text_emb_to_video_layers']) == 1 # Produce generation embedding. gen_prefix = ' '.join([f'[VID{i}]' for i in range(self.args['num_gen_video_tokens'])]) gen_prefx_ids = self.llama_tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device) gen_prefix_embs = self.input_embeddings(gen_prefx_ids) # (1, T_I_V_A.txt, D) gen_emb = self.gen_text_hidden_fcs_video[-1](raw_emb, gen_prefix_embs) # (1, 77, 768) if gen_emb.shape[1] != 77: print(f"Padding {gen_emb.shape} with zeros") bs = gen_emb.shape[0] clip_emb = 768 gen_emb = gen_emb.reshape(bs, -1, clip_emb) # (bs, T_I_V_A.txt, 768) seq_len = gen_emb.shape[1] gen_emb = torch.cat([gen_emb, torch.zeros((bs, 77 - seq_len, clip_emb), device=gen_emb.device, dtype=gen_emb.dtype)], dim=1) print('Padded to', gen_emb.shape) video_outputs = generation_model(prompt_embeds=gen_emb, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, height=height, width=width, num_frames=num_frames).frames caption = \ self.llama_tokenizer.batch_decode(generated_ids[:, last_ret_idx:gen_idx], skip_special_tokens=True)[ 0] last_ret_idx = gen_idx + 1 return_outputs.append(caption + f' {gen_prefix}') # return_outputs.append(truncate_caption(caption) + f' {gen_prefix}') return_outputs.append(video_outputs) return return_outputs def generate_audios(self, generated_ids, embeddings, all_gen_idx, generation_model=None, guidance_scale=7.5, num_inference_steps=40, audio_length_in_s=5.0): """ To generate videos based on the embeddings generated_ids: the index of the generated tokens embedding: the embeddings for synthesizing audios all_gen_idx: the index of [AUD0] in the generated_ids """ return_outputs = [] last_ret_idx = 0 generation_model = AudioLDMPipeline.from_pretrained(self.ad_ckpt_path, torch_dtype=torch.float16).to("cuda") for gen_idx in all_gen_idx: assert generated_ids[0, gen_idx:gen_idx + self.args['num_gen_audio_tokens']].cpu().detach().numpy().tolist() == \ self.args[ 'gen_audio_token_idx'], ( generated_ids[0, gen_idx:gen_idx + self.args['num_gen_audio_tokens']], self.args['gen_audio_token_idx']) raw_emb = embeddings[:, gen_idx - 1:gen_idx - 1 + self.args['num_gen_audio_tokens'], :] # (1, 8, 4096) # print(f'gen_idx: {gen_idx}') # print('raw_emb 4', raw_emb.size()) # assert len(self.args['text_emb_to_video_layers']) == 1 # Produce generation embedding. gen_prefix = ' '.join([f'[AUD{i}]' for i in range(self.args['num_gen_audio_tokens'])]) gen_prefx_ids = self.llama_tokenizer(gen_prefix, add_special_tokens=False, return_tensors="pt").input_ids.to(self.device) gen_prefix_embs = self.input_embeddings(gen_prefx_ids) # (1, T_I_V_A.txt, D) gen_emb = self.gen_text_hidden_fcs_audio[-1](raw_emb, gen_prefix_embs) # (1, 77, 768) # print('gen_emb size:', gen_emb.size()) bs = gen_emb.shape[0] hid_emb_size = gen_emb.shape[2] gen_emb = gen_emb.view(bs, hid_emb_size) audio_outputs = generation_model(prompt_embeds=gen_emb, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, audio_length_in_s=audio_length_in_s).audios[0] caption = \ self.llama_tokenizer.batch_decode(generated_ids[:, last_ret_idx:gen_idx], skip_special_tokens=True)[ 0] last_ret_idx = gen_idx + 1 return_outputs.append(caption + f' {gen_prefix}') # return_outputs.append(truncate_caption(caption) + f' {gen_prefix}') return_outputs.append(audio_outputs) return return_outputs def generate(self, inputs): """ inputs = { 'image_paths': optional, 'audio_paths': optional 'video_paths': optional 'thermal_paths': optional 'mode': generation mode, 'prompt': human input prompt, 'max_tgt_len': generation length, 'top_p': top_p, 'temperature': temperature, Used to modulate logit distribution. 'modality_embeds': None or torch.tensor, 'modality_cache': save the image cache, 'filter_value': Value to assign to tokens that should never be generated, 'min_word_tokens': Minimum number of words to generate before allowing a [IMG] output. 'gen_scale_factor': float = 1.0, 'stops_id': the default value is [[835], [2277, 29937]] the stop token is '###', which has two types of tokenization ways, [835] and [2277, 29937] 'ENCOUNTERS': the times that the generated sentence will be ended. 'load_sd': whether use SD for image generation 'max_num_imgs': Maximum number of images to return in one generation pass. 'guidance_scale_for_img': the guidance ratio of conditioner, if it is None, the default value will be applied in SD 'num_inference_steps_for_img': the number of inference step for image generation in the stable diffusion model 'load_vd': whether use VD for video generation 'max_num_vids': Maximum number of videos to return in one generation pass. 'guidance_scale_for_vid': the guidance ratio of conditioner, if it is None, the default value will be applied in VD 'num_inference_steps_for_vid': the number of inference step for video generation in the stable diffusion model 'height': (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated video. 'width': (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated video. 'num_frames': (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. 'load_ad': whether use AD for audio generation 'max_num_auds': Maximum number of audios to return in one generation pass. 'guidance_scale_for_aud': the guidance ratio of conditioner, if it is None, the default value will be applied in AD 'num_inference_steps_for_aud': the number of inference step for audio generation in the stable diffusion model 'audio_length_in_s': the seconds for generated audio length } """ # init output with image tokens input_embeds = self.prepare_generation_embedding(inputs) generated_ids, generated_image_embeddings, generated_video_embeddings, generated_audio_embeddings = self.generate_tokens_embeddings( inputs, input_embeds) return_outputs = [] # Find up to max_num_rets [IMG] tokens, and their corresponding scores. all_gen_img_idx = [i for i, x in enumerate(generated_ids[0, :] == self.args['gen_img_token_idx'][0]) if x][ :inputs['max_num_imgs']] print('all_gen_img_idx: ', all_gen_img_idx) # Find up to max_num_rest [VID] tokens, and their corresponding scores. all_gen_vid_idx = [i for i, x in enumerate(generated_ids[0, :] == self.args['gen_video_token_idx'][0]) if x][ :inputs['max_num_vids']] print('all_gen_vid_idx: ', all_gen_vid_idx) # Find up to max_num_rest [AUD] tokens, and their corresponding scores. all_gen_aud_idx = [i for i, x in enumerate(generated_ids[0, :] == self.args['gen_audio_token_idx'][0]) if x][ :inputs['max_num_auds']] print('all_gen_aud_idx: ', all_gen_aud_idx) if len(all_gen_img_idx) == 0 and len(all_gen_vid_idx) == 0 and len(all_gen_aud_idx) == 0: # No [IMG], [VID], [AUD] tokens. caption = self.llama_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # return_outputs.append(truncate_caption(caption)) return_outputs.append(caption) else: if len(all_gen_img_idx) > 0: img_outputs = self.generate_images(generated_ids, generated_image_embeddings, all_gen_img_idx, None, guidance_scale=inputs['guidance_scale_for_img'], num_inference_steps=inputs['num_inference_steps_for_img'], ) return_outputs.append({'img': img_outputs}) if len(all_gen_vid_idx) > 0: vid_outputs = self.generate_videos(generated_ids, generated_video_embeddings, all_gen_vid_idx, None, guidance_scale=inputs['guidance_scale_for_vid'], num_inference_steps=inputs['num_inference_steps_for_vid'], height=inputs['height'], width=inputs['width'], num_frames=inputs['num_frames']) return_outputs.append({'vid': vid_outputs}) if len(all_gen_aud_idx) > 0: aud_outputs = self.generate_audios(generated_ids, generated_audio_embeddings, all_gen_aud_idx, None, guidance_scale=inputs['guidance_scale_for_aud'], num_inference_steps=inputs['num_inference_steps_for_aud'], audio_length_in_s=inputs['audio_length_in_s']) return_outputs.append({'aud': aud_outputs}) return return_outputs