import numpy as np from tensorflow.keras.preprocessing.sequence import pad_sequences from keras.models import load_model import pickle import gradio as gr class Translator: def __init__(self): self.tokenizer_obj_en, self.tokenizer_obj_fr = self.prepare_tokenizers() self.model = self.load_model() self.start_token = '' self.end_token = '' def prepare_tokenizers(self): with open('eng_tokenizer.pickle', 'rb') as handle: tokenizer_obj_en = pickle.load(handle) with open('fr_tokenizer.pickle', 'rb') as handle: tokenizer_obj_fr = pickle.load(handle) return tokenizer_obj_en, tokenizer_obj_fr def load_model(self): seq2seq_Model2 = load_model("best_model2.h5", compile=False) seq2seq_Model2.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return seq2seq_Model2 def prob_to_sentence(self, output_probs, tokenizer): new_index_word = {} for key, value in tokenizer.index_word.items(): if isinstance(key, np.ndarray): new_key = tuple(key.flatten()) new_index_word[new_key] = value else: new_index_word[key] = value word_indices = np.argmax(output_probs, axis=2) output_sentence = '' for idx in word_indices.flatten(): if idx == 0: continue if idx == tokenizer.word_index[self.end_token]: break word = new_index_word[idx] output_sentence += word + ' ' return output_sentence[:-1] def translate_sentence(self, input): vocab_fr = len(self.tokenizer_obj_fr.word_index) self.tokenizer_obj_fr.word_index[self.end_token] = vocab_fr + 1 input_with_start = [' '.join([self.start_token] + input.split())] self.tokenizer_obj_en.fit_on_texts(input_with_start) input_with_end = [input.split() + [self.end_token]] input_index = self.tokenizer_obj_en.texts_to_sequences(input_with_end) input_tokenized_padded = pad_sequences(input_index, maxlen=15 + 2, padding='post') return self.prob_to_sentence(self.model.predict(input_tokenized_padded, verbose=0), self.tokenizer_obj_fr) def predict(self, input): return self.translate_sentence(input) if __name__ == '__main__': model = Translator() interface = gr.Interface(fn=model.predict, inputs=gr.inputs.Textbox(lines=2, placeholder='Text to translate'), outputs='text') interface.launch()