DheivaCodes commited on
Commit
7f39ef1
Β·
verified Β·
1 Parent(s): f48b498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -10,7 +10,6 @@ from sacrebleu import corpus_bleu
10
  import os
11
  import tempfile
12
 
13
-
14
  # Load Models
15
  lang_detect_model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
16
  lang_detect_tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
@@ -50,7 +49,7 @@ dimension = corpus_embeddings.shape[1]
50
  index = faiss.IndexFlatL2(dimension)
51
  index.add(corpus_embeddings)
52
 
53
- # Detect Language
54
  def detect_language(text):
55
  inputs = lang_detect_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
56
  with torch.no_grad():
@@ -59,7 +58,7 @@ def detect_language(text):
59
  pred = torch.argmax(probs, dim=1).item()
60
  return id2lang[pred]
61
 
62
- # Translate
63
  def translate(text, src_code, tgt_code):
64
  trans_tokenizer.src_lang = src_code
65
  encoded = trans_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -75,8 +74,8 @@ def search_semantic(query, top_k=3):
75
  query_embedding = embed_model.encode([query])
76
  distances, indices = index.search(query_embedding, top_k)
77
  return [(corpus[i], float(distances[0][idx])) for idx, i in enumerate(indices[0])]
78
-
79
- # Create downloadable output file
80
  def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
81
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f:
82
  f.write(f"Detected Language: {detected_lang}\n")
@@ -88,24 +87,25 @@ def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
88
  f.write(f"\nBLEU Score: {bleu_score}")
89
  return f.name
90
 
 
91
  def full_pipeline(user_input_text, target_lang_code, human_ref=""):
92
  if not user_input_text.strip():
93
  return "Empty input", "", [], "", "", None
94
 
95
  if len(user_input_text) > 2048:
96
- return " Input too long", "Please enter shorter text (under 2000 characters).", [], "", "", None
97
 
98
  detected_lang = detect_language(user_input_text)
99
  src_nllb = xlm_to_nllb.get(detected_lang, "eng_Latn")
100
 
101
  translated = translate(user_input_text, src_nllb, target_lang_code)
102
  if not translated:
103
- return detected_lang, " Translation failed", [], "", "", None
104
 
105
  sem_results = search_semantic(translated)
106
  result_list = [f"{i+1}. {txt} (Score: {score:.2f})" for i, (txt, score) in enumerate(sem_results)]
107
 
108
- # Plot
109
  labels = [f"{i+1}" for i in range(len(sem_results))]
110
  scores = [score for _, score in sem_results]
111
  plt.figure(figsize=(6, 4))
@@ -128,8 +128,7 @@ def full_pipeline(user_input_text, target_lang_code, human_ref=""):
128
  download_file_path = save_output_to_file(detected_lang, translated, sem_results, bleu_score)
129
  return detected_lang, translated, "\n".join(result_list), plot_path, bleu_score, download_file_path
130
 
131
-
132
- # Gradio Interface
133
  gr.Interface(
134
  fn=full_pipeline,
135
  inputs=[
@@ -143,8 +142,8 @@ gr.Interface(
143
  gr.Textbox(label="Top Semantic Matches"),
144
  gr.Image(label="Semantic Similarity Plot"),
145
  gr.Textbox(label="BLEU Score"),
146
- gr.File(label="Download Translation Report") # NEW OUTPUT
147
  ],
148
- title=" Multilingual Translator + Semantic Search",
149
  description="Detects language β†’ Translates β†’ Finds related Sanskrit concepts β†’ BLEU optional β†’ Downloadable report."
150
- ).launch()
 
10
  import os
11
  import tempfile
12
 
 
13
  # Load Models
14
  lang_detect_model = AutoModelForSequenceClassification.from_pretrained("papluca/xlm-roberta-base-language-detection")
15
  lang_detect_tokenizer = AutoTokenizer.from_pretrained("papluca/xlm-roberta-base-language-detection")
 
49
  index = faiss.IndexFlatL2(dimension)
50
  index.add(corpus_embeddings)
51
 
52
+ # Language Detection
53
  def detect_language(text):
54
  inputs = lang_detect_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
55
  with torch.no_grad():
 
58
  pred = torch.argmax(probs, dim=1).item()
59
  return id2lang[pred]
60
 
61
+ # Translation
62
  def translate(text, src_code, tgt_code):
63
  trans_tokenizer.src_lang = src_code
64
  encoded = trans_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
74
  query_embedding = embed_model.encode([query])
75
  distances, indices = index.search(query_embedding, top_k)
76
  return [(corpus[i], float(distances[0][idx])) for idx, i in enumerate(indices[0])]
77
+
78
+ # Save Report
79
  def save_output_to_file(detected_lang, translated, sem_results, bleu_score):
80
  with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".txt") as f:
81
  f.write(f"Detected Language: {detected_lang}\n")
 
87
  f.write(f"\nBLEU Score: {bleu_score}")
88
  return f.name
89
 
90
+ # Full Pipeline
91
  def full_pipeline(user_input_text, target_lang_code, human_ref=""):
92
  if not user_input_text.strip():
93
  return "Empty input", "", [], "", "", None
94
 
95
  if len(user_input_text) > 2048:
96
+ return "Input too long", "Please enter shorter text (under 2000 characters).", [], "", "", None
97
 
98
  detected_lang = detect_language(user_input_text)
99
  src_nllb = xlm_to_nllb.get(detected_lang, "eng_Latn")
100
 
101
  translated = translate(user_input_text, src_nllb, target_lang_code)
102
  if not translated:
103
+ return detected_lang, "Translation failed", [], "", "", None
104
 
105
  sem_results = search_semantic(translated)
106
  result_list = [f"{i+1}. {txt} (Score: {score:.2f})" for i, (txt, score) in enumerate(sem_results)]
107
 
108
+ # Plot similarity
109
  labels = [f"{i+1}" for i in range(len(sem_results))]
110
  scores = [score for _, score in sem_results]
111
  plt.figure(figsize=(6, 4))
 
128
  download_file_path = save_output_to_file(detected_lang, translated, sem_results, bleu_score)
129
  return detected_lang, translated, "\n".join(result_list), plot_path, bleu_score, download_file_path
130
 
131
+ # Gradio UI
 
132
  gr.Interface(
133
  fn=full_pipeline,
134
  inputs=[
 
142
  gr.Textbox(label="Top Semantic Matches"),
143
  gr.Image(label="Semantic Similarity Plot"),
144
  gr.Textbox(label="BLEU Score"),
145
+ gr.File(label="Download Translation Report")
146
  ],
147
+ title="Multilingual Translator + Semantic Search",
148
  description="Detects language β†’ Translates β†’ Finds related Sanskrit concepts β†’ BLEU optional β†’ Downloadable report."
149
+ ).launch()