LOHJC
rework example
8843b46
import gradio as gr
import tensorflow as tf
import dependency
#hard to make it work as depend on environment
# MODEL = "cn_to_en_transformer.keras"
# transformer = tf.keras.models.load_model(MODEL)
EMBEDDING_DEPTH = dependency.EMBEDDING_DEPTH
MAX_TOKENIZE_LENGTH = dependency.MAX_TOKENIZE_LENGTH
tokenizer_cn = dependency.tokenizer_cn
tokenizer_en = dependency.tokenizer_en
num_layers = 1
d_model = EMBEDDING_DEPTH
dff = MAX_TOKENIZE_LENGTH
num_heads = 8
dropout_rate = 0.1
# Create a new model instance
transformer = dependency.Transformer(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff,\
input_vocab_size=tokenizer_cn.vocab_size,target_vocab_size=tokenizer_en.vocab_size, dropout_rate=dropout_rate)
transformer.load_weights('./checkpoints/cn_to_en_transformer_checkpoint')
def preprocess(text):
text = tf.constant(tokenizer_cn.encode(text, add_special_tokens=True))[tf.newaxis]
return text
def inference(text):
start_end = tokenizer_en.encode("", add_special_tokens=True)
start = tf.constant(start_end[0],dtype=tf.int64)[tf.newaxis]
end = tf.constant(start_end[1],dtype=tf.int64)[tf.newaxis]
# `tf.TensorArray` is required here (instead of a Python list), so that the
# dynamic-loop can be traced by `tf.function`.
output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
output_array = output_array.write(0, start)
for i in tf.range(MAX_TOKENIZE_LENGTH):
output = tf.transpose(output_array.stack())
predictions = transformer([text, output], training=False)
# Select the last token from the `seq_len` dimension.
predictions = predictions[:, -1:, :] # Shape `(batch_size, 1, vocab_size)`.
predicted_id = tf.argmax(predictions, axis=-1)
# Concatenate the `predicted_id` to the output which is given to the
# decoder as its input.
output_array = output_array.write(i+1, predicted_id[0])
if predicted_id == end:
break
text = tf.transpose(output_array.stack())
return text
def postprocess(text):
text = tokenizer_en.decode(text[0], skip_special_tokens=True)
return text
def translate(text):
if (text.strip()==""):
return ""
text = preprocess(text)
text = inference(text)
return postprocess(text)
DESCRIPTION = ""
DESCRIPTION += "<h1>中英翻译器</h1>"
DESCRIPTION += "<h1>Chinese to English translator</h1>"
DESCRIPTION += "<p>This translator is building by using transformer from scratch</p>"
DESCRIPTION += "<p>This is just a demonstration of the usage of transformer, the translation is not 100% correct</p>"
DESCRIPTION += "<ul><li><a href=\"https://medium.com/@jiachiewloh/nlp-chinese-to-english-translation-by-using-transformer-6503c1f4a139\">Article</li>"
DESCRIPTION += "<li><a href=\"https://www.kaggle.com/code/jclohjc/cn-en-translation-using-transformer\">Code</li></ul>"
with gr.Blocks(css="styles.css") as demo:
gr.HTML(DESCRIPTION)
#the input and output
with gr.Row():
input_text = gr.Text(label="中文 (Chinese)",\
info="请输入您想翻译的句子(Please enter the text to be translated)")
output_text = gr.Text(label="English (英文)",\
info="Here is the translated text (这是翻译后的句子)")
with gr.Row():
gr.Button("Translate").click(fn=translate,inputs=input_text,outputs=output_text)
gr.ClearButton().add([input_text,output_text])
#Examples
gr.Examples(examples=[["祝您有个美好的一天","Have a nice day"], ["早上好,很高兴见到你","Good Morning, nice to meet you"],
["你叫什么名字","What is your name"],["我喜欢爬山","I like climbing"],["我爱你","I love you"],
["我是一个好人","I am a good person"],["我们是一家人","We are family"]],\
inputs=[input_text,output_text],\
outputs=[output_text])
demo.launch()