Kumarkishalaya commited on
Commit
cc54661
1 Parent(s): 23e7bd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -5,15 +5,20 @@ import pickle
5
  import json
6
  import keras
7
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
8
 
9
 
10
  # Define the model repository and tokenizer checkpoint
11
  model_checkpoint = "himanishprak23/neural_machine_translation"
12
  tokenizer_checkpoint = "Helsinki-NLP/opus-mt-en-hi"
13
 
14
- tokenizer_base_nmt = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
15
- model_base_nmt = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
16
-
17
  # Load the tokenizer from Helsinki-NLP and model from Hugging Face repository
18
  tokenizer_nmt = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
19
  model_nmt = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
@@ -33,9 +38,9 @@ max_len_eng = 20
33
  max_len_hin = 22
34
 
35
  def translate_text_base_nmt(input_text):
36
- tokenized_input = tokenizer_base_nmt(input_text, return_tensors='tf', max_length=128, truncation=True)
37
- generated_tokens = model_base_nmt.generate(**tokenized_input, max_length=128)
38
- predicted_text = tokenizer_nmt.decode(generated_tokens[0], skip_special_tokens=True)
39
  return predicted_text
40
 
41
  def translate_text_nmt(input_text):
 
5
  import json
6
  import keras
7
  from huggingface_hub import hf_hub_download
8
+ from transformers import pipeline
9
+
10
+
11
+ model_name = "Helsinki-NLP/opus-mt-en-hi"
12
+
13
+ tokenizer_base_nmt = MarianMTModel.from_pretrained(model_name)
14
+ model_base_nmt = AutoTokenizer.from_pretrained(model_name)
15
+
16
 
17
 
18
  # Define the model repository and tokenizer checkpoint
19
  model_checkpoint = "himanishprak23/neural_machine_translation"
20
  tokenizer_checkpoint = "Helsinki-NLP/opus-mt-en-hi"
21
 
 
 
 
22
  # Load the tokenizer from Helsinki-NLP and model from Hugging Face repository
23
  tokenizer_nmt = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
24
  model_nmt = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
 
38
  max_len_hin = 22
39
 
40
  def translate_text_base_nmt(input_text):
41
+ batch = tokenizer_base_nmt([input_text], return_tensors="pt")
42
+ generated_ids = model_base_nmt.generate(**batch)
43
+ predicted_text = tokenizer_base_nmt.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
  return predicted_text
45
 
46
  def translate_text_nmt(input_text):