File size: 2,016 Bytes
ae29df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import yaml
from typing import Dict, List
import torch
import torch.nn as nn
import numpy as np
import librosa
from scipy.io.wavfile import write
from utils import ignore_warnings; ignore_warnings()
from utils import parse_yaml, load_ss_model
from models.clap_encoder import CLAP_Encoder
def build_audiosep(config_yaml, checkpoint_path, device):
configs = parse_yaml(config_yaml)
query_encoder = CLAP_Encoder().eval()
model = load_ss_model(
configs=configs,
checkpoint_path=checkpoint_path,
query_encoder=query_encoder
).eval().to(device)
print(f'Load AudioSep model from [{checkpoint_path}]')
return model
def inference(model, audio_file, text, output_file, device='cuda'):
print(f'Separate audio from [{audio_file}] with textual query [{text}]')
mixture, fs = librosa.load(audio_file, sr=32000, mono=True)
with torch.no_grad():
text = [text]
conditions = model.query_encoder.get_query_embed(
modality='text',
text=text,
device=device
)
input_dict = {
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
"condition": conditions,
}
sep_segment = model.ss_model(input_dict)["waveform"]
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16))
print(f'Write separated audio to [{output_file}]')
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_audiosep(
config_yaml='config/audiosep_base.yaml',
checkpoint_path='checkpoint/step=3920000.ckpt',
device=device)
audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav'
text = 'pigeons are cooing in the background'
output_file='separated_audio.wav'
inference(model, audio_file, text, output_file, device)
|