AItool commited on
Commit
7291080
·
verified ·
1 Parent(s): 799982d

using translator

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
+
6
+ # Translation models for English→English correction
7
+ MODEL_OPTIONS = {
8
+ "Helsinki-NLP/opus-mt-en-en (light, CPU-friendly)": "Helsinki-NLP/opus-mt-en-en",
9
+ "facebook/mbart-large-50-many-to-many-mmt (heavier)": "facebook/mbart-large-50-many-to-many-mmt"
10
+ }
11
+
12
+ # Cache loaded pipelines
13
+ loaded_pipelines = {}
14
+
15
+ def get_pipeline(model_id: str):
16
+ if model_id not in loaded_pipelines:
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(
19
+ model_id,
20
+ low_cpu_mem_usage=True,
21
+ torch_dtype="auto"
22
+ )
23
+ pipe = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
24
+ # Warm-up
25
+ _ = pipe("This is a test.", max_length=32)
26
+ loaded_pipelines[model_id] = pipe
27
+ return loaded_pipelines[model_id]
28
+
29
+ def polish(sentence: str, model_choice: str) -> str:
30
+ model_id = MODEL_OPTIONS[model_choice]
31
+ translator = get_pipeline(model_id)
32
+
33
+ # For mbart we need to set language codes
34
+ if "mbart" in model_id:
35
+ inputs = translator.tokenizer(sentence, return_tensors="pt")
36
+ inputs["forced_bos_token_id"] = translator.tokenizer.lang_code_to_id["en_XX"]
37
+ out = translator.model.generate(**inputs, max_length=128, num_beams=4)
38
+ text = translator.tokenizer.decode(out[0], skip_special_tokens=True)
39
+ else:
40
+ out = translator(sentence, max_length=128)
41
+ text = out[0]["translation_text"]
42
+
43
+ return text.strip()
44
+
45
+ # Gradio interface
46
+ demo = gr.Interface(
47
+ fn=polish,
48
+ inputs=[
49
+ gr.Textbox(lines=2, placeholder="Enter a sentence to correct..."),
50
+ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()),
51
+ value="Helsinki-NLP/opus-mt-en-en (light, CPU-friendly)",
52
+ label="Choose Model")
53
+ ],
54
+ outputs=gr.Textbox(label="Corrected English"),
55
+ title="English→English Grammar Polisher",
56
+ description="Uses translation models (Helsinki-NLP opus-mt-en-en and facebook mbart-large-50) to rewrite English sentences into fluent, corrected English."
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ demo.launch()