hubert-dementia-screening / feature_extractor.py
birgermoell's picture
Updated to input wav file directly
f9c55bd
import pandas as pd
import soundfile as sf
import pdb
from pydub import AudioSegment
from transformers import AutoTokenizer, Wav2Vec2ForCTC
import torch
import numpy as np
import glob
import librosa
import numpy
import os.path
processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
new_sample_rate = 16000
# Dementia path
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
# cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie
def convert_mp3_to_wav(audio_file):
sound = AudioSegment.from_mp3(audio_file)
sound.export(audio_file + ".wav", format="wav")
def feature_extractor(path):
print("the path is", path)
wav_files = glob.glob(path + "/*.wav")
#print(wav_files)
for wav_file in wav_files:
print("the wavfile is", wav_files)
# wav2vec2 embeddings
if not os.path.isfile(wav_file + ".wav2vec2.pt"):
get_wav2vecembeddings_from_audiofile(wav_file)
def change_sample_rate(y, sample_rate, new_sample_rate):
value = librosa.resample(y, sample_rate, new_sample_rate)
return value
def stereo_to_mono(audio_input):
X = audio_input.mean(axis=1, keepdims=True)
X = np.squeeze(X)
return X
def get_wav2vecembeddings_from_audiofile(wav_file):
print("the file is", wav_file)
speech, sample_rate = sf.read(wav_file)
if len(speech.shape) > 1:
speech = stereo_to_mono(speech)
# change sample rate to 16 000 hertz
resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
print("the speech is", speech)
input_values = processor(wav_file, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
print("input values", input_values)
# import pdb
# pdb.set_trace()
with torch.no_grad():
encoded_states = model(
input_values=input_values["input_ids"],
# attention_mask=input_values["attention_mask"],
output_hidden_states=True
)
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
print("getting wav2vec2 embeddings")
print(last_hidden_state)
torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')
feature_extractor("/home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie")