gauravchand11 commited on
Commit
c98f2e3
·
verified ·
1 Parent(s): 7f47488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -23
app.py CHANGED
@@ -2,15 +2,18 @@ import streamlit as st
2
  import PyPDF2
3
  import docx
4
  import io
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
  from pathlib import Path
8
  import tempfile
9
  from typing import Union, Tuple
10
- import language_tool_python
11
 
12
- # Initialize language tool for grammar correction
13
- language_tool = language_tool_python.LanguageTool('en-US')
 
 
 
14
 
15
  # Define supported languages and their codes
16
  SUPPORTED_LANGUAGES = {
@@ -19,26 +22,53 @@ SUPPORTED_LANGUAGES = {
19
  'Marathi': 'mar_Deva'
20
  }
21
 
 
 
 
 
 
 
 
22
  @st.cache_resource
23
  def load_models():
24
- """Load and cache the translation and context interpretation models."""
25
  # Load Gemma model for context interpretation
26
- gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
 
 
 
27
  gemma_model = AutoModelForCausalLM.from_pretrained(
28
  "google/gemma-2b",
29
  device_map="auto",
30
- torch_dtype=torch.float16
 
31
  )
32
 
33
  # Load NLLB model for translation
34
- nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
 
 
 
35
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
36
  "facebook/nllb-200-distilled-600M",
37
  device_map="auto",
38
- torch_dtype=torch.float16
 
39
  )
40
 
41
- return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model)
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def extract_text_from_file(uploaded_file) -> str:
44
  """Extract text content from uploaded file based on its type."""
@@ -106,17 +136,39 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
106
  translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
107
  return translated_text
108
 
109
- def correct_grammar(text: str, target_lang: str) -> str:
110
- """Correct grammar and ensure tense consistency in the translated text."""
111
- # For English target language, use LanguageTool
112
- if target_lang == 'eng_Latn':
113
- matches = language_tool.check(text)
114
- corrected_text = language_tool.correct(text)
115
- return corrected_text
116
-
117
- # For other languages, return as-is (you may want to add specific grammar
118
- # correction for Hindi and Marathi in a production environment)
119
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def save_as_docx(text: str) -> io.BytesIO:
122
  """Save translated text as a DOCX file."""
@@ -134,7 +186,12 @@ def main():
134
 
135
  # Load models
136
  with st.spinner("Loading models... This may take a few minutes."):
137
- gemma_tuple, nllb_tuple = load_models()
 
 
 
 
 
138
 
139
  # File upload
140
  uploaded_file = st.file_uploader(
@@ -182,7 +239,8 @@ def main():
182
  with st.spinner("Correcting grammar..."):
183
  corrected_text = correct_grammar(
184
  translated_text,
185
- SUPPORTED_LANGUAGES[target_language]
 
186
  )
187
 
188
  # Display result
 
2
  import PyPDF2
3
  import docx
4
  import io
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
6
  import torch
7
  from pathlib import Path
8
  import tempfile
9
  from typing import Union, Tuple
10
+ import os
11
 
12
+ # Get Hugging Face token from environment variables
13
+ HF_TOKEN = os.environ.get('HF_TOKEN')
14
+ if not HF_TOKEN:
15
+ st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
16
+ st.stop()
17
 
18
  # Define supported languages and their codes
19
  SUPPORTED_LANGUAGES = {
 
22
  'Marathi': 'mar_Deva'
23
  }
24
 
25
+ # Language codes for MT5
26
+ MT5_LANG_CODES = {
27
+ 'eng_Latn': 'en',
28
+ 'hin_Deva': 'hi',
29
+ 'mar_Deva': 'mr'
30
+ }
31
+
32
  @st.cache_resource
33
  def load_models():
34
+ """Load and cache the translation, context interpretation, and grammar correction models."""
35
  # Load Gemma model for context interpretation
36
+ gemma_tokenizer = AutoTokenizer.from_pretrained(
37
+ "google/gemma-2b",
38
+ token=HF_TOKEN
39
+ )
40
  gemma_model = AutoModelForCausalLM.from_pretrained(
41
  "google/gemma-2b",
42
  device_map="auto",
43
+ torch_dtype=torch.float16,
44
+ token=HF_TOKEN
45
  )
46
 
47
  # Load NLLB model for translation
48
+ nllb_tokenizer = AutoTokenizer.from_pretrained(
49
+ "facebook/nllb-200-distilled-600M",
50
+ token=HF_TOKEN
51
+ )
52
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
53
  "facebook/nllb-200-distilled-600M",
54
  device_map="auto",
55
+ torch_dtype=torch.float16,
56
+ token=HF_TOKEN
57
  )
58
 
59
+ # Load MT5 model for grammar correction
60
+ mt5_tokenizer = AutoTokenizer.from_pretrained(
61
+ "google/mt5-small",
62
+ token=HF_TOKEN
63
+ )
64
+ mt5_model = T5ForConditionalGeneration.from_pretrained(
65
+ "google/mt5-small",
66
+ device_map="auto",
67
+ torch_dtype=torch.float16,
68
+ token=HF_TOKEN
69
+ )
70
+
71
+ return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
72
 
73
  def extract_text_from_file(uploaded_file) -> str:
74
  """Extract text content from uploaded file based on its type."""
 
136
  translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
137
  return translated_text
138
 
139
+ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
140
+ """
141
+ Correct grammar using MT5 model for all supported languages.
142
+ Uses a text-to-text approach with language-specific prompts.
143
+ """
144
+ tokenizer, model = mt5_tuple
145
+ lang_code = MT5_LANG_CODES[target_lang]
146
+
147
+ # Language-specific prompts for grammar correction
148
+ prompts = {
149
+ 'en': f"grammar: {text}",
150
+ 'hi': f"व्याकरण सुधार: {text}",
151
+ 'mr': f"व्याकरण सुधारणा: {text}"
152
+ }
153
+
154
+ prompt = prompts[lang_code]
155
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
156
+
157
+ outputs = model.generate(
158
+ **inputs,
159
+ max_length=512,
160
+ num_beams=5,
161
+ temperature=0.7,
162
+ top_p=0.9,
163
+ do_sample=True
164
+ )
165
+
166
+ corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
167
+
168
+ # Clean up any artifacts from the model output
169
+ corrected_text = corrected_text.replace("grammar:", "").replace("व्याकरण सुधार:", "").replace("व्याकरण सुधारणा:", "").strip()
170
+
171
+ return corrected_text
172
 
173
  def save_as_docx(text: str) -> io.BytesIO:
174
  """Save translated text as a DOCX file."""
 
186
 
187
  # Load models
188
  with st.spinner("Loading models... This may take a few minutes."):
189
+ try:
190
+ gemma_tuple, nllb_tuple, mt5_tuple = load_models()
191
+ except Exception as e:
192
+ st.error(f"Error loading models: {str(e)}")
193
+ st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
194
+ st.stop()
195
 
196
  # File upload
197
  uploaded_file = st.file_uploader(
 
239
  with st.spinner("Correcting grammar..."):
240
  corrected_text = correct_grammar(
241
  translated_text,
242
+ SUPPORTED_LANGUAGES[target_language],
243
+ mt5_tuple
244
  )
245
 
246
  # Display result