update data prep and inference notebook
Browse files
README.md
CHANGED
@@ -42,7 +42,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
42 |
|
43 |
processor = Wav2Vec2Processor.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
|
44 |
model = Wav2Vec2ForCTC.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
|
45 |
-
model = model.to("cuda")
|
46 |
|
47 |
resampler = torchaudio.transforms.Resample(TEST_AUDIO_SR, 16_000)
|
48 |
def speech_file_to_array_fn(batch):
|
@@ -53,7 +53,7 @@ def speech_file_to_array_fn(batch):
|
|
53 |
speech_array = speech_file_to_array_fn("test_file.wav")
|
54 |
inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
|
55 |
with torch.no_grad():
|
56 |
-
logits = model(inputs.input_values
|
57 |
|
58 |
|
59 |
predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
42 |
|
43 |
processor = Wav2Vec2Processor.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
|
44 |
model = Wav2Vec2ForCTC.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
|
45 |
+
# model = model.to("cuda")
|
46 |
|
47 |
resampler = torchaudio.transforms.Resample(TEST_AUDIO_SR, 16_000)
|
48 |
def speech_file_to_array_fn(batch):
|
|
|
53 |
speech_array = speech_file_to_array_fn("test_file.wav")
|
54 |
inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
|
55 |
with torch.no_grad():
|
56 |
+
logits = model(inputs.input_values).logits
|
57 |
|
58 |
|
59 |
predicted_ids = torch.argmax(logits, dim=-1)
|