vdprabhu commited on
Commit
e9ca565
1 Parent(s): f8eab5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from huggingface_hub import from_pretrained_keras
4
+
5
+ import gradio as gr
6
+
7
+ ## -- DIGIT ENCODER AND DECODER -- ##
8
+ class CharacterTable:
9
+ """Given a set of characters:
10
+ + Encode them to a one-hot integer representation
11
+ + Decode the one-hot or integer representation to their character output
12
+ + Decode a vector of probabilities to their character output
13
+ """
14
+
15
+ def __init__(self, chars):
16
+ """Initialize character table.
17
+ # Arguments
18
+ chars: Characters that can appear in the input.
19
+ """
20
+ self.chars = sorted(set(chars))
21
+ self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
22
+ self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
23
+
24
+ def encode(self, C, num_rows):
25
+ """One-hot encode given string C.
26
+ # Arguments
27
+ C: string, to be encoded.
28
+ num_rows: Number of rows in the returned one-hot encoding. This is
29
+ used to keep the # of rows for each data the same.
30
+ """
31
+ x = np.zeros((num_rows, len(self.chars)))
32
+ for i, c in enumerate(C):
33
+ x[i, self.char_indices[c]] = 1
34
+ return x
35
+
36
+ def decode(self, x, calc_argmax=True):
37
+ """Decode the given vector or 2D array to their character output.
38
+ # Arguments
39
+ x: A vector or a 2D array of probabilities or one-hot representations;
40
+ or a vector of character indices (used with `calc_argmax=False`).
41
+ calc_argmax: Whether to find the character index with maximum
42
+ probability, defaults to `True`.
43
+ """
44
+ if calc_argmax:
45
+ x = x.argmax(axis=-1)
46
+ return "".join(self.indices_char[x] for x in x)
47
+
48
+
49
+
50
+ ## -- INFERENCE CODE -- ##
51
+ def check_num_validity(a, b):
52
+ """Validates the inputs before feeding to the model.
53
+ Checks if both the inputs are positive integers and are not more than 5 digits.
54
+ """
55
+ if a.isdigit() and b.isdigit():
56
+ if len(a) <= 5 and len(b) <= 5:
57
+ return True, None
58
+ else:
59
+ return False, "Input can contain maximum of five digits"
60
+ else:
61
+ return False, "Input has to be positive integers (each can be max 5 digits)"
62
+
63
+ def add_2_nums(a, b):
64
+ """Performs the model inference after input validation.
65
+ """
66
+ a = a.strip()
67
+ b = b.strip()
68
+
69
+ ip_val_op = check_num_validity(a, b)
70
+
71
+ if ip_val_op[0]:
72
+ # Input encoding
73
+ q = f"{a.strip()}+{b.strip()}"
74
+ query = q + " " * (MAXLEN - len(q))
75
+ rev_query = query[::-1]
76
+
77
+ inp = ctable.encode(rev_query, MAXLEN)
78
+ inp = np.expand_dims(inp, axis=0)
79
+
80
+ # Prediction and output decoding
81
+ preds = np.argmax(model.predict(inp), axis=-1)
82
+ guess = ctable.decode(preds[0], calc_argmax=False)
83
+
84
+ return guess, rev_query, query
85
+ else:
86
+ return "", "Error", ip_val_op[1]
87
+
88
+
89
+ ## -- LSTM INFERENCE SETUP --##
90
+ # Setup for encoding and decoding the inputs
91
+ chars = "0123456789+ "
92
+ ctable = CharacterTable(chars)
93
+ DIGITS = 5
94
+ REVERSE = True
95
+
96
+ # Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
97
+ # int is DIGITS.
98
+ MAXLEN = DIGITS + 1 + DIGITS
99
+
100
+ # Model
101
+ model = from_pretrained_keras("vdprabhu/addition-lstm")
102
+
103
+ ## -- GRADIO SETUP -- ##
104
+ inputs = [gr.Textbox(), gr.Textbox()]
105
+ outputs = [gr.Textbox(label="LSTM Output"), gr.Textbox(label="LSTM Input"), gr.Textbox(label="User Query")]
106
+
107
+ examples = [[53511,98888], [452,12]]
108
+ title = "Addition using LSTM"
109
+
110
+ more_text = "It is interesting to note that the input is reversed before feeding to LSTM. Sequence order inversion introduces shorter term dependencies between source and target for this problem."
111
+ description = f"LSTM model is used to add two numbers, provided as strings.\n\n{more_text}"
112
+
113
+ article = """
114
+ <p style='text-align: center'>
115
+ <a href='https://keras.io/examples/nlp/addition_rnn/' target='_blank'>Keras Example by Smerity and others</a>
116
+ <br>
117
+ Space by Vrinda Prabhu
118
+ </p>
119
+ """
120
+
121
+ gr.Interface(fn=add_2_nums, inputs=inputs, outputs=outputs, examples=examples, article=article, allow_flagging="never", analytics_enabled=False,
122
+ title=title, description=description).launch(enable_queue=True)