lmzjms commited on
Commit
1b4468e
1 Parent(s): 064d7bc

Update audio_foundation_models.py

Browse files
Files changed (1) hide show
  1. audio_foundation_models.py +285 -2
audio_foundation_models.py CHANGED
@@ -4,6 +4,8 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
4
  sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
5
  sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
6
  sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
 
 
7
  import matplotlib
8
  import librosa
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
@@ -40,7 +42,16 @@ from utils.hparams import set_hparams
40
  from utils.hparams import hparams as hp
41
  from utils.os_utils import move_file
42
  import scipy.io.wavfile as wavfile
43
-
 
 
 
 
 
 
 
 
 
44
 
45
  def prompts(name, description):
46
  def decorator(func):
@@ -520,4 +531,276 @@ class A2T:
520
  def inference(self, audio_path):
521
  audio = whisper.load_audio(audio_path)
522
  caption_text = self.model(audio)
523
- return caption_text[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
5
  sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
6
  sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'text_to_audio/Make_An_Audio'))
7
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'audio_detection'))
8
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'mono2binaural'))
9
  import matplotlib
10
  import librosa
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
 
42
  from utils.hparams import hparams as hp
43
  from utils.os_utils import move_file
44
  import scipy.io.wavfile as wavfile
45
+ from audio_infer.utils import config as detection_config
46
+ from audio_infer.pytorch.models import PVT
47
+ from src.models import BinauralNetwork
48
+ from sound_extraction.model.LASSNet import LASSNet
49
+ from sound_extraction.utils.stft import STFT
50
+ from sound_extraction.utils.wav_io import load_wav, save_wav
51
+ from target_sound_detection.src import models as tsd_models
52
+ from target_sound_detection.src.models import event_labels
53
+ from target_sound_detection.src.utils import median_filter, decode_with_timestamps
54
+ import clip
55
 
56
  def prompts(name, description):
57
  def decorator(func):
 
531
  def inference(self, audio_path):
532
  audio = whisper.load_audio(audio_path)
533
  caption_text = self.model(audio)
534
+ return caption_text[0]
535
+
536
+ class SoundDetection:
537
+ def __init__(self, device):
538
+ self.device = device
539
+ self.sample_rate = 32000
540
+ self.window_size = 1024
541
+ self.hop_size = 320
542
+ self.mel_bins = 64
543
+ self.fmin = 50
544
+ self.fmax = 14000
545
+ self.model_type = 'PVT'
546
+ self.checkpoint_path = 'audio_detection/audio_infer/useful_ckpts/audio_detection.pth'
547
+ self.classes_num = detection_config.classes_num
548
+ self.labels = detection_config.labels
549
+ self.frames_per_second = self.sample_rate // self.hop_size
550
+ # Model = eval(self.model_type)
551
+ self.model = PVT(sample_rate=self.sample_rate, window_size=self.window_size,
552
+ hop_size=self.hop_size, mel_bins=self.mel_bins, fmin=self.fmin, fmax=self.fmax,
553
+ classes_num=self.classes_num)
554
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
555
+ self.model.load_state_dict(checkpoint['model'])
556
+ self.model.to(device)
557
+
558
+ @prompts(name="Detect The Sound Event From The Audio",
559
+ description="useful for when you want to know what event in the audio and the sound event start or end time, "
560
+ "receives audio_path as input. "
561
+ "The input to this tool should be a string, "
562
+ "representing the audio_path. " )
563
+
564
+ def inference(self, audio_path):
565
+ # Forward
566
+ (waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
567
+ waveform = waveform[None, :] # (1, audio_length)
568
+ waveform = torch.from_numpy(waveform)
569
+ waveform = waveform.to(self.device)
570
+ # Forward
571
+ with torch.no_grad():
572
+ self.model.eval()
573
+ batch_output_dict = self.model(waveform, None)
574
+ framewise_output = batch_output_dict['framewise_output'].data.cpu().numpy()[0]
575
+ """(time_steps, classes_num)"""
576
+ # print('Sound event detection result (time_steps x classes_num): {}'.format(
577
+ # framewise_output.shape))
578
+ import numpy as np
579
+ import matplotlib.pyplot as plt
580
+ sorted_indexes = np.argsort(np.max(framewise_output, axis=0))[::-1]
581
+ top_k = 10 # Show top results
582
+ top_result_mat = framewise_output[:, sorted_indexes[0 : top_k]]
583
+ """(time_steps, top_k)"""
584
+ # Plot result
585
+ stft = librosa.core.stft(y=waveform[0].data.cpu().numpy(), n_fft=self.window_size,
586
+ hop_length=self.hop_size, window='hann', center=True)
587
+ frames_num = stft.shape[-1]
588
+ fig, axs = plt.subplots(2, 1, sharex=True, figsize=(10, 4))
589
+ axs[0].matshow(np.log(np.abs(stft)), origin='lower', aspect='auto', cmap='jet')
590
+ axs[0].set_ylabel('Frequency bins')
591
+ axs[0].set_title('Log spectrogram')
592
+ axs[1].matshow(top_result_mat.T, origin='upper', aspect='auto', cmap='jet', vmin=0, vmax=1)
593
+ axs[1].xaxis.set_ticks(np.arange(0, frames_num, self.frames_per_second))
594
+ axs[1].xaxis.set_ticklabels(np.arange(0, frames_num / self.frames_per_second))
595
+ axs[1].yaxis.set_ticks(np.arange(0, top_k))
596
+ axs[1].yaxis.set_ticklabels(np.array(self.labels)[sorted_indexes[0 : top_k]])
597
+ axs[1].yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3)
598
+ axs[1].set_xlabel('Seconds')
599
+ axs[1].xaxis.set_ticks_position('bottom')
600
+ plt.tight_layout()
601
+ image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
602
+ plt.savefig(image_filename)
603
+ return image_filename
604
+
605
+ class SoundExtraction:
606
+ def __init__(self, device):
607
+ self.device = device
608
+ self.model_file = 'sound_extraction/useful_ckpts/LASSNet.pt'
609
+ self.stft = STFT()
610
+ import torch.nn as nn
611
+ self.model = nn.DataParallel(LASSNet(device)).to(device)
612
+ checkpoint = torch.load(self.model_file)
613
+ self.model.load_state_dict(checkpoint['model'])
614
+ self.model.eval()
615
+
616
+ @prompts(name="Extract Sound Event From Mixture Audio Based On Language Description",
617
+ description="useful for when you extract target sound from a mixture audio, you can describe the target sound by text, "
618
+ "receives audio_path and text as input. "
619
+ "The input to this tool should be a comma seperated string of two, "
620
+ "representing mixture audio path and input text." )
621
+
622
+ def inference(self, inputs):
623
+ #key = ['ref_audio', 'text']
624
+ val = inputs.split(",")
625
+ audio_path = val[0] # audio_path, text
626
+ text = val[1]
627
+ waveform = load_wav(audio_path)
628
+ waveform = torch.tensor(waveform).transpose(1,0)
629
+ mixed_mag, mixed_phase = self.stft.transform(waveform)
630
+ text_query = ['[CLS] ' + text]
631
+ mixed_mag = mixed_mag.transpose(2,1).unsqueeze(0).to(self.device)
632
+ est_mask = self.model(mixed_mag, text_query)
633
+ est_mag = est_mask * mixed_mag
634
+ est_mag = est_mag.squeeze(1)
635
+ est_mag = est_mag.permute(0, 2, 1)
636
+ est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
637
+ est_wav = est_wav.squeeze(0).squeeze(0).numpy()
638
+ #est_path = f'output/est{i}.wav'
639
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
640
+ print('audio_filename ', audio_filename)
641
+ save_wav(est_wav, audio_filename)
642
+ return audio_filename
643
+
644
+
645
+ class Binaural:
646
+ def __init__(self, device):
647
+ self.device = device
648
+ self.model_file = 'mono2binaural/useful_ckpts/m2b/binaural_network.net'
649
+ self.position_file = ['mono2binaural/useful_ckpts/m2b/tx_positions.txt',
650
+ 'mono2binaural/useful_ckpts/m2b/tx_positions2.txt',
651
+ 'mono2binaural/useful_ckpts/m2b/tx_positions3.txt',
652
+ 'mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
653
+ 'mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
654
+ self.net = BinauralNetwork(view_dim=7,
655
+ warpnet_layers=4,
656
+ warpnet_channels=64,
657
+ )
658
+ self.net.load_from_file(self.model_file)
659
+ self.sr = 48000
660
+
661
+ @prompts(name="Sythesize Binaural Audio From A Mono Audio Input",
662
+ description="useful for when you want to transfer your mono audio into binaural audio, "
663
+ "receives audio_path as input. "
664
+ "The input to this tool should be a string, "
665
+ "representing the audio_path. " )
666
+
667
+ def inference(self, audio_path):
668
+ mono, sr = librosa.load(path=audio_path, sr=self.sr, mono=True)
669
+ mono = torch.from_numpy(mono)
670
+ mono = mono.unsqueeze(0)
671
+ import numpy as np
672
+ import random
673
+ rand_int = random.randint(0,4)
674
+ view = np.loadtxt(self.position_file[rand_int]).transpose().astype(np.float32)
675
+ view = torch.from_numpy(view)
676
+ if not view.shape[-1] * 400 == mono.shape[-1]:
677
+ mono = mono[:,:(mono.shape[-1]//400)*400] #
678
+ if view.shape[1]*400 > mono.shape[1]:
679
+ m_a = view.shape[1] - mono.shape[-1]//400
680
+ rand_st = random.randint(0,m_a)
681
+ view = view[:,m_a:m_a+(mono.shape[-1]//400)] #
682
+ # binauralize and save output
683
+ self.net.eval().to(self.device)
684
+ mono, view = mono.to(self.device), view.to(self.device)
685
+ chunk_size = 48000 # forward in chunks of 1s
686
+ rec_field = 1000 # add 1000 samples as "safe bet" since warping has undefined rec. field
687
+ rec_field -= rec_field % 400 # make sure rec_field is a multiple of 400 to match audio and view frequencies
688
+ chunks = [
689
+ {
690
+ "mono": mono[:, max(0, i-rec_field):i+chunk_size],
691
+ "view": view[:, max(0, i-rec_field)//400:(i+chunk_size)//400]
692
+ }
693
+ for i in range(0, mono.shape[-1], chunk_size)
694
+ ]
695
+ for i, chunk in enumerate(chunks):
696
+ with torch.no_grad():
697
+ mono = chunk["mono"].unsqueeze(0)
698
+ view = chunk["view"].unsqueeze(0)
699
+ binaural = self.net(mono, view).squeeze(0)
700
+ if i > 0:
701
+ binaural = binaural[:, -(mono.shape[-1]-rec_field):]
702
+ chunk["binaural"] = binaural
703
+ binaural = torch.cat([chunk["binaural"] for chunk in chunks], dim=-1)
704
+ binaural = torch.clamp(binaural, min=-1, max=1).cpu()
705
+ #binaural = chunked_forwarding(net, mono, view)
706
+ audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
707
+ import torchaudio
708
+ torchaudio.save(audio_filename, binaural, sr)
709
+ #soundfile.write(audio_filename, binaural, samplerate = 48000)
710
+ print(f"Processed Binaural.run, audio_filename: {audio_filename}")
711
+ return audio_filename
712
+
713
+ class TargetSoundDetection:
714
+ def __init__(self, device):
715
+ self.device = device
716
+ self.MEL_ARGS = {
717
+ 'n_mels': 64,
718
+ 'n_fft': 2048,
719
+ 'hop_length': int(22050 * 20 / 1000),
720
+ 'win_length': int(22050 * 40 / 1000)
721
+ }
722
+ self.EPS = np.spacing(1)
723
+ self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
724
+ self.event_labels = event_labels
725
+ self.id_to_event = {i : label for i, label in enumerate(self.event_labels)}
726
+ config = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth', map_location='cpu')
727
+ config_parameters = dict(config)
728
+ config_parameters['tao'] = 0.6
729
+ if 'thres' not in config_parameters.keys():
730
+ config_parameters['thres'] = 0.5
731
+ if 'time_resolution' not in config_parameters.keys():
732
+ config_parameters['time_resolution'] = 125
733
+ model_parameters = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
734
+ , map_location=lambda storage, loc: storage) # load parameter
735
+ self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
736
+ inputdim=64, outputdim=2, time_resolution=config_parameters['time_resolution'], **config_parameters['model_args'])
737
+ self.model.load_state_dict(model_parameters)
738
+ self.model = self.model.to(self.device).eval()
739
+ self.re_embeds = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
740
+ self.ref_mel = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/ref_mel.pth')
741
+
742
+ def extract_feature(self, fname):
743
+ import soundfile as sf
744
+ y, sr = sf.read(fname, dtype='float32')
745
+ print('y ', y.shape)
746
+ ti = y.shape[0]/sr
747
+ if y.ndim > 1:
748
+ y = y.mean(1)
749
+ y = librosa.resample(y, sr, 22050)
750
+ lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
751
+ return lms_feature,ti
752
+
753
+ def build_clip(self, text):
754
+ text = clip.tokenize(text).to(self.device) # ["a diagram with dog", "a dog", "a cat"]
755
+ text_features = self.clip_model.encode_text(text)
756
+ return text_features
757
+
758
+ def cal_similarity(self, target, retrievals):
759
+ ans = []
760
+ for name in retrievals.keys():
761
+ tmp = retrievals[name]
762
+ s = torch.cosine_similarity(target.squeeze(), tmp.squeeze(), dim=0)
763
+ ans.append(s.item())
764
+ return ans.index(max(ans))
765
+
766
+ @prompts(name="Target Sound Detection",
767
+ description="useful for when you want to know when the target sound event in the audio happens. You can use language descriptions to instruct the model, "
768
+ "receives text description and audio_path as input. "
769
+ "The input to this tool should be a comma seperated string of two, "
770
+ "representing audio path and the text description. " )
771
+
772
+ def inference(self, text, audio_path):
773
+ target_emb = self.build_clip(text) # torch type
774
+ idx = self.cal_similarity(target_emb, self.re_embeds)
775
+ target_event = self.id_to_event[idx]
776
+ embedding = self.ref_mel[target_event]
777
+ embedding = torch.from_numpy(embedding)
778
+ embedding = embedding.unsqueeze(0).to(self.device).float()
779
+ inputs,ti = self.extract_feature(audio_path)
780
+ inputs = torch.from_numpy(inputs)
781
+ inputs = inputs.unsqueeze(0).to(self.device).float()
782
+ decision, decision_up, logit = self.model(inputs, embedding)
783
+ pred = decision_up.detach().cpu().numpy()
784
+ pred = pred[:,:,0]
785
+ frame_num = decision_up.shape[1]
786
+ time_ratio = ti / frame_num
787
+ filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
788
+ time_predictions = []
789
+ for index_k in range(filtered_pred.shape[0]):
790
+ decoded_pred = []
791
+ decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k,:])
792
+ if len(decoded_pred_) == 0: # neg deal
793
+ decoded_pred_.append((target_event, 0, 0))
794
+ decoded_pred.append(decoded_pred_)
795
+ for num_batch in range(len(decoded_pred)): # when we test our model,the batch_size is 1
796
+ cur_pred = pred[num_batch]
797
+ # Save each frame output, for later visualization
798
+ label_prediction = decoded_pred[num_batch] # frame predict
799
+ for event_label, onset, offset in label_prediction:
800
+ time_predictions.append({
801
+ 'onset': onset*time_ratio,
802
+ 'offset': offset*time_ratio,})
803
+ ans = ''
804
+ for i,item in enumerate(time_predictions):
805
+ ans = ans + 'segment' + str(i+1) + ' start_time: ' + str(item['onset']) + ' end_time: ' + str(item['offset']) + '\t'
806
+ return ans