hubert-dementia-screening / feature_extractor.py
birgermoell's picture
WIP feature extractor for wav2vec2
a875c0d
import pandas as pd
import soundfile as sf
import pdb
from pydub import AudioSegment
from transformers import AutoTokenizer, Wav2Vec2ForCTC
import torch
import numpy as np
import glob
import numpy
import os.path
processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
# Dementia path
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
# cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie
def convert_mp3_to_wav(audio_file):
sound = AudioSegment.from_mp3(audio_file)
sound.export(audio_file + ".wav", format="wav")
def feature_extractor(path):
print("the path is", path)
wav_files = glob.glob(path + "/*.wav")
#print(wav_files)
for wav_file in wav_files:
print("the wavfile is", wav_files)
# wav2vec2 embeddings
if not os.path.isfile(wav_file + ".wav2vec2.pt"):
get_wav2vecembeddings_from_audiofile(wav_file)
def get_wav2vecembeddings_from_audiofile(wav_file):
print("the file is", wav_file)
speech, sample_rate = sf.read(wav_file)
input_values = processor(wav_file, return_tensors="pt", padding=True) # there is no truncation param anymore
print("input values", input_values)
file_info = os.stat(wav_file)
file_size = file_info.st_size
print("the size is", file_size)
if file_size > 250:
with torch.no_grad():
encoded_states = model(
input_values=input_values["input_ids"],
attention_mask=input_values["attention_mask"],
output_hidden_states=True
)
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
print("getting wav2vec2 embeddings")
print(last_hidden_state)
torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')
feature_extractor("/home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie")