Kumarkishalaya's picture
Update app.py
f73c3a8 verified
raw
history blame
No virus
4.38 kB
import gradio as gr
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, MarianMTModel
from tensorflow.keras.models import load_model
import pickle
import json
import keras
from huggingface_hub import hf_hub_download
from transformers import pipeline
import torch
import os
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer_base_nmt = MarianMTModel.from_pretrained(model_name)
model_base_nmt = AutoTokenizer.from_pretrained(model_name)
# Define the model repository and tokenizer checkpoint
model_checkpoint = "himanishprak23/neural_machine_translation"
tokenizer_checkpoint = "Helsinki-NLP/opus-mt-en-hi"
# Load the tokenizer from Helsinki-NLP and model from Hugging Face repository
tokenizer_nmt = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
model_nmt = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# Loading models, tokenizer & variables for trained LSTM translation model.
#repo_id = "Kumarkishalaya/lstm-eng-to-hin"
#lstm_filename = "seq2seq_model.keras"
# Re-download the file
#lstm_model_path = hf_hub_download(repo_id=repo_id, filename=lstm_filename, force_download=True)
model_lstm = load_model('seq2seq_model.h5')
with open('eng_tokenizer.pkl', 'rb') as file:
eng_tokenizer = pickle.load(file)
with open('hin_tokenizer.pkl', 'rb') as file:
hin_tokenizer = pickle.load(file)
max_len_eng = 20
max_len_hin = 22
def translate_text_base_nmt(input_text):
batch = tokenizer_base_nmt([input_text], return_tensors="pt")
generated_ids = model_base_nmt.generate(**batch)
predicted_text = tokenizer_base_nmt.batch_decode(generated_ids, skip_special_tokens=True)[0]
return predicted_text
def translate_text_nmt(input_text):
tokenized_input = tokenizer_nmt(input_text, return_tensors='tf', max_length=128, truncation=True)
generated_tokens = model_nmt.generate(**tokenized_input, max_length=128)
predicted_text = tokenizer_nmt.decode(generated_tokens[0], skip_special_tokens=True)
return predicted_text
def translate_text_lstm(sentence, model, eng_tokenizer, hin_tokenizer, max_len_eng, max_len_hin):
# Tokenize and pad the input sentence
input_seq = eng_tokenizer.texts_to_sequences([sentence])
input_seq = pad_sequences(input_seq, maxlen=max_len_eng, padding='post')
# Initialize target sequence with start token
target_seq = np.zeros((1, 1))
target_seq[0, 0] = hin_tokenizer.word_index['start']
# Create reverse word index for Hindi
reverse_word_index = dict([(idx, word) for word, idx in hin_tokenizer.word_index.items()])
decoded_sentence = []
for _ in range(max_len_hin):
output = model.predict([input_seq, target_seq], verbose=0)
sampled_token_index = np.argmax(output[0, -1, :])
sampled_word = reverse_word_index.get(sampled_token_index, '')
if sampled_word == 'end' or sampled_word == '' or len(decoded_sentence) >= max_len_hin - 1:
break
decoded_sentence.append(sampled_word)
# Update target sequence
target_seq = np.zeros((1, len(decoded_sentence) + 1))
for t, word in enumerate(decoded_sentence):
target_seq[0, t] = hin_tokenizer.word_index.get(word, 0) # Use 0 for unknown words
target_seq[0, len(decoded_sentence)] = sampled_token_index
return ' '.join(decoded_sentence)
def translate_text(input_text):
translation_lstm = translate_text_lstm(input_text, model_lstm, eng_tokenizer, hin_tokenizer, max_len_eng, max_len_hin)
translation_nmt_base = translate_text_base_nmt(input_text)
translation_nmt_finetuned = translate_text_nmt(input_text)
return translation_lstm, translation_nmt_base, translation_nmt_finetuned
# Create the Gradio interface
iface = gr.Interface(
fn=translate_text,
inputs=gr.components.Textbox(lines=2, placeholder="Enter text to translate from English to Hindi..."),
outputs=[
gr.components.Textbox(label="Translation (LSTM Model)"),
gr.components.Textbox(label="Translation (Base Helsinki Model)"),
gr.components.Textbox(label="Translation (Fine-tuned Helsinki Model)")
],
title="English to Hindi Translator",
description="Enter English text and get the Hindi translation from three different models: LSTM, Base Helsinki-NLP, and Fine-tuned Helsinki-NLP."
)
# Launch the Gradio app
iface.launch()