File size: 1,924 Bytes
d3bc923
 
9e6c667
 
 
 
d3bc923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
from todset import todset
from keras.models import Sequential
from keras.layers import Embedding, Dense, Dropout, Flatten, PReLU
from keras.preprocessing.text import Tokenizer
from keras_self_attention import SeqSelfAttention, SeqWeightedAttention

def train(data: str, message: str):
    if "→" not in data and "\n" not in data:
        return "Dataset should be like:\nquestion→answer\nquestion→answer\netc."
    dset, responses = todset(data)
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(list(dset.keys()))
    
    vocab_size = len(tokenizer.word_index) + 1
    
    model = Sequential()
    model.add(Embedding(input_dim=vocab_size, output_dim=emb_size, input_length=inp_len))
    model.add(SeqSelfAttention())
    model.add(Flatten())
    model.add(Dense(1024, activation="relu"))
    model.add(Dropout(0.5))
    model.add(Dense(512, activation="relu"))
    model.add(Dense(512, activation="relu"))
    model.add(Dense(256, activation="relu"))
    model.add(Dense(dset_size, activation="softmax"))
    
    X = []
    y = []
    
    for key in dset:
        tokens = tokenizer.texts_to_sequences([key,])[0]
        X.append(np.array((list(tokens)+[0,]*inp_len)[:inp_len]))
        output_array = np.zeros(dset_size)
        output_array[dset[key]] = 1
        y.append(output_array)
    
    X = np.array(X)
    y = np.array(y)
    
    model.compile(loss="categorical_crossentropy", metrics=["accuracy",])
    
    model.fit(X, y, epochs=10, batch_size=8, workers=4, use_multiprocessing=True)
    tokens = tokenizer.texts_to_sequences([message,])[0]
    prediction = model.predict(np.array((list(tokens)+[0,]*inp_len)[:inp_len]))
    max_o = 0
    max_v = 0
    for ind, i in enumerate(prediction):
        if max_v < i:
            max_v = i
            max_o = ind
    return responses[ind]

iface = gr.Interface(fn=greet, inputs=["text", "text"], outputs="text")
iface.launch()