addition-lstm / app.py
vdprabhu's picture
Update app.py
ad514b8
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import gradio as gr
## -- DIGIT ENCODER AND DECODER -- ##
class CharacterTable:
"""Given a set of characters:
+ Encode them to a one-hot integer representation
+ Decode the one-hot or integer representation to their character output
+ Decode a vector of probabilities to their character output
"""
def __init__(self, chars):
"""Initialize character table.
# Arguments
chars: Characters that can appear in the input.
"""
self.chars = sorted(set(chars))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
def encode(self, C, num_rows):
"""One-hot encode given string C.
# Arguments
C: string, to be encoded.
num_rows: Number of rows in the returned one-hot encoding. This is
used to keep the # of rows for each data the same.
"""
x = np.zeros((num_rows, len(self.chars)))
for i, c in enumerate(C):
x[i, self.char_indices[c]] = 1
return x
def decode(self, x, calc_argmax=True):
"""Decode the given vector or 2D array to their character output.
# Arguments
x: A vector or a 2D array of probabilities or one-hot representations;
or a vector of character indices (used with `calc_argmax=False`).
calc_argmax: Whether to find the character index with maximum
probability, defaults to `True`.
"""
if calc_argmax:
x = x.argmax(axis=-1)
return "".join(self.indices_char[x] for x in x)
## -- INFERENCE CODE -- ##
def check_num_validity(a, b):
"""Validates the inputs before feeding to the model.
Checks if both the inputs are positive integers and are not more than 5 digits.
"""
if a.isdigit() and b.isdigit():
if len(a) <= 5 and len(b) <= 5:
return True, None
else:
return False, "Input can contain maximum of five digits"
else:
return False, "Input has to be positive integers (each can be max 5 digits)"
def add_2_nums(a, b):
"""Performs the model inference after input validation.
"""
a = a.strip()
b = b.strip()
ip_val_op = check_num_validity(a, b)
if ip_val_op[0]:
# Input encoding
q = f"{a.strip()}+{b.strip()}"
query = q + " " * (MAXLEN - len(q))
rev_query = query[::-1]
inp = ctable.encode(rev_query, MAXLEN)
inp = np.expand_dims(inp, axis=0)
# Prediction and output decoding
preds = np.argmax(model.predict(inp), axis=-1)
guess = ctable.decode(preds[0], calc_argmax=False)
return guess, rev_query, query
else:
return "", "Error", ip_val_op[1]
## -- LSTM INFERENCE SETUP --##
# Setup for encoding and decoding the inputs
chars = "0123456789+ "
ctable = CharacterTable(chars)
DIGITS = 5
REVERSE = True
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
# int is DIGITS.
MAXLEN = DIGITS + 1 + DIGITS
# Model
model = from_pretrained_keras("keras-io/addition-lstm")
## -- GRADIO SETUP -- ##
inputs = [gr.Textbox(), gr.Textbox()]
outputs = [gr.Textbox(label="LSTM Output"), gr.Textbox(label="LSTM Input"), gr.Textbox(label="User Query")]
examples = [[53511,98888], [452,12]]
title = "Addition using LSTM"
more_text = "It is interesting to note that the input is reversed before feeding to LSTM. Sequence order inversion introduces shorter term dependencies between source and target for this problem."
description = f"LSTM model is used to add two numbers, provided as strings.\n\n{more_text}"
article = """
<p style='text-align: center'>
<a href='https://keras.io/examples/nlp/addition_rnn/' target='_blank'>Keras Example by Smerity and others</a>
<br>
Space by Vrinda Prabhu
</p>
"""
gr.Interface(fn=add_2_nums, inputs=inputs, outputs=outputs, examples=examples, article=article, allow_flagging="never", analytics_enabled=False,
title=title, description=description).launch(enable_queue=True)