from typing import Any, Dict import numpy as np import torch import torchaudio from encodec import EncodecModel from encodec.utils import convert_audio from laion_clap import CLAP_Module from transformers import AutoTokenizer from modeling.enclap_bart import EnClapBartConfig, EnClapBartForConditionalGeneration class EnClap: def __init__( self, ckpt_path: str, clap_audio_model: str = "HTSAT-tiny", clap_enable_fusion = True, device: str = "cuda", ): config = EnClapBartConfig.from_pretrained(ckpt_path) self.device = device self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large") self.model = ( EnClapBartForConditionalGeneration.from_pretrained(ckpt_path) .to(self.device) .eval() ) self.encodec = EncodecModel.encodec_model_24khz().to(self.device) self.encodec.set_target_bandwidth(12.0) self.clap_model = CLAP_Module(enable_fusion=clap_enable_fusion, amodel=clap_audio_model, device=self.device) self.clap_model.load_ckpt() self.generation_config = { "_from_model_config": True, "bos_token_id": 0, "decoder_start_token_id": 2, "early_stopping": True, "eos_token_id": 2, "forced_bos_token_id": 0, "forced_eos_token_id": 2, "no_repeat_ngram_size": 3, "num_beams": 4, "pad_token_id": 1, "max_length": 50, } self.max_seq_len = config.max_position_embeddings - 3 @torch.no_grad() def infer_from_audio_file( self, audio_file: str, generation_config: Dict[str, Any] = None ) -> str: if generation_config is None: generation_config = self.generation_config audio, res = torchaudio.load(audio_file) return self.infer_from_audio(audio[0], res) @torch.no_grad() def infer_from_audio( self, audio: torch.Tensor, res: int, generation_config: Dict[str, Any] = None ) -> str: if generation_config is None: generation_config = self.generation_config if audio.dtype == torch.short: audio = audio / 2**15 if audio.dtype == torch.int: audio = audio / 2**31 encodec_audio = ( convert_audio( audio.unsqueeze(0), res, self.encodec.sample_rate, self.encodec.channels ) .unsqueeze(0) .to(self.device) ) encodec_frames = self.encodec.encode(encodec_audio) encodec_frames = torch.cat( [codebook for codebook, _ in encodec_frames], dim=-1 ).mT clap_audio = torchaudio.transforms.Resample(res, 48000)(audio).unsqueeze(0) clap_embedding = self.clap_model.get_audio_embedding_from_data(clap_audio, use_tensor=True) return self._infer(encodec_frames, clap_embedding, generation_config) @torch.no_grad() def _infer( self, encodec_frames: torch.LongTensor, clap_embedding: torch.Tensor, generation_config: Dict[str, Any] = None, ) -> str: input_ids = torch.cat( [ torch.ones( (encodec_frames.shape[0], 2, encodec_frames.shape[-1]), dtype=torch.long, ).to(self.device) * self.tokenizer.bos_token_id, encodec_frames[:, : self.max_seq_len], torch.ones( (encodec_frames.shape[0], 1, encodec_frames.shape[-1]), dtype=torch.long, ).to(self.device) * self.tokenizer.eos_token_id, ], dim=1, ) encodec_mask = torch.LongTensor( [[0, 0] + [1] * (input_ids.shape[1] - 3) + [0]] ).to(self.device) enclap_bart_inputs = { "input_ids": input_ids, "encodec_mask": encodec_mask, "clap_embedding": clap_embedding, } results = self.model.generate(**enclap_bart_inputs, **generation_config) caption = self.tokenizer.batch_decode(results, skip_special_tokens=True) return caption @torch.no_grad() def infer_from_encodec( self, file_path, clap_path: str = "clap", generation_config: Dict[str, Any] = None, ): if generation_config is None: generation_config = self.generation_config input_ids = np.load(file_path) if input_ids.shape[0] > self.max_encodec_length: input_ids = input_ids[: self.max_encodec_length, :] input_length = input_ids.shape[0] input_ids = np.concatenate([input_ids, self.eos_padding], axis=0) input_ids = torch.LongTensor(input_ids) input_ids = input_ids.unsqueeze(0).to(self.device) attention_mask = ( torch.ones(input_length + 3, dtype=torch.int64).unsqueeze(0).to(self.device) ) eos_mask = [0] * (input_length + 3) eos_mask[input_length + 2] = 1 eos_mask = torch.BoolTensor(eos_mask).unsqueeze(0) # Load CLAP clap_path = file_path.replace("encodec_16", clap_path) clap = np.load(clap_path) clap = torch.Tensor(clap).unsqueeze(0).to(self.device) input = { "input_ids": input_ids, "clap": clap, "attention_mask": attention_mask, "eos_mask": eos_mask, } generated_ids = self.model.generate(**input, **generation_config) text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) return text