Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -2,15 +2,18 @@ import streamlit as st
|
|
2 |
import PyPDF2
|
3 |
import docx
|
4 |
import io
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
import tempfile
|
9 |
from typing import Union, Tuple
|
10 |
-
import
|
11 |
|
12 |
-
#
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
# Define supported languages and their codes
|
16 |
SUPPORTED_LANGUAGES = {
|
@@ -19,26 +22,53 @@ SUPPORTED_LANGUAGES = {
|
|
19 |
'Marathi': 'mar_Deva'
|
20 |
}
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
@st.cache_resource
|
23 |
def load_models():
|
24 |
-
"""Load and cache the translation
|
25 |
# Load Gemma model for context interpretation
|
26 |
-
gemma_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
27 |
gemma_model = AutoModelForCausalLM.from_pretrained(
|
28 |
"google/gemma-2b",
|
29 |
device_map="auto",
|
30 |
-
torch_dtype=torch.float16
|
|
|
31 |
)
|
32 |
|
33 |
# Load NLLB model for translation
|
34 |
-
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
35 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
36 |
"facebook/nllb-200-distilled-600M",
|
37 |
device_map="auto",
|
38 |
-
torch_dtype=torch.float16
|
|
|
39 |
)
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
def extract_text_from_file(uploaded_file) -> str:
|
44 |
"""Extract text content from uploaded file based on its type."""
|
@@ -106,17 +136,39 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
106 |
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
107 |
return translated_text
|
108 |
|
109 |
-
def correct_grammar(text: str, target_lang: str) -> str:
|
110 |
-
"""
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
#
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
def save_as_docx(text: str) -> io.BytesIO:
|
122 |
"""Save translated text as a DOCX file."""
|
@@ -134,7 +186,12 @@ def main():
|
|
134 |
|
135 |
# Load models
|
136 |
with st.spinner("Loading models... This may take a few minutes."):
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
# File upload
|
140 |
uploaded_file = st.file_uploader(
|
@@ -182,7 +239,8 @@ def main():
|
|
182 |
with st.spinner("Correcting grammar..."):
|
183 |
corrected_text = correct_grammar(
|
184 |
translated_text,
|
185 |
-
SUPPORTED_LANGUAGES[target_language]
|
|
|
186 |
)
|
187 |
|
188 |
# Display result
|
|
|
2 |
import PyPDF2
|
3 |
import docx
|
4 |
import io
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
import tempfile
|
9 |
from typing import Union, Tuple
|
10 |
+
import os
|
11 |
|
12 |
+
# Get Hugging Face token from environment variables
|
13 |
+
HF_TOKEN = os.environ.get('HF_TOKEN')
|
14 |
+
if not HF_TOKEN:
|
15 |
+
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
16 |
+
st.stop()
|
17 |
|
18 |
# Define supported languages and their codes
|
19 |
SUPPORTED_LANGUAGES = {
|
|
|
22 |
'Marathi': 'mar_Deva'
|
23 |
}
|
24 |
|
25 |
+
# Language codes for MT5
|
26 |
+
MT5_LANG_CODES = {
|
27 |
+
'eng_Latn': 'en',
|
28 |
+
'hin_Deva': 'hi',
|
29 |
+
'mar_Deva': 'mr'
|
30 |
+
}
|
31 |
+
|
32 |
@st.cache_resource
|
33 |
def load_models():
|
34 |
+
"""Load and cache the translation, context interpretation, and grammar correction models."""
|
35 |
# Load Gemma model for context interpretation
|
36 |
+
gemma_tokenizer = AutoTokenizer.from_pretrained(
|
37 |
+
"google/gemma-2b",
|
38 |
+
token=HF_TOKEN
|
39 |
+
)
|
40 |
gemma_model = AutoModelForCausalLM.from_pretrained(
|
41 |
"google/gemma-2b",
|
42 |
device_map="auto",
|
43 |
+
torch_dtype=torch.float16,
|
44 |
+
token=HF_TOKEN
|
45 |
)
|
46 |
|
47 |
# Load NLLB model for translation
|
48 |
+
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
49 |
+
"facebook/nllb-200-distilled-600M",
|
50 |
+
token=HF_TOKEN
|
51 |
+
)
|
52 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
53 |
"facebook/nllb-200-distilled-600M",
|
54 |
device_map="auto",
|
55 |
+
torch_dtype=torch.float16,
|
56 |
+
token=HF_TOKEN
|
57 |
)
|
58 |
|
59 |
+
# Load MT5 model for grammar correction
|
60 |
+
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
61 |
+
"google/mt5-small",
|
62 |
+
token=HF_TOKEN
|
63 |
+
)
|
64 |
+
mt5_model = T5ForConditionalGeneration.from_pretrained(
|
65 |
+
"google/mt5-small",
|
66 |
+
device_map="auto",
|
67 |
+
torch_dtype=torch.float16,
|
68 |
+
token=HF_TOKEN
|
69 |
+
)
|
70 |
+
|
71 |
+
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
|
72 |
|
73 |
def extract_text_from_file(uploaded_file) -> str:
|
74 |
"""Extract text content from uploaded file based on its type."""
|
|
|
136 |
translated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
137 |
return translated_text
|
138 |
|
139 |
+
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
140 |
+
"""
|
141 |
+
Correct grammar using MT5 model for all supported languages.
|
142 |
+
Uses a text-to-text approach with language-specific prompts.
|
143 |
+
"""
|
144 |
+
tokenizer, model = mt5_tuple
|
145 |
+
lang_code = MT5_LANG_CODES[target_lang]
|
146 |
+
|
147 |
+
# Language-specific prompts for grammar correction
|
148 |
+
prompts = {
|
149 |
+
'en': f"grammar: {text}",
|
150 |
+
'hi': f"व्याकरण सुधार: {text}",
|
151 |
+
'mr': f"व्याकरण सुधारणा: {text}"
|
152 |
+
}
|
153 |
+
|
154 |
+
prompt = prompts[lang_code]
|
155 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
|
156 |
+
|
157 |
+
outputs = model.generate(
|
158 |
+
**inputs,
|
159 |
+
max_length=512,
|
160 |
+
num_beams=5,
|
161 |
+
temperature=0.7,
|
162 |
+
top_p=0.9,
|
163 |
+
do_sample=True
|
164 |
+
)
|
165 |
+
|
166 |
+
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
167 |
+
|
168 |
+
# Clean up any artifacts from the model output
|
169 |
+
corrected_text = corrected_text.replace("grammar:", "").replace("व्याकरण सुधार:", "").replace("व्याकरण सुधारणा:", "").strip()
|
170 |
+
|
171 |
+
return corrected_text
|
172 |
|
173 |
def save_as_docx(text: str) -> io.BytesIO:
|
174 |
"""Save translated text as a DOCX file."""
|
|
|
186 |
|
187 |
# Load models
|
188 |
with st.spinner("Loading models... This may take a few minutes."):
|
189 |
+
try:
|
190 |
+
gemma_tuple, nllb_tuple, mt5_tuple = load_models()
|
191 |
+
except Exception as e:
|
192 |
+
st.error(f"Error loading models: {str(e)}")
|
193 |
+
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
|
194 |
+
st.stop()
|
195 |
|
196 |
# File upload
|
197 |
uploaded_file = st.file_uploader(
|
|
|
239 |
with st.spinner("Correcting grammar..."):
|
240 |
corrected_text = correct_grammar(
|
241 |
translated_text,
|
242 |
+
SUPPORTED_LANGUAGES[target_language],
|
243 |
+
mt5_tuple
|
244 |
)
|
245 |
|
246 |
# Display result
|