johnybanda commited on
Commit
9173e6e
1 Parent(s): 8384a2d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
+ import re
5
+
6
+ # GPU check
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ print(f"Using device: {device}")
9
+
10
+ # Load model and tokenizer
11
+ model_name = "facebook/mbart-large-50-many-to-many-mmt"
12
+ model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
13
+ tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
14
+
15
+ # Set languages
16
+ src_lang, tgt_lang = "en_XX", "ta_IN"
17
+
18
+ # Define a set of technical terms (expand this list as needed)
19
+ technical_terms = {
20
+ "machine translation", "natural language processing", "nlp", "transformer architecture",
21
+ "machine learning", "deep learning", "artificial intelligence", "ai", "neural network",
22
+ "algorithms", "data science", "big data", "cloud computing", "internet of things", "iot",
23
+ "blockchain", "cybersecurity", "virtual reality", "vr", "augmented reality", "ar",
24
+ "robotics", "automation", "quantum computing", "5g", "edge computing", "devops",
25
+ "microservices", "api", "serverless", "container", "docker", "kubernetes", "ml",
26
+ "computer vision", "natural language understanding", "nlu", "speech recognition",
27
+ "sentiment analysis", "chatbot", "reinforcement learning", "supervised learning",
28
+ "unsupervised learning", "convolutional neural network", "cnn", "recurrent neural network", "rnn",
29
+ "long short-term memory", "lstm", "generative adversarial network", "gan",
30
+ "transfer learning", "federated learning", "explainable ai", "xai"
31
+ }
32
+
33
+ def preprocess_text(text):
34
+ # Wrap technical terms with special tokens
35
+ for term in sorted(technical_terms, key=len, reverse=True):
36
+ pattern = re.compile(r'\b' + re.escape(term) + r'\b', re.IGNORECASE)
37
+ text = pattern.sub(lambda m: f"<keep>{m.group()}</keep>", text)
38
+ return text
39
+
40
+ def postprocess_text(text):
41
+ # Replace special tokens with original terms
42
+ return re.sub(r'<keep>(.*?)</keep>', r'**\1**', text)
43
+
44
+ def translate(text, src_lang=src_lang, tgt_lang=tgt_lang):
45
+ # Preprocess the text
46
+ preprocessed_text = preprocess_text(text)
47
+
48
+ # Tokenize the preprocessed text
49
+ inputs = tokenizer(preprocessed_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
50
+ inputs = {k: v.to(device) for k, v in inputs.items()}
51
+
52
+ # Generate translation
53
+ translated = model.generate(
54
+ **inputs,
55
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
56
+ max_length=512,
57
+ num_beams=5,
58
+ length_penalty=1.0,
59
+ early_stopping=True
60
+ )
61
+
62
+ # Decode the generated tokens
63
+ translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
64
+
65
+ # Postprocess the translated text
66
+ return postprocess_text(translated_text)
67
+
68
+ # Gradio interface
69
+ def gradio_translate(text):
70
+ return translate(text)
71
+
72
+ iface = gr.Interface(
73
+ fn=gradio_translate,
74
+ inputs=gr.Textbox(lines=5, label="English Text"),
75
+ outputs=gr.Textbox(lines=5, label="Tamil Translation"),
76
+ title="English to Tamil Translation with Technical Terms Preserved",
77
+ description="This app translates English text to Tamil while preserving technical terms."
78
+ )
79
+
80
+ iface.launch()