awacke1 commited on
Commit
023687d
1 Parent(s): 25c2aec

Create GrammarTokenize.py

Browse files
Files changed (1) hide show
  1. GrammarTokenize.py +60 -0
GrammarTokenize.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from fastapi import File
3
+ from fastapi import FastAPI
4
+ from fastapi import UploadFile
5
+ import torch
6
+ import os
7
+ import sys
8
+ import glob
9
+ import transformers
10
+ from transformers import AutoTokenizer
11
+ from transformers import AutoModelForSeq2SeqLM
12
+
13
+
14
+ print("Loading models...")
15
+ app = FastAPI()
16
+
17
+ device = "cpu"
18
+ correction_model_tag = "prithivida/grammar_error_correcter_v1"
19
+ correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag)
20
+ correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag)
21
+
22
+ def set_seed(seed):
23
+ torch.manual_seed(seed)
24
+ if torch.cuda.is_available():
25
+ torch.cuda.manual_seed_all(seed)
26
+
27
+ print("Models loaded !")
28
+
29
+
30
+ @app.get("/")
31
+ def read_root():
32
+ return {"Gramformer !"}
33
+
34
+ @app.get("/{correct}")
35
+ def get_correction(input_sentence):
36
+ set_seed(1212)
37
+ scored_corrected_sentence = correct(input_sentence)
38
+ return {"scored_corrected_sentence": scored_corrected_sentence}
39
+
40
+ def correct(input_sentence, max_candidates=1):
41
+ correction_prefix = "gec: "
42
+ input_sentence = correction_prefix + input_sentence
43
+ input_ids = correction_tokenizer.encode(input_sentence, return_tensors='pt')
44
+ input_ids = input_ids.to(device)
45
+
46
+ preds = correction_model.generate(
47
+ input_ids,
48
+ do_sample=True,
49
+ max_length=128,
50
+ top_k=50,
51
+ top_p=0.95,
52
+ early_stopping=True,
53
+ num_return_sequences=max_candidates)
54
+
55
+ corrected = set()
56
+ for pred in preds:
57
+ corrected.add(correction_tokenizer.decode(pred, skip_special_tokens=True).strip())
58
+
59
+ corrected = list(corrected)
60
+ return (corrected[0], 0) #Corrected Sentence, Dummy score