enclap / inference.py
tonyswoo's picture
Fix scaling
446cebb
from typing import Any, Dict
import numpy as np
import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
from laion_clap import CLAP_Module
from transformers import AutoTokenizer
from modeling.enclap_bart import EnClapBartConfig, EnClapBartForConditionalGeneration
class EnClap:
def __init__(
self,
ckpt_path: str,
clap_audio_model: str = "HTSAT-tiny",
clap_enable_fusion = True,
device: str = "cuda",
):
config = EnClapBartConfig.from_pretrained(ckpt_path)
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
self.model = (
EnClapBartForConditionalGeneration.from_pretrained(ckpt_path)
.to(self.device)
.eval()
)
self.encodec = EncodecModel.encodec_model_24khz().to(self.device)
self.encodec.set_target_bandwidth(12.0)
self.clap_model = CLAP_Module(enable_fusion=clap_enable_fusion, amodel=clap_audio_model, device=self.device)
self.clap_model.load_ckpt()
self.generation_config = {
"_from_model_config": True,
"bos_token_id": 0,
"decoder_start_token_id": 2,
"early_stopping": True,
"eos_token_id": 2,
"forced_bos_token_id": 0,
"forced_eos_token_id": 2,
"no_repeat_ngram_size": 3,
"num_beams": 4,
"pad_token_id": 1,
"max_length": 50,
}
self.max_seq_len = config.max_position_embeddings - 3
@torch.no_grad()
def infer_from_audio_file(
self, audio_file: str, generation_config: Dict[str, Any] = None
) -> str:
if generation_config is None:
generation_config = self.generation_config
audio, res = torchaudio.load(audio_file)
return self.infer_from_audio(audio[0], res)
@torch.no_grad()
def infer_from_audio(
self, audio: torch.Tensor, res: int, generation_config: Dict[str, Any] = None
) -> str:
if generation_config is None:
generation_config = self.generation_config
if audio.dtype == torch.short:
audio = audio / 2**15
if audio.dtype == torch.int:
audio = audio / 2**31
encodec_audio = (
convert_audio(
audio.unsqueeze(0), res, self.encodec.sample_rate, self.encodec.channels
)
.unsqueeze(0)
.to(self.device)
)
encodec_frames = self.encodec.encode(encodec_audio)
encodec_frames = torch.cat(
[codebook for codebook, _ in encodec_frames], dim=-1
).mT
clap_audio = torchaudio.transforms.Resample(res, 48000)(audio).unsqueeze(0)
clap_embedding = self.clap_model.get_audio_embedding_from_data(clap_audio, use_tensor=True)
return self._infer(encodec_frames, clap_embedding, generation_config)
@torch.no_grad()
def _infer(
self,
encodec_frames: torch.LongTensor,
clap_embedding: torch.Tensor,
generation_config: Dict[str, Any] = None,
) -> str:
input_ids = torch.cat(
[
torch.ones(
(encodec_frames.shape[0], 2, encodec_frames.shape[-1]),
dtype=torch.long,
).to(self.device)
* self.tokenizer.bos_token_id,
encodec_frames[:, : self.max_seq_len],
torch.ones(
(encodec_frames.shape[0], 1, encodec_frames.shape[-1]),
dtype=torch.long,
).to(self.device)
* self.tokenizer.eos_token_id,
],
dim=1,
)
encodec_mask = torch.LongTensor(
[[0, 0] + [1] * (input_ids.shape[1] - 3) + [0]]
).to(self.device)
enclap_bart_inputs = {
"input_ids": input_ids,
"encodec_mask": encodec_mask,
"clap_embedding": clap_embedding,
}
results = self.model.generate(**enclap_bart_inputs, **generation_config)
caption = self.tokenizer.batch_decode(results, skip_special_tokens=True)
return caption
@torch.no_grad()
def infer_from_encodec(
self,
file_path,
clap_path: str = "clap",
generation_config: Dict[str, Any] = None,
):
if generation_config is None:
generation_config = self.generation_config
input_ids = np.load(file_path)
if input_ids.shape[0] > self.max_encodec_length:
input_ids = input_ids[: self.max_encodec_length, :]
input_length = input_ids.shape[0]
input_ids = np.concatenate([input_ids, self.eos_padding], axis=0)
input_ids = torch.LongTensor(input_ids)
input_ids = input_ids.unsqueeze(0).to(self.device)
attention_mask = (
torch.ones(input_length + 3, dtype=torch.int64).unsqueeze(0).to(self.device)
)
eos_mask = [0] * (input_length + 3)
eos_mask[input_length + 2] = 1
eos_mask = torch.BoolTensor(eos_mask).unsqueeze(0)
# Load CLAP
clap_path = file_path.replace("encodec_16", clap_path)
clap = np.load(clap_path)
clap = torch.Tensor(clap).unsqueeze(0).to(self.device)
input = {
"input_ids": input_ids,
"clap": clap,
"attention_mask": attention_mask,
"eos_mask": eos_mask,
}
generated_ids = self.model.generate(**input, **generation_config)
text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return text