import torch import os import logging import torch.nn.functional as F from slam_llm.models.slam_model import ( slam_model, setup_tokenizer, setup_encoder, setup_encoder_projector, setup_llm, ) from slam_llm.utils.train_utils import print_model_size from typing import List, Optional from slam_llm.utils.metric import compute_accuracy from transformers import T5ForConditionalGeneration from tqdm import tqdm from utils.tts_adapter_utils import setup_tts_adapter from utils.codec_utils import setup_codec from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only from utils.snac_utils import layershift logger = logging.getLogger(__name__) def model_factory(train_config, model_config, ckpt_path, **kwargs): # return necessary components for training tokenizer = setup_tokenizer(train_config, model_config, **kwargs) if train_config.task_type == "s2s" or train_config.task_type == "asr": encoder = setup_encoder(train_config, model_config, **kwargs) elif train_config.task_type == "tts": encoder = None else: raise NotImplementedError # llm llm = setup_llm(train_config, model_config, **kwargs) # projector if encoder is not None: encoder_projector = setup_encoder_projector( train_config, model_config, **kwargs ) else: encoder_projector = None codec_decoder = None if model_config.codec_decode: codec_decoder = setup_codec(train_config, model_config, **kwargs) tts_adapter = None if model_config.tts_adapter: adapter_config = model_config.tts_adapter_config tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs) model = slam_model_s2s( encoder, llm, encoder_projector, tokenizer, tts_adapter, codec_decoder, train_config, model_config, **kwargs, ) if ckpt_path is not None: logger.info("loading other parts from: {}".format(ckpt_path)) ckpt_dict = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt_dict, strict=False) if train_config.train_audio_embed_only: partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize) if train_config.train_embed_only: train_embedding_layer_only(model) print_model_size( model, train_config, ( int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0 ), ) return model, tokenizer class slam_model_s2s(slam_model): def __init__( self, encoder, llm, encoder_projector, tokenizer, tts_adapter, codec_decoder, train_config, model_config, **kwargs, ): super().__init__( encoder, llm, encoder_projector, tokenizer, train_config, model_config, **kwargs, ) # resize llm embedding layer self.original_vocabsize = self.llm.lm_head.weight.size(0) if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize: self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize) if int(os.environ.get("RANK", "0")) == 0: logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize)) self.codec_decoder = codec_decoder self.tts_adapter = tts_adapter self.code_layer = self.model_config.vocab_config.code_layer def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ): audio_mel = kwargs.get("audio_mel", None) audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper audio = kwargs.get("audio", None) audio_mask = kwargs.get("audio_mask", None) modality_mask = kwargs.get("modality_mask", None) encoder_outs = None if audio_mel is not None or audio is not None: if self.train_config.freeze_encoder: # freeze encoder self.encoder.eval() if self.model_config.encoder_name == "whisper": encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim if self.model_config.encoder_name == "wavlm": encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask if self.model_config.encoder_name == "hubert": results = self.encoder(source = audio, padding_mask = 1-audio_mask) if self.model_config.encoder_type == "pretrain": encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] if self.model_config.encoder_type == "finetune": encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] encoder_outs = encoder_outs.transpose(0, 1) if self.encoder is None: encoder_outs = audio_mel if audio_mel is not None else audio if self.model_config.encoder_projector == "q-former": encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) if self.model_config.encoder_projector == "linear": encoder_outs = self.encoder_projector(encoder_outs) if self.model_config.encoder_projector == "cov1d-linear": encoder_outs = self.encoder_projector(encoder_outs) if input_ids is not None: input_ids[input_ids == -1] = 0 # [btz, 8, seq_length] if isinstance(self.llm, T5ForConditionalGeneration): inputs_embeds = self.llm.shared(input_ids) else: if hasattr(self.llm.model, "embed_tokens"): inputs_embeds = self.llm.model.embed_tokens(input_ids) # [btz, 8, seq_length, emb_dim] elif hasattr(self.llm.model.model, "embed_tokens"): inputs_embeds = self.llm.model.model.embed_tokens(input_ids) else: inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) if modality_mask is not None and encoder_outs is not None: modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1) # [btz, 8, seq_length] modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2) modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist() encoder_outs_pad = torch.zeros_like(inputs_embeds) for i in range(encoder_outs.shape[0]): for j in range(self.code_layer): start_idx = modality_mask_start_indices[i, j].item() length = modality_lengths[i][j] encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length] inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None]) inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers if kwargs.get("inference_mode", False): return inputs_embeds, attention_mask text_labels = labels[:,self.code_layer] if labels is not None else None audio_labels = labels[:, :self.code_layer] if labels is not None else None model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels) # here we use the text token layer as the target label # parrallel generation # TODO: add tts adapter forward x_ori = model_outputs.logits text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize xt = x_ori[..., :text_vocab_size] xa = [] for i in range(self.code_layer): xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)]) loss_recorder = [] total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels) model_outputs.loss = total_loss text_acc = -1 audio_acc = [-1 for _ in range(self.code_layer)] if self.metric: with torch.no_grad(): preds = torch.argmax(xt, -1) text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100) preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)] audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)] # metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder} return model_outputs, text_acc, audio_acc, loss_recorder def compute_parallel_loss(self, xt, text_labels, xa, audio_labels): """ Compute the parallel loss for text and audio layers. """ text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize layer_loss = [0 for _ in range(self.code_layer+1) ] if text_labels is not None: # text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100) text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100) layer_loss[self.code_layer] = text_loss else: text_loss = 0 total_audio_loss = 0 single_audio_loss = 0 for i in range(self.code_layer): if audio_labels[:,i] is not None: # audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100) single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100) layer_loss[i] = single_audio_loss total_audio_loss += single_audio_loss total_loss = (text_loss + total_audio_loss) / (self.code_layer+1) return total_loss, layer_loss @torch.no_grad() def generate(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ): kwargs["inference_mode"] = True inputs_embeds, attention_mask = self.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) generated_ids = [[] for _ in range((self.code_layer+1))] current_input_text = None current_audio_tokens = [None for _ in range(self.code_layer)] # input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0) past_key_values = None text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize max_new_tokens = kwargs.get("max_new_tokens", 360) repetition_penalty = kwargs.get("repetition_penalty", 1.0) decode_text_only = kwargs.get("decode_text_only", False) pad_t = self.model_config.vocab_config.pad_t pad_a = self.model_config.vocab_config.pad_a eot = self.model_config.vocab_config.eot eoa = self.model_config.vocab_config.eoa text_end = False # Track whether text generation has ended audio_end = False # Track whether audio generation has ended # NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search for step in tqdm(range(max_new_tokens), desc="Generating"): if current_input_text is not None: audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1) combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1) inputs_embeds = self.llm.model.embed_tokens(combined_input_ids) inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1) outputs = self.llm( inputs_embeds=inputs_embeds, # [btz, seq_len / 1, emb_dim] attention_mask=attention_mask, # single sample, no need for attention mask past_key_values=past_key_values, # position_ids=input_pos, use_cache=True, ) logits = outputs.logits past_key_values = outputs.past_key_values # Update past_key_values for the next step # Split logits into text and audio layers based on vocab size xt_logits = logits[..., :text_vocab_size] xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)] # Apply repetition penalty to the logits if repetition_penalty != 1.0: xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty) for i in range(self.code_layer): xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty) if not text_end: next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs) else: next_token_text = torch.tensor([pad_t], device=input_ids.device) next_tokens_audio = [] for i in range(self.code_layer): if not audio_end and not decode_text_only: next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs) else: next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device) next_tokens_audio.append(next_token_audio) if next_tokens_audio[-1] == eoa or decode_text_only: audio_end = True if next_token_text == eot: text_end = True # Update input_ids for the next step current_input_text = next_token_text for i in range(self.code_layer): current_audio_tokens[i] = next_tokens_audio[i] # if input_pos.size(-1) > 1: # input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0) # else: # input_pos = input_pos.add_(1) attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1) if audio_end and text_end: break # Append generated tokens to the list for i in range(self.code_layer): generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0]) # Audio layers generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0]) # Text layer # Concatenate the generated tokens to form the complete sequence text_tokens = generated_ids[-1] generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens generated_ids = [torch.tensor(layer) for layer in generated_ids] return generated_ids @torch.no_grad() def sample_next_token(self, logits, **kwargs): """ Generate the next token based on the model output logits. Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling. """ do_sample = kwargs.get("do_sample", False) temperature = kwargs.get("temperature", 1.0) top_k = kwargs.get("top_k", 50) top_p = kwargs.get("top_p", 1.0) num_samples = kwargs.get("num_samples", 1) # Adjust logits with temperature logits = logits.squeeze(0) logits = logits / temperature # Top-k filtering if top_k > 0: top_k = min(top_k, logits.size(-1)) # Make sure top_k is within the vocab size values, indices = torch.topk(logits, top_k) logits[logits < values[..., [-1]]] = -float('Inf') # Filter tokens not in top_k # Top-p filtering (nucleus sampling) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = -float('Inf') if do_sample: # Perform sampling return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples) else: # Greedy decoding (argmax) return torch.argmax(logits, dim=-1, keepdim=True) def repetition_penalty(self, logits, generated_ids, repetition_penalty): """ Apply repetition penalty to the logits. """ for token_id in set(generated_ids): if logits[0, -1, token_id] < 0: logits[0, -1, token_id] *= repetition_penalty else: logits[0, -1, token_id] /= repetition_penalty return logits