import yaml from typing import Dict, List import torch import torch.nn as nn import numpy as np import librosa from scipy.io.wavfile import write from utils import ignore_warnings; ignore_warnings() from utils import parse_yaml, load_ss_model from models.clap_encoder import CLAP_Encoder def build_audiosep(config_yaml, checkpoint_path, device): configs = parse_yaml(config_yaml) query_encoder = CLAP_Encoder().eval() model = load_ss_model( configs=configs, checkpoint_path=checkpoint_path, query_encoder=query_encoder ).eval().to(device) print(f'Load AudioSep model from [{checkpoint_path}]') return model def inference(model, audio_file, text, output_file, device='cuda'): print(f'Separate audio from [{audio_file}] with textual query [{text}]') mixture, fs = librosa.load(audio_file, sr=32000, mono=True) with torch.no_grad(): text = [text] conditions = model.query_encoder.get_query_embed( modality='text', text=text, device=device ) input_dict = { "mixture": torch.Tensor(mixture)[None, None, :].to(device), "condition": conditions, } sep_segment = model.ss_model(input_dict)["waveform"] sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy() write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16)) print(f'Write separated audio to [{output_file}]') if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = build_audiosep( config_yaml='config/audiosep_base.yaml', checkpoint_path='checkpoint/step=3920000.ckpt', device=device) audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav' text = 'pigeons are cooing in the background' output_file='separated_audio.wav' inference(model, audio_file, text, output_file, device)