Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from huggingface_hub import from_pretrained_keras
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
## -- DIGIT ENCODER AND DECODER -- ##
|
8 |
+
class CharacterTable:
|
9 |
+
"""Given a set of characters:
|
10 |
+
+ Encode them to a one-hot integer representation
|
11 |
+
+ Decode the one-hot or integer representation to their character output
|
12 |
+
+ Decode a vector of probabilities to their character output
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, chars):
|
16 |
+
"""Initialize character table.
|
17 |
+
# Arguments
|
18 |
+
chars: Characters that can appear in the input.
|
19 |
+
"""
|
20 |
+
self.chars = sorted(set(chars))
|
21 |
+
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
|
22 |
+
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
|
23 |
+
|
24 |
+
def encode(self, C, num_rows):
|
25 |
+
"""One-hot encode given string C.
|
26 |
+
# Arguments
|
27 |
+
C: string, to be encoded.
|
28 |
+
num_rows: Number of rows in the returned one-hot encoding. This is
|
29 |
+
used to keep the # of rows for each data the same.
|
30 |
+
"""
|
31 |
+
x = np.zeros((num_rows, len(self.chars)))
|
32 |
+
for i, c in enumerate(C):
|
33 |
+
x[i, self.char_indices[c]] = 1
|
34 |
+
return x
|
35 |
+
|
36 |
+
def decode(self, x, calc_argmax=True):
|
37 |
+
"""Decode the given vector or 2D array to their character output.
|
38 |
+
# Arguments
|
39 |
+
x: A vector or a 2D array of probabilities or one-hot representations;
|
40 |
+
or a vector of character indices (used with `calc_argmax=False`).
|
41 |
+
calc_argmax: Whether to find the character index with maximum
|
42 |
+
probability, defaults to `True`.
|
43 |
+
"""
|
44 |
+
if calc_argmax:
|
45 |
+
x = x.argmax(axis=-1)
|
46 |
+
return "".join(self.indices_char[x] for x in x)
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
## -- INFERENCE CODE -- ##
|
51 |
+
def check_num_validity(a, b):
|
52 |
+
"""Validates the inputs before feeding to the model.
|
53 |
+
Checks if both the inputs are positive integers and are not more than 5 digits.
|
54 |
+
"""
|
55 |
+
if a.isdigit() and b.isdigit():
|
56 |
+
if len(a) <= 5 and len(b) <= 5:
|
57 |
+
return True, None
|
58 |
+
else:
|
59 |
+
return False, "Input can contain maximum of five digits"
|
60 |
+
else:
|
61 |
+
return False, "Input has to be positive integers (each can be max 5 digits)"
|
62 |
+
|
63 |
+
def add_2_nums(a, b):
|
64 |
+
"""Performs the model inference after input validation.
|
65 |
+
"""
|
66 |
+
a = a.strip()
|
67 |
+
b = b.strip()
|
68 |
+
|
69 |
+
ip_val_op = check_num_validity(a, b)
|
70 |
+
|
71 |
+
if ip_val_op[0]:
|
72 |
+
# Input encoding
|
73 |
+
q = f"{a.strip()}+{b.strip()}"
|
74 |
+
query = q + " " * (MAXLEN - len(q))
|
75 |
+
rev_query = query[::-1]
|
76 |
+
|
77 |
+
inp = ctable.encode(rev_query, MAXLEN)
|
78 |
+
inp = np.expand_dims(inp, axis=0)
|
79 |
+
|
80 |
+
# Prediction and output decoding
|
81 |
+
preds = np.argmax(model.predict(inp), axis=-1)
|
82 |
+
guess = ctable.decode(preds[0], calc_argmax=False)
|
83 |
+
|
84 |
+
return guess, rev_query, query
|
85 |
+
else:
|
86 |
+
return "", "Error", ip_val_op[1]
|
87 |
+
|
88 |
+
|
89 |
+
## -- LSTM INFERENCE SETUP --##
|
90 |
+
# Setup for encoding and decoding the inputs
|
91 |
+
chars = "0123456789+ "
|
92 |
+
ctable = CharacterTable(chars)
|
93 |
+
DIGITS = 5
|
94 |
+
REVERSE = True
|
95 |
+
|
96 |
+
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
|
97 |
+
# int is DIGITS.
|
98 |
+
MAXLEN = DIGITS + 1 + DIGITS
|
99 |
+
|
100 |
+
# Model
|
101 |
+
model = from_pretrained_keras("vdprabhu/addition-lstm")
|
102 |
+
|
103 |
+
## -- GRADIO SETUP -- ##
|
104 |
+
inputs = [gr.Textbox(), gr.Textbox()]
|
105 |
+
outputs = [gr.Textbox(label="LSTM Output"), gr.Textbox(label="LSTM Input"), gr.Textbox(label="User Query")]
|
106 |
+
|
107 |
+
examples = [[53511,98888], [452,12]]
|
108 |
+
title = "Addition using LSTM"
|
109 |
+
|
110 |
+
more_text = "It is interesting to note that the input is reversed before feeding to LSTM. Sequence order inversion introduces shorter term dependencies between source and target for this problem."
|
111 |
+
description = f"LSTM model is used to add two numbers, provided as strings.\n\n{more_text}"
|
112 |
+
|
113 |
+
article = """
|
114 |
+
<p style='text-align: center'>
|
115 |
+
<a href='https://keras.io/examples/nlp/addition_rnn/' target='_blank'>Keras Example by Smerity and others</a>
|
116 |
+
<br>
|
117 |
+
Space by Vrinda Prabhu
|
118 |
+
</p>
|
119 |
+
"""
|
120 |
+
|
121 |
+
gr.Interface(fn=add_2_nums, inputs=inputs, outputs=outputs, examples=examples, article=article, allow_flagging="never", analytics_enabled=False,
|
122 |
+
title=title, description=description).launch(enable_queue=True)
|