File size: 2,573 Bytes
a875c0d
 
 
 
 
 
 
 
8093841
a875c0d
 
 
 
 
8093841
a875c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8093841
 
 
 
 
f9c55bd
 
 
 
 
a875c0d
 
 
f9c55bd
 
 
8093841
 
 
f9c55bd
a875c0d
8093841
 
 
 
 
f9c55bd
8093841
 
 
 
 
 
 
a875c0d
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pandas as pd
import soundfile as sf
import pdb
from pydub import AudioSegment
from transformers import AutoTokenizer, Wav2Vec2ForCTC
import torch
import numpy as np
import glob
import librosa
import numpy
import os.path

processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
new_sample_rate = 16000
# Dementia path
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
# cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie


def convert_mp3_to_wav(audio_file):
    sound = AudioSegment.from_mp3(audio_file)
    sound.export(audio_file + ".wav", format="wav")


def feature_extractor(path):
    print("the path is", path)

    wav_files = glob.glob(path + "/*.wav")
    #print(wav_files)
    for wav_file in wav_files:
        print("the wavfile is", wav_files)
        # wav2vec2 embeddings
        if not os.path.isfile(wav_file + ".wav2vec2.pt"):      
            get_wav2vecembeddings_from_audiofile(wav_file)


def change_sample_rate(y, sample_rate, new_sample_rate):
    value = librosa.resample(y, sample_rate, new_sample_rate)
    return value

def stereo_to_mono(audio_input):
    X = audio_input.mean(axis=1, keepdims=True)
    X = np.squeeze(X)
    return X

def get_wav2vecembeddings_from_audiofile(wav_file):
    print("the file is", wav_file)
    speech, sample_rate = sf.read(wav_file)

    if len(speech.shape) > 1:
        speech = stereo_to_mono(speech)
     # change sample rate to 16 000 hertz
    resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
    print("the speech is", speech)
    input_values = processor(wav_file, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
    print("input values", input_values)
    # import pdb
    # pdb.set_trace()

    with torch.no_grad():
        encoded_states = model(
            input_values=input_values["input_ids"], 
            # attention_mask=input_values["attention_mask"], 
            output_hidden_states=True
        )
        last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
        print("getting wav2vec2 embeddings")
        print(last_hidden_state)
        torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')

    





feature_extractor("/home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie")