Spaces:
Runtime error
Runtime error
import numpy as np | |
import tensorflow as tf | |
from huggingface_hub import from_pretrained_keras | |
import gradio as gr | |
## -- DIGIT ENCODER AND DECODER -- ## | |
class CharacterTable: | |
"""Given a set of characters: | |
+ Encode them to a one-hot integer representation | |
+ Decode the one-hot or integer representation to their character output | |
+ Decode a vector of probabilities to their character output | |
""" | |
def __init__(self, chars): | |
"""Initialize character table. | |
# Arguments | |
chars: Characters that can appear in the input. | |
""" | |
self.chars = sorted(set(chars)) | |
self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) | |
self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) | |
def encode(self, C, num_rows): | |
"""One-hot encode given string C. | |
# Arguments | |
C: string, to be encoded. | |
num_rows: Number of rows in the returned one-hot encoding. This is | |
used to keep the # of rows for each data the same. | |
""" | |
x = np.zeros((num_rows, len(self.chars))) | |
for i, c in enumerate(C): | |
x[i, self.char_indices[c]] = 1 | |
return x | |
def decode(self, x, calc_argmax=True): | |
"""Decode the given vector or 2D array to their character output. | |
# Arguments | |
x: A vector or a 2D array of probabilities or one-hot representations; | |
or a vector of character indices (used with `calc_argmax=False`). | |
calc_argmax: Whether to find the character index with maximum | |
probability, defaults to `True`. | |
""" | |
if calc_argmax: | |
x = x.argmax(axis=-1) | |
return "".join(self.indices_char[x] for x in x) | |
## -- INFERENCE CODE -- ## | |
def check_num_validity(a, b): | |
"""Validates the inputs before feeding to the model. | |
Checks if both the inputs are positive integers and are not more than 5 digits. | |
""" | |
if a.isdigit() and b.isdigit(): | |
if len(a) <= 5 and len(b) <= 5: | |
return True, None | |
else: | |
return False, "Input can contain maximum of five digits" | |
else: | |
return False, "Input has to be positive integers (each can be max 5 digits)" | |
def add_2_nums(a, b): | |
"""Performs the model inference after input validation. | |
""" | |
a = a.strip() | |
b = b.strip() | |
ip_val_op = check_num_validity(a, b) | |
if ip_val_op[0]: | |
# Input encoding | |
q = f"{a.strip()}+{b.strip()}" | |
query = q + " " * (MAXLEN - len(q)) | |
rev_query = query[::-1] | |
inp = ctable.encode(rev_query, MAXLEN) | |
inp = np.expand_dims(inp, axis=0) | |
# Prediction and output decoding | |
preds = np.argmax(model.predict(inp), axis=-1) | |
guess = ctable.decode(preds[0], calc_argmax=False) | |
return guess, rev_query, query | |
else: | |
return "", "Error", ip_val_op[1] | |
## -- LSTM INFERENCE SETUP --## | |
# Setup for encoding and decoding the inputs | |
chars = "0123456789+ " | |
ctable = CharacterTable(chars) | |
DIGITS = 5 | |
REVERSE = True | |
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of | |
# int is DIGITS. | |
MAXLEN = DIGITS + 1 + DIGITS | |
# Model | |
model = from_pretrained_keras("vdprabhu/addition-lstm") | |
## -- GRADIO SETUP -- ## | |
inputs = [gr.Textbox(), gr.Textbox()] | |
outputs = [gr.Textbox(label="LSTM Output"), gr.Textbox(label="LSTM Input"), gr.Textbox(label="User Query")] | |
examples = [[53511,98888], [452,12]] | |
title = "Addition using LSTM" | |
more_text = "It is interesting to note that the input is reversed before feeding to LSTM. Sequence order inversion introduces shorter term dependencies between source and target for this problem." | |
description = f"LSTM model is used to add two numbers, provided as strings.\n\n{more_text}" | |
article = """ | |
<p style='text-align: center'> | |
<a href='https://keras.io/examples/nlp/addition_rnn/' target='_blank'>Keras Example by Smerity and others</a> | |
<br> | |
Space by Vrinda Prabhu | |
</p> | |
""" | |
gr.Interface(fn=add_2_nums, inputs=inputs, outputs=outputs, examples=examples, article=article, allow_flagging="never", analytics_enabled=False, | |
title=title, description=description).launch(enable_queue=True) | |