|
import os |
|
os.system('pip install tensorflow') |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
import json |
|
import numpy as np |
|
import gradio as gr |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from huggingface_hub.keras_mixin import from_pretrained_keras |
|
|
|
num_samples = 10000 |
|
data_path = 'fra.txt' |
|
|
|
input_texts = [] |
|
target_texts = [] |
|
input_characters = set() |
|
target_characters = set() |
|
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
lines = f.read().split("\n") |
|
for line in lines[: min(num_samples, len(lines) - 1)]: |
|
input_text, target_text, _ = line.split("\t") |
|
|
|
|
|
target_text = "\t" + target_text + "\n" |
|
input_texts.append(input_text) |
|
target_texts.append(target_text) |
|
for char in input_text: |
|
if char not in input_characters: |
|
input_characters.add(char) |
|
for char in target_text: |
|
if char not in target_characters: |
|
target_characters.add(char) |
|
|
|
input_characters = sorted(list(input_characters)) |
|
target_characters = sorted(list(target_characters)) |
|
|
|
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)]) |
|
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)]) |
|
|
|
num_encoder_tokens = len(input_characters) |
|
num_decoder_tokens = len(target_characters) |
|
max_encoder_seq_length = max([len(txt) for txt in input_texts]) |
|
max_decoder_seq_length = max([len(txt) for txt in target_texts]) |
|
|
|
model = from_pretrained_keras("keras-io/cl_s2s") |
|
print(model.summary()) |
|
latent_dim = 256 |
|
|
|
encoder_inputs = model.input[0] |
|
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output |
|
encoder_states = [state_h_enc, state_c_enc] |
|
encoder_model = keras.Model(encoder_inputs, encoder_states) |
|
|
|
decoder_inputs = tf.identity(model.input[1]) |
|
decoder_state_input_h = keras.Input(shape=(latent_dim,)) |
|
decoder_state_input_c = keras.Input(shape=(latent_dim,)) |
|
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] |
|
decoder_lstm = model.layers[3] |
|
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm( |
|
decoder_inputs, initial_state=decoder_states_inputs |
|
) |
|
decoder_states = [state_h_dec, state_c_dec] |
|
decoder_dense = model.layers[4] |
|
decoder_outputs = decoder_dense(decoder_outputs) |
|
decoder_model = keras.Model( |
|
[decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states |
|
) |
|
|
|
|
|
|
|
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items()) |
|
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items()) |
|
|
|
|
|
def decode_sequence(input_seq): |
|
|
|
|
|
input_seq2 = list() |
|
input_seq2.append(input_seq) |
|
|
|
infer_input_data = np.zeros((len(input_seq2), max_encoder_seq_length, num_encoder_tokens), dtype="float32") |
|
|
|
for i, (input_text) in enumerate((input_seq2)): |
|
for t, char in enumerate(input_text): |
|
infer_input_data[i, t, input_token_index[char]] = 1.0 |
|
infer_input_data[i, t + 1:, input_token_index[" "]] = 1.0 |
|
|
|
states_value = encoder_model.predict(infer_input_data) |
|
|
|
|
|
target_seq = np.zeros((1, 1, num_decoder_tokens)) |
|
|
|
target_seq[0, 0, target_token_index["\t"]] = 1.0 |
|
|
|
|
|
|
|
stop_condition = False |
|
decoded_sentence = "" |
|
while not stop_condition: |
|
output_tokens, h, c = decoder_model.predict([target_seq] + states_value) |
|
|
|
|
|
sampled_token_index = np.argmax(output_tokens[0, -1, :]) |
|
sampled_char = reverse_target_char_index[sampled_token_index] |
|
decoded_sentence += sampled_char |
|
|
|
|
|
|
|
if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length: |
|
stop_condition = True |
|
|
|
|
|
target_seq = np.zeros((1, 1, num_decoder_tokens)) |
|
target_seq[0, 0, sampled_token_index] = 1.0 |
|
|
|
|
|
states_value = [h, c] |
|
|
|
return decoded_sentence |
|
|
|
|
|
input_1 = gr.Textbox(lines=2) |
|
output_1 = gr.Textbox() |
|
|
|
iface = gr.Interface(decode_sequence, |
|
inputs=input_1, outputs=output_1, |
|
examples=[["Be kind."], |
|
["Hug me."]], |
|
title="Character Level Recurrent Seq2Seq Model", |
|
article="Author: <a href=\"https://huggingface.co/reichenbach\">Rishav Chandra Varma</a>") |
|
|
|
iface.launch(debug=True) |
|
|