salmanmapkar commited on
Commit
8520752
1 Parent(s): 5b083c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -27,9 +27,23 @@ import numpy as np
27
  import json
28
  from datetime import timedelta
29
 
 
 
30
  __FILES = set()
31
  wispher_models = list(whisper._MODELS.keys())
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def CreateFile(filename):
34
  __FILES.add(filename)
35
  return filename
 
27
  import json
28
  from datetime import timedelta
29
 
30
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
31
+
32
  __FILES = set()
33
  wispher_models = list(whisper._MODELS.keys())
34
 
35
+ def correct_grammar(input_text,num_return_sequences=num_return_sequences):
36
+ torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
+ tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
38
+ model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)
39
+ batch = tokenizer([input_text],truncation=True,padding='max_length',max_length=len(input_text), return_tensors="pt").to(torch_device)
40
+ results = model.generate(**batch,max_length=len(input_text),num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
41
+ generated_sequences = []
42
+ for generated_sequence_idx, generated_sequence in enumerate(results):
43
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
44
+ generated_sequences.append(text)
45
+ return "".join(generated_sequences)
46
+
47
  def CreateFile(filename):
48
  __FILES.add(filename)
49
  return filename