audio-flamingo-demo / inference_utils.py
ZhifengKong's picture
update
0195d32
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import string
import yaml
from copy import deepcopy
import torch
from transformers import AutoTokenizer, set_seed
set_seed(0)
from data import AudioTextDataProcessor
from src.factory import create_model_and_transforms
def prepare_tokenizer(model_config):
tokenizer_path = model_config['tokenizer_path']
cache_dir = model_config['cache_dir']
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=False,
trust_remote_code=True,
cache_dir=cache_dir,
)
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
)
if text_tokenizer.pad_token is None:
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
if text_tokenizer.sep_token is None:
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
return text_tokenizer
def prepare_model(model_config, clap_config, checkpoint_path, device=0):
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
model, tokenizer = create_model_and_transforms(
**model_config,
clap_config=clap_config,
use_local_files=False,
gradient_checkpointing=False,
freeze_lm_embeddings=False,
)
model.eval()
model = model.to(device)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model_state_dict = checkpoint["model_state_dict"]
model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
model.load_state_dict(model_state_dict, False)
return model
def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0):
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
audio_clips = audio_clips.to(device, dtype=None, non_blocking=True)
audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True)
input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze()
media_token_id = tokenizer.encode("<audio>")[-1]
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
sep_token_id = tokenizer.sep_token_id
eos_token_id = tokenizer.eos_token_id
outputs = model.generate(
audio_x=audio_clips.unsqueeze(0),
audio_x_mask=audio_embed_mask.unsqueeze(0),
lang_x=input_ids.unsqueeze(0),
eos_token_id=eos_token_id,
max_new_tokens=128,
**inference_kwargs,
)
outputs_decoded = [
tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs
]
return outputs_decoded