import pandas as pd import soundfile as sf import pdb from pydub import AudioSegment from transformers import AutoTokenizer, Wav2Vec2ForCTC import torch import numpy as np import glob import librosa import numpy import os.path processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60") new_sample_rate = 16000 # Dementia path # /home/bmoell/data/media.talkbank.org/dementia/English/Pitt # cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie # /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie def convert_mp3_to_wav(audio_file): sound = AudioSegment.from_mp3(audio_file) sound.export(audio_file + ".wav", format="wav") def feature_extractor(path): print("the path is", path) wav_files = glob.glob(path + "/*.wav") #print(wav_files) for wav_file in wav_files: print("the wavfile is", wav_files) # wav2vec2 embeddings if not os.path.isfile(wav_file + ".wav2vec2.pt"): get_wav2vecembeddings_from_audiofile(wav_file) def change_sample_rate(y, sample_rate, new_sample_rate): value = librosa.resample(y, sample_rate, new_sample_rate) return value def stereo_to_mono(audio_input): X = audio_input.mean(axis=1, keepdims=True) X = np.squeeze(X) return X def get_wav2vecembeddings_from_audiofile(wav_file): print("the file is", wav_file) speech, sample_rate = sf.read(wav_file) if len(speech.shape) > 1: speech = stereo_to_mono(speech) # change sample rate to 16 000 hertz resampled = change_sample_rate(speech, sample_rate, new_sample_rate) print("the speech is", speech) input_values = processor(wav_file, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore print("input values", input_values) # import pdb # pdb.set_trace() with torch.no_grad(): encoded_states = model( input_values=input_values["input_ids"], # attention_mask=input_values["attention_mask"], output_hidden_states=True ) last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple print("getting wav2vec2 embeddings") print(last_hidden_state) torch.save(last_hidden_state, wav_file + '.wav2vec2.pt') feature_extractor("/home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie")