nimool commited on
Commit
6daeff1
1 Parent(s): 0180e45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -30,32 +30,39 @@ def resampler(input_file_path, output_file_path):
30
  def parse_transcription(logits):
31
  predicted_ids = torch.argmax(logits, dim=-1)
32
  transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
 
33
  return transcription
34
 
35
 
36
- # def parse(wav_file):
37
- # input_values = read_file_and_process(wav_file)
38
- # with torch.no_grad():
39
- # logits = model(**input_values).logits
40
- # sentence = parse_transcription(logits)
41
- # check_spell = spell_checker.check(sentence)
42
- # if check_spell[0] is False:
43
- # corrected = check_spell[1]
44
- # else:
45
- # corrected = sentence
46
- # return corrected
47
-
48
  def parse(wav_file):
49
  input_values = read_file_and_process(wav_file)
50
  with torch.no_grad():
51
  logits = model(**input_values).logits
52
- # sentence = parse_transcription(logits)
53
- check_spell = spell_checker.check(parse_transcription(logits))
54
- # if check_spell[0] is False:
55
- # corrected = check_spell[1]
56
- # else:
57
- # corrected = sentence
58
- return spell_checker.check(parse_transcription(logits))[1] if spell_checker.check(parse_transcription(logits))[0] is False else parse_transcription(logits)
 
 
 
 
 
 
 
 
 
 
59
 
60
  model_id = "jonatasgrosman/wav2vec2-large-xlsr-53-persian"
61
  processor = Wav2Vec2Processor.from_pretrained(model_id)
 
30
  def parse_transcription(logits):
31
  predicted_ids = torch.argmax(logits, dim=-1)
32
  transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
33
+ del(logits)
34
  return transcription
35
 
36
 
37
+ def corrector(sentence):
38
+ check_spell = spell_checker.check(sentence)
39
+ if check_spell[0] is False:
40
+ corrected = check_spell[1]
41
+ return corrected
42
+ else:
43
+ return sentence
44
+
 
 
 
 
45
  def parse(wav_file):
46
  input_values = read_file_and_process(wav_file)
47
  with torch.no_grad():
48
  logits = model(**input_values).logits
49
+ sentence = parse_transcription(logits)
50
+ corrected_sent = corrector(sentence)
51
+ return corrected_sent
52
+
53
+
54
+ # def parse(wav_file):
55
+ # check_spell = ''
56
+ # input_values = read_file_and_process(wav_file)
57
+ # with torch.no_grad():
58
+ # logits = model(**input_values).logits
59
+ # # sentence = parse_transcription(logits)
60
+ # check_spell = spell_checker.check(parse_transcription(logits))
61
+ # # if check_spell[0] is False:
62
+ # # corrected = check_spell[1]
63
+ # # else:
64
+ # # corrected = sentence
65
+ # return spell_checker.check(parse_transcription(logits))[1] if spell_checker.check(parse_transcription(logits))[0] is False else parse_transcription(logits)
66
 
67
  model_id = "jonatasgrosman/wav2vec2-large-xlsr-53-persian"
68
  processor = Wav2Vec2Processor.from_pretrained(model_id)