birgermoell commited on
Commit
a875c0d
1 Parent(s): 47a6a98

WIP feature extractor for wav2vec2

Browse files
Files changed (2) hide show
  1. feature_extractor.py +68 -0
  2. readme.MD +14 -0
feature_extractor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import soundfile as sf
3
+ import pdb
4
+ from pydub import AudioSegment
5
+ from transformers import AutoTokenizer, Wav2Vec2ForCTC
6
+ import torch
7
+ import numpy as np
8
+ import glob
9
+ import numpy
10
+ import os.path
11
+
12
+ processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
13
+
14
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
15
+
16
+ # Dementia path
17
+ # /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
18
+ # cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
19
+ # /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie
20
+
21
+
22
+ def convert_mp3_to_wav(audio_file):
23
+ sound = AudioSegment.from_mp3(audio_file)
24
+ sound.export(audio_file + ".wav", format="wav")
25
+
26
+
27
+ def feature_extractor(path):
28
+ print("the path is", path)
29
+
30
+ wav_files = glob.glob(path + "/*.wav")
31
+ #print(wav_files)
32
+ for wav_file in wav_files:
33
+ print("the wavfile is", wav_files)
34
+ # wav2vec2 embeddings
35
+ if not os.path.isfile(wav_file + ".wav2vec2.pt"):
36
+ get_wav2vecembeddings_from_audiofile(wav_file)
37
+
38
+ def get_wav2vecembeddings_from_audiofile(wav_file):
39
+ print("the file is", wav_file)
40
+ speech, sample_rate = sf.read(wav_file)
41
+ input_values = processor(wav_file, return_tensors="pt", padding=True) # there is no truncation param anymore
42
+ print("input values", input_values)
43
+
44
+ file_info = os.stat(wav_file)
45
+ file_size = file_info.st_size
46
+ print("the size is", file_size)
47
+
48
+ if file_size > 250:
49
+ with torch.no_grad():
50
+ encoded_states = model(
51
+ input_values=input_values["input_ids"],
52
+ attention_mask=input_values["attention_mask"],
53
+ output_hidden_states=True
54
+ )
55
+
56
+ last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
57
+ print("getting wav2vec2 embeddings")
58
+ print(last_hidden_state)
59
+ torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+ feature_extractor("/home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie")
68
+
readme.MD ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # todo
2
+ install things to run on tpu / hugginface / datasets
3
+ load in data
4
+ train
5
+
6
+ # Important readmes
7
+ https://github.com/huggingface/transformers/tree/f42a0abf4bd765ad08e14b347a3acbe9fade31b9/examples/research_projects/jax-projects/wav2vec2
8
+
9
+ # path to files
10
+ # cookie control
11
+ data/media.talkbank.org/dementia/English/Pitt/Control/cookie
12
+
13
+ # cookie dementia
14
+ data/media.talkbank.org/dementia/English/Pitt/Control/cookie