arijitx commited on
Commit
ce2fe76
1 Parent(s): feb4be3

update data prep and inference notebook

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -42,7 +42,7 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
42
 
43
  processor = Wav2Vec2Processor.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
44
  model = Wav2Vec2ForCTC.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
45
- model = model.to("cuda")
46
 
47
  resampler = torchaudio.transforms.Resample(TEST_AUDIO_SR, 16_000)
48
  def speech_file_to_array_fn(batch):
@@ -53,7 +53,7 @@ def speech_file_to_array_fn(batch):
53
  speech_array = speech_file_to_array_fn("test_file.wav")
54
  inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
55
  with torch.no_grad():
56
- logits = model(inputs.input_values.to('cuda')).logits
57
 
58
 
59
  predicted_ids = torch.argmax(logits, dim=-1)
 
42
 
43
  processor = Wav2Vec2Processor.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
44
  model = Wav2Vec2ForCTC.from_pretrained("arijitx/wav2vec2-large-xlsr-bengali")
45
+ # model = model.to("cuda")
46
 
47
  resampler = torchaudio.transforms.Resample(TEST_AUDIO_SR, 16_000)
48
  def speech_file_to_array_fn(batch):
 
53
  speech_array = speech_file_to_array_fn("test_file.wav")
54
  inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
55
  with torch.no_grad():
56
+ logits = model(inputs.input_values).logits
57
 
58
 
59
  predicted_ids = torch.argmax(logits, dim=-1)