Spaces:
Build error
Build error
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
import os | |
import string | |
import yaml | |
from copy import deepcopy | |
import torch | |
from transformers import AutoTokenizer, set_seed | |
set_seed(0) | |
from data import AudioTextDataProcessor | |
from src.factory import create_model_and_transforms | |
def prepare_tokenizer(model_config): | |
tokenizer_path = model_config['tokenizer_path'] | |
cache_dir = model_config['cache_dir'] | |
text_tokenizer = AutoTokenizer.from_pretrained( | |
tokenizer_path, | |
local_files_only=False, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
text_tokenizer.add_special_tokens( | |
{"additional_special_tokens": ["<audio>", "<|endofchunk|>"]} | |
) | |
if text_tokenizer.pad_token is None: | |
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
if text_tokenizer.sep_token is None: | |
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
return text_tokenizer | |
def prepare_model(model_config, clap_config, checkpoint_path, device=0): | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
model, tokenizer = create_model_and_transforms( | |
**model_config, | |
clap_config=clap_config, | |
use_local_files=False, | |
gradient_checkpointing=False, | |
freeze_lm_embeddings=False, | |
) | |
model.eval() | |
model = model.to(device) | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
model_state_dict = checkpoint["model_state_dict"] | |
model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()} | |
model.load_state_dict(model_state_dict, False) | |
return model | |
def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0): | |
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item | |
audio_clips = audio_clips.to(device, dtype=None, non_blocking=True) | |
audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True) | |
input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze() | |
media_token_id = tokenizer.encode("<audio>")[-1] | |
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1] | |
sep_token_id = tokenizer.sep_token_id | |
eos_token_id = tokenizer.eos_token_id | |
outputs = model.generate( | |
audio_x=audio_clips.unsqueeze(0), | |
audio_x_mask=audio_embed_mask.unsqueeze(0), | |
lang_x=input_ids.unsqueeze(0), | |
eos_token_id=eos_token_id, | |
max_new_tokens=128, | |
**inference_kwargs, | |
) | |
outputs_decoded = [ | |
tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs | |
] | |
return outputs_decoded | |