Spaces:
Build error
Build error
Update audio_foundation_models.py
Browse files- audio_foundation_models.py +285 -2
audio_foundation_models.py
CHANGED
@@ -4,6 +4,8 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
|
4 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
5 |
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
|
6 |
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
|
|
|
|
|
7 |
import matplotlib
|
8 |
import librosa
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
|
@@ -40,7 +42,16 @@ from utils.hparams import set_hparams
|
|
40 |
from utils.hparams import hparams as hp
|
41 |
from utils.os_utils import move_file
|
42 |
import scipy.io.wavfile as wavfile
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def prompts(name, description):
|
46 |
def decorator(func):
|
@@ -520,4 +531,276 @@ class A2T:
|
|
520 |
def inference(self, audio_path):
|
521 |
audio = whisper.load_audio(audio_path)
|
522 |
caption_text = self.model(audio)
|
523 |
-
return caption_text[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
5 |
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
|
6 |
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
|
7 |
+
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'audio_detection'))
|
8 |
+
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mono2binaural'))
|
9 |
import matplotlib
|
10 |
import librosa
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
|
|
|
42 |
from utils.hparams import hparams as hp
|
43 |
from utils.os_utils import move_file
|
44 |
import scipy.io.wavfile as wavfile
|
45 |
+
from audio_infer.utils import config as detection_config
|
46 |
+
from audio_infer.pytorch.models import PVT
|
47 |
+
from src.models import BinauralNetwork
|
48 |
+
from sound_extraction.model.LASSNet import LASSNet
|
49 |
+
from sound_extraction.utils.stft import STFT
|
50 |
+
from sound_extraction.utils.wav_io import load_wav, save_wav
|
51 |
+
from target_sound_detection.src import models as tsd_models
|
52 |
+
from target_sound_detection.src.models import event_labels
|
53 |
+
from target_sound_detection.src.utils import median_filter, decode_with_timestamps
|
54 |
+
import clip
|
55 |
|
56 |
def prompts(name, description):
|
57 |
def decorator(func):
|
|
|
531 |
def inference(self, audio_path):
|
532 |
audio = whisper.load_audio(audio_path)
|
533 |
caption_text = self.model(audio)
|
534 |
+
return caption_text[0]
|
535 |
+
|
536 |
+
class SoundDetection:
|
537 |
+
def __init__(self, device):
|
538 |
+
self.device = device
|
539 |
+
self.sample_rate = 32000
|
540 |
+
self.window_size = 1024
|
541 |
+
self.hop_size = 320
|
542 |
+
self.mel_bins = 64
|
543 |
+
self.fmin = 50
|
544 |
+
self.fmax = 14000
|
545 |
+
self.model_type = 'PVT'
|
546 |
+
self.checkpoint_path = 'audio_detection/audio_infer/useful_ckpts/audio_detection.pth'
|
547 |
+
self.classes_num = detection_config.classes_num
|
548 |
+
self.labels = detection_config.labels
|
549 |
+
self.frames_per_second = self.sample_rate // self.hop_size
|
550 |
+
# Model = eval(self.model_type)
|
551 |
+
self.model = PVT(sample_rate=self.sample_rate, window_size=self.window_size,
|
552 |
+
hop_size=self.hop_size, mel_bins=self.mel_bins, fmin=self.fmin, fmax=self.fmax,
|
553 |
+
classes_num=self.classes_num)
|
554 |
+
checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
|
555 |
+
self.model.load_state_dict(checkpoint['model'])
|
556 |
+
self.model.to(device)
|
557 |
+
|
558 |
+
@prompts(name="Detect The Sound Event From The Audio",
|
559 |
+
description="useful for when you want to know what event in the audio and the sound event start or end time, "
|
560 |
+
"receives audio_path as input. "
|
561 |
+
"The input to this tool should be a string, "
|
562 |
+
"representing the audio_path. " )
|
563 |
+
|
564 |
+
def inference(self, audio_path):
|
565 |
+
# Forward
|
566 |
+
(waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
|
567 |
+
waveform = waveform[None, :] # (1, audio_length)
|
568 |
+
waveform = torch.from_numpy(waveform)
|
569 |
+
waveform = waveform.to(self.device)
|
570 |
+
# Forward
|
571 |
+
with torch.no_grad():
|
572 |
+
self.model.eval()
|
573 |
+
batch_output_dict = self.model(waveform, None)
|
574 |
+
framewise_output = batch_output_dict['framewise_output'].data.cpu().numpy()[0]
|
575 |
+
"""(time_steps, classes_num)"""
|
576 |
+
# print('Sound event detection result (time_steps x classes_num): {}'.format(
|
577 |
+
# framewise_output.shape))
|
578 |
+
import numpy as np
|
579 |
+
import matplotlib.pyplot as plt
|
580 |
+
sorted_indexes = np.argsort(np.max(framewise_output, axis=0))[::-1]
|
581 |
+
top_k = 10 # Show top results
|
582 |
+
top_result_mat = framewise_output[:, sorted_indexes[0 : top_k]]
|
583 |
+
"""(time_steps, top_k)"""
|
584 |
+
# Plot result
|
585 |
+
stft = librosa.core.stft(y=waveform[0].data.cpu().numpy(), n_fft=self.window_size,
|
586 |
+
hop_length=self.hop_size, window='hann', center=True)
|
587 |
+
frames_num = stft.shape[-1]
|
588 |
+
fig, axs = plt.subplots(2, 1, sharex=True, figsize=(10, 4))
|
589 |
+
axs[0].matshow(np.log(np.abs(stft)), origin='lower', aspect='auto', cmap='jet')
|
590 |
+
axs[0].set_ylabel('Frequency bins')
|
591 |
+
axs[0].set_title('Log spectrogram')
|
592 |
+
axs[1].matshow(top_result_mat.T, origin='upper', aspect='auto', cmap='jet', vmin=0, vmax=1)
|
593 |
+
axs[1].xaxis.set_ticks(np.arange(0, frames_num, self.frames_per_second))
|
594 |
+
axs[1].xaxis.set_ticklabels(np.arange(0, frames_num / self.frames_per_second))
|
595 |
+
axs[1].yaxis.set_ticks(np.arange(0, top_k))
|
596 |
+
axs[1].yaxis.set_ticklabels(np.array(self.labels)[sorted_indexes[0 : top_k]])
|
597 |
+
axs[1].yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3)
|
598 |
+
axs[1].set_xlabel('Seconds')
|
599 |
+
axs[1].xaxis.set_ticks_position('bottom')
|
600 |
+
plt.tight_layout()
|
601 |
+
image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
|
602 |
+
plt.savefig(image_filename)
|
603 |
+
return image_filename
|
604 |
+
|
605 |
+
class SoundExtraction:
|
606 |
+
def __init__(self, device):
|
607 |
+
self.device = device
|
608 |
+
self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
|
609 |
+
self.stft = STFT()
|
610 |
+
import torch.nn as nn
|
611 |
+
self.model = nn.DataParallel(LASSNet(device)).to(device)
|
612 |
+
checkpoint = torch.load(self.model_file)
|
613 |
+
self.model.load_state_dict(checkpoint['model'])
|
614 |
+
self.model.eval()
|
615 |
+
|
616 |
+
@prompts(name="Extract Sound Event From Mixture Audio Based On Language Description",
|
617 |
+
description="useful for when you extract target sound from a mixture audio, you can describe the target sound by text, "
|
618 |
+
"receives audio_path and text as input. "
|
619 |
+
"The input to this tool should be a comma seperated string of two, "
|
620 |
+
"representing mixture audio path and input text." )
|
621 |
+
|
622 |
+
def inference(self, inputs):
|
623 |
+
#key = ['ref_audio', 'text']
|
624 |
+
val = inputs.split(",")
|
625 |
+
audio_path = val[0] # audio_path, text
|
626 |
+
text = val[1]
|
627 |
+
waveform = load_wav(audio_path)
|
628 |
+
waveform = torch.tensor(waveform).transpose(1,0)
|
629 |
+
mixed_mag, mixed_phase = self.stft.transform(waveform)
|
630 |
+
text_query = ['[CLS] ' + text]
|
631 |
+
mixed_mag = mixed_mag.transpose(2,1).unsqueeze(0).to(self.device)
|
632 |
+
est_mask = self.model(mixed_mag, text_query)
|
633 |
+
est_mag = est_mask * mixed_mag
|
634 |
+
est_mag = est_mag.squeeze(1)
|
635 |
+
est_mag = est_mag.permute(0, 2, 1)
|
636 |
+
est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
|
637 |
+
est_wav = est_wav.squeeze(0).squeeze(0).numpy()
|
638 |
+
#est_path = f'output/est{i}.wav'
|
639 |
+
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
640 |
+
print('audio_filename ', audio_filename)
|
641 |
+
save_wav(est_wav, audio_filename)
|
642 |
+
return audio_filename
|
643 |
+
|
644 |
+
|
645 |
+
class Binaural:
|
646 |
+
def __init__(self, device):
|
647 |
+
self.device = device
|
648 |
+
self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
|
649 |
+
self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
|
650 |
+
'mono2binaural/useful_ckpts/m2b/tx_positions2.txt',
|
651 |
+
'mono2binaural/useful_ckpts/m2b/tx_positions3.txt',
|
652 |
+
'mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
|
653 |
+
'mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
|
654 |
+
self.net = BinauralNetwork(view_dim=7,
|
655 |
+
warpnet_layers=4,
|
656 |
+
warpnet_channels=64,
|
657 |
+
)
|
658 |
+
self.net.load_from_file(self.model_file)
|
659 |
+
self.sr = 48000
|
660 |
+
|
661 |
+
@prompts(name="Sythesize Binaural Audio From A Mono Audio Input",
|
662 |
+
description="useful for when you want to transfer your mono audio into binaural audio, "
|
663 |
+
"receives audio_path as input. "
|
664 |
+
"The input to this tool should be a string, "
|
665 |
+
"representing the audio_path. " )
|
666 |
+
|
667 |
+
def inference(self, audio_path):
|
668 |
+
mono, sr = librosa.load(path=audio_path, sr=self.sr, mono=True)
|
669 |
+
mono = torch.from_numpy(mono)
|
670 |
+
mono = mono.unsqueeze(0)
|
671 |
+
import numpy as np
|
672 |
+
import random
|
673 |
+
rand_int = random.randint(0,4)
|
674 |
+
view = np.loadtxt(self.position_file[rand_int]).transpose().astype(np.float32)
|
675 |
+
view = torch.from_numpy(view)
|
676 |
+
if not view.shape[-1] * 400 == mono.shape[-1]:
|
677 |
+
mono = mono[:,:(mono.shape[-1]//400)*400] #
|
678 |
+
if view.shape[1]*400 > mono.shape[1]:
|
679 |
+
m_a = view.shape[1] - mono.shape[-1]//400
|
680 |
+
rand_st = random.randint(0,m_a)
|
681 |
+
view = view[:,m_a:m_a+(mono.shape[-1]//400)] #
|
682 |
+
# binauralize and save output
|
683 |
+
self.net.eval().to(self.device)
|
684 |
+
mono, view = mono.to(self.device), view.to(self.device)
|
685 |
+
chunk_size = 48000 # forward in chunks of 1s
|
686 |
+
rec_field = 1000 # add 1000 samples as "safe bet" since warping has undefined rec. field
|
687 |
+
rec_field -= rec_field % 400 # make sure rec_field is a multiple of 400 to match audio and view frequencies
|
688 |
+
chunks = [
|
689 |
+
{
|
690 |
+
"mono": mono[:, max(0, i-rec_field):i+chunk_size],
|
691 |
+
"view": view[:, max(0, i-rec_field)//400:(i+chunk_size)//400]
|
692 |
+
}
|
693 |
+
for i in range(0, mono.shape[-1], chunk_size)
|
694 |
+
]
|
695 |
+
for i, chunk in enumerate(chunks):
|
696 |
+
with torch.no_grad():
|
697 |
+
mono = chunk["mono"].unsqueeze(0)
|
698 |
+
view = chunk["view"].unsqueeze(0)
|
699 |
+
binaural = self.net(mono, view).squeeze(0)
|
700 |
+
if i > 0:
|
701 |
+
binaural = binaural[:, -(mono.shape[-1]-rec_field):]
|
702 |
+
chunk["binaural"] = binaural
|
703 |
+
binaural = torch.cat([chunk["binaural"] for chunk in chunks], dim=-1)
|
704 |
+
binaural = torch.clamp(binaural, min=-1, max=1).cpu()
|
705 |
+
#binaural = chunked_forwarding(net, mono, view)
|
706 |
+
audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
|
707 |
+
import torchaudio
|
708 |
+
torchaudio.save(audio_filename, binaural, sr)
|
709 |
+
#soundfile.write(audio_filename, binaural, samplerate = 48000)
|
710 |
+
print(f"Processed Binaural.run, audio_filename: {audio_filename}")
|
711 |
+
return audio_filename
|
712 |
+
|
713 |
+
class TargetSoundDetection:
|
714 |
+
def __init__(self, device):
|
715 |
+
self.device = device
|
716 |
+
self.MEL_ARGS = {
|
717 |
+
'n_mels': 64,
|
718 |
+
'n_fft': 2048,
|
719 |
+
'hop_length': int(22050 * 20 / 1000),
|
720 |
+
'win_length': int(22050 * 40 / 1000)
|
721 |
+
}
|
722 |
+
self.EPS = np.spacing(1)
|
723 |
+
self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
|
724 |
+
self.event_labels = event_labels
|
725 |
+
self.id_to_event = {i : label for i, label in enumerate(self.event_labels)}
|
726 |
+
config = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth', map_location='cpu')
|
727 |
+
config_parameters = dict(config)
|
728 |
+
config_parameters['tao'] = 0.6
|
729 |
+
if 'thres' not in config_parameters.keys():
|
730 |
+
config_parameters['thres'] = 0.5
|
731 |
+
if 'time_resolution' not in config_parameters.keys():
|
732 |
+
config_parameters['time_resolution'] = 125
|
733 |
+
model_parameters = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
|
734 |
+
, map_location=lambda storage, loc: storage) # load parameter
|
735 |
+
self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
|
736 |
+
inputdim=64, outputdim=2, time_resolution=config_parameters['time_resolution'], **config_parameters['model_args'])
|
737 |
+
self.model.load_state_dict(model_parameters)
|
738 |
+
self.model = self.model.to(self.device).eval()
|
739 |
+
self.re_embeds = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
|
740 |
+
self.ref_mel = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/ref_mel.pth')
|
741 |
+
|
742 |
+
def extract_feature(self, fname):
|
743 |
+
import soundfile as sf
|
744 |
+
y, sr = sf.read(fname, dtype='float32')
|
745 |
+
print('y ', y.shape)
|
746 |
+
ti = y.shape[0]/sr
|
747 |
+
if y.ndim > 1:
|
748 |
+
y = y.mean(1)
|
749 |
+
y = librosa.resample(y, sr, 22050)
|
750 |
+
lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
|
751 |
+
return lms_feature,ti
|
752 |
+
|
753 |
+
def build_clip(self, text):
|
754 |
+
text = clip.tokenize(text).to(self.device) # ["a diagram with dog", "a dog", "a cat"]
|
755 |
+
text_features = self.clip_model.encode_text(text)
|
756 |
+
return text_features
|
757 |
+
|
758 |
+
def cal_similarity(self, target, retrievals):
|
759 |
+
ans = []
|
760 |
+
for name in retrievals.keys():
|
761 |
+
tmp = retrievals[name]
|
762 |
+
s = torch.cosine_similarity(target.squeeze(), tmp.squeeze(), dim=0)
|
763 |
+
ans.append(s.item())
|
764 |
+
return ans.index(max(ans))
|
765 |
+
|
766 |
+
@prompts(name="Target Sound Detection",
|
767 |
+
description="useful for when you want to know when the target sound event in the audio happens. You can use language descriptions to instruct the model, "
|
768 |
+
"receives text description and audio_path as input. "
|
769 |
+
"The input to this tool should be a comma seperated string of two, "
|
770 |
+
"representing audio path and the text description. " )
|
771 |
+
|
772 |
+
def inference(self, text, audio_path):
|
773 |
+
target_emb = self.build_clip(text) # torch type
|
774 |
+
idx = self.cal_similarity(target_emb, self.re_embeds)
|
775 |
+
target_event = self.id_to_event[idx]
|
776 |
+
embedding = self.ref_mel[target_event]
|
777 |
+
embedding = torch.from_numpy(embedding)
|
778 |
+
embedding = embedding.unsqueeze(0).to(self.device).float()
|
779 |
+
inputs,ti = self.extract_feature(audio_path)
|
780 |
+
inputs = torch.from_numpy(inputs)
|
781 |
+
inputs = inputs.unsqueeze(0).to(self.device).float()
|
782 |
+
decision, decision_up, logit = self.model(inputs, embedding)
|
783 |
+
pred = decision_up.detach().cpu().numpy()
|
784 |
+
pred = pred[:,:,0]
|
785 |
+
frame_num = decision_up.shape[1]
|
786 |
+
time_ratio = ti / frame_num
|
787 |
+
filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
|
788 |
+
time_predictions = []
|
789 |
+
for index_k in range(filtered_pred.shape[0]):
|
790 |
+
decoded_pred = []
|
791 |
+
decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k,:])
|
792 |
+
if len(decoded_pred_) == 0: # neg deal
|
793 |
+
decoded_pred_.append((target_event, 0, 0))
|
794 |
+
decoded_pred.append(decoded_pred_)
|
795 |
+
for num_batch in range(len(decoded_pred)): # when we test our model,the batch_size is 1
|
796 |
+
cur_pred = pred[num_batch]
|
797 |
+
# Save each frame output, for later visualization
|
798 |
+
label_prediction = decoded_pred[num_batch] # frame predict
|
799 |
+
for event_label, onset, offset in label_prediction:
|
800 |
+
time_predictions.append({
|
801 |
+
'onset': onset*time_ratio,
|
802 |
+
'offset': offset*time_ratio,})
|
803 |
+
ans = ''
|
804 |
+
for i,item in enumerate(time_predictions):
|
805 |
+
ans = ans + 'segment' + str(i+1) + ' start_time: ' + str(item['onset']) + ' end_time: ' + str(item['offset']) + '\t'
|
806 |
+
return ans
|