Update README.md
Browse files
README.md
CHANGED
@@ -59,7 +59,8 @@ feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
59 |
)
|
60 |
|
61 |
# prepare data
|
62 |
-
data = load_dataset("distil-whisper/librispeech_long", "clean",
|
|
|
63 |
|
64 |
def batch_collater(data):
|
65 |
tensors = []
|
@@ -68,8 +69,9 @@ def batch_collater(data):
|
|
68 |
return tensors
|
69 |
|
70 |
audio_array = batch_collater(data)
|
71 |
-
inputs = feature_extractor(audio_array, sampling_rate=16_000,
|
72 |
-
|
|
|
73 |
input_values = inputs['input_values']
|
74 |
input_lengths = torch.sum(inputs['attention_mask'], dim=-1)
|
75 |
|
@@ -79,7 +81,8 @@ input_values, input_lengths = input_values.to(device), input_lengths.to(device)
|
|
79 |
with torch.no_grad():
|
80 |
model.eval()
|
81 |
output = model(input_values=input_values,
|
82 |
-
|
|
|
83 |
```
|
84 |
|
85 |
### Downstream Use
|
@@ -105,8 +108,10 @@ def extract_all_chars(batch):
|
|
105 |
|
106 |
librispeech100h_train = load_dataset("openslr/librispeech_asr", split="train.clean.100")
|
107 |
librispeech100h_test = load_dataset("openslr/librispeech_asr", split="validation.clean")
|
108 |
-
librispeech100h_train = librispeech100h_train.remove_columns(
|
109 |
-
|
|
|
|
|
110 |
|
111 |
librispeech100h_train = librispeech100h_train.map(pre_processing)
|
112 |
librispeech100h_test = librispeech100h_test.map(pre_processing)
|
@@ -135,7 +140,8 @@ feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
135 |
)
|
136 |
|
137 |
tokenizer = Wav2Vec2CTCTokenizer("./ls_vocab.json",
|
138 |
-
|
|
|
139 |
|
140 |
model = AutoModelForCTC.from_pretrained(
|
141 |
repo_id,
|
|
|
59 |
)
|
60 |
|
61 |
# prepare data
|
62 |
+
data = load_dataset("distil-whisper/librispeech_long", "clean",
|
63 |
+
split="validation")
|
64 |
|
65 |
def batch_collater(data):
|
66 |
tensors = []
|
|
|
69 |
return tensors
|
70 |
|
71 |
audio_array = batch_collater(data)
|
72 |
+
inputs = feature_extractor(audio_array, sampling_rate=16_000,
|
73 |
+
return_attention_mask=True,
|
74 |
+
return_tensors='pt', do_normalize=False)
|
75 |
input_values = inputs['input_values']
|
76 |
input_lengths = torch.sum(inputs['attention_mask'], dim=-1)
|
77 |
|
|
|
81 |
with torch.no_grad():
|
82 |
model.eval()
|
83 |
output = model(input_values=input_values,
|
84 |
+
input_lengths=input_lengths,
|
85 |
+
output_hidden_states=True)
|
86 |
```
|
87 |
|
88 |
### Downstream Use
|
|
|
108 |
|
109 |
librispeech100h_train = load_dataset("openslr/librispeech_asr", split="train.clean.100")
|
110 |
librispeech100h_test = load_dataset("openslr/librispeech_asr", split="validation.clean")
|
111 |
+
librispeech100h_train = librispeech100h_train.remove_columns(
|
112 |
+
['file', 'speaker_id', 'chapter_id', 'id'])
|
113 |
+
librispeech100h_test = librispeech100h_test.remove_columns(
|
114 |
+
['file', 'speaker_id', 'chapter_id', 'id'])
|
115 |
|
116 |
librispeech100h_train = librispeech100h_train.map(pre_processing)
|
117 |
librispeech100h_test = librispeech100h_test.map(pre_processing)
|
|
|
140 |
)
|
141 |
|
142 |
tokenizer = Wav2Vec2CTCTokenizer("./ls_vocab.json",
|
143 |
+
unk_token="[UNK]", pad_token="[PAD]",
|
144 |
+
word_delimiter_token="|")
|
145 |
|
146 |
model = AutoModelForCTC.from_pretrained(
|
147 |
repo_id,
|