ierhon's picture
Update app.py
9e6c667
raw history blame
No virus
1.92 kB
import gradio as gr
from todset import todset
from keras.models import Sequential
from keras.layers import Embedding, Dense, Dropout, Flatten, PReLU
from keras.preprocessing.text import Tokenizer
from keras_self_attention import SeqSelfAttention, SeqWeightedAttention
def train(data: str, message: str):
if "→" not in data and "\n" not in data:
return "Dataset should be like:\nquestion→answer\nquestion→answer\netc."
dset, responses = todset(data)
tokenizer = Tokenizer()
tokenizer.fit_on_texts(list(dset.keys()))
vocab_size = len(tokenizer.word_index) + 1
model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim=emb_size, input_length=inp_len))
model.add(SeqSelfAttention())
model.add(Flatten())
model.add(Dense(1024, activation="relu"))
model.add(Dropout(0.5))
model.add(Dense(512, activation="relu"))
model.add(Dense(512, activation="relu"))
model.add(Dense(256, activation="relu"))
model.add(Dense(dset_size, activation="softmax"))
X = []
y = []
for key in dset:
tokens = tokenizer.texts_to_sequences([key,])[0]
X.append(np.array((list(tokens)+[0,]*inp_len)[:inp_len]))
output_array = np.zeros(dset_size)
output_array[dset[key]] = 1
y.append(output_array)
X = np.array(X)
y = np.array(y)
model.compile(loss="categorical_crossentropy", metrics=["accuracy",])
model.fit(X, y, epochs=10, batch_size=8, workers=4, use_multiprocessing=True)
tokens = tokenizer.texts_to_sequences([message,])[0]
prediction = model.predict(np.array((list(tokens)+[0,]*inp_len)[:inp_len]))
max_o = 0
max_v = 0
for ind, i in enumerate(prediction):
if max_v < i:
max_v = i
max_o = ind
return responses[ind]
iface = gr.Interface(fn=greet, inputs=["text", "text"], outputs="text")
iface.launch()