kenichiro commited on
Commit
926183f
β€’
1 Parent(s): e96863e

Add application file

Browse files
README.md CHANGED
@@ -1,13 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Clinical Segnemt
3
- emoji: πŸŒ–
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: streamlit
7
- sdk_version: 1.17.0
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-3.0
11
- ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NLP based Chatbot in PyTorch
2
+ <img src="https://miro.medium.com/max/1400/1*VqLvWcTKgVpv1idxII591A.jpeg" width="470" height="350">
3
+
4
+
5
+ ## Simple chatbot implementation with PyTorch.
6
+
7
+ * The implementation should be easy to follow for beginners and provide a basic understanding of chatbots.
8
+
9
+ * The implementation is straightforward with a Feed Forward Neural net with 2 hidden layers.
10
+
11
+ * Customization for your own use case is super easy. Just modify intents.json with possible patterns and responses and re-run the training (see below for more info).
12
+
13
+ In [this article](https://medium.com/@mlvictoriamaslova/nlp-based-chatbot-in-pytorch-bonus-flask-and-javascript-deployment-474c4e59ceff) on Medium I explain some NLP concepts that underlies building Chatbots.
14
+
15
  ---
 
 
 
 
 
 
 
 
 
 
16
 
17
+
18
+ ## Installation
19
+
20
+ ### Create an environment
21
+
22
+ Whatever you prefer (e.g. conda or venv)
23
+
24
+ ```
25
+ mkdir myproject
26
+ $ cd myproject
27
+ $ python3 -m venv venv
28
+ ```
29
+
30
+ ### Activate it
31
+
32
+ Mac / Linux:
33
+ ```
34
+ . venv/bin/activate
35
+ ```
36
+ Windows:
37
+
38
+ ```
39
+ venv\Scripts\activate
40
+ ```
41
+
42
+ ### Install PyTorch and dependencies
43
+
44
+ For Installation of PyTorch see official website.
45
+
46
+ You also need nltk:
47
+ ```
48
+ pip install nltk
49
+ ```
50
+ If you get an error during the first run, you also need to install nltk.tokenize.punkt: Run this once in your terminal:
51
+
52
+ ```
53
+ $ python
54
+ >>> import nltk
55
+ >>> nltk.download('punkt')
56
+ ```
57
+
58
+ ### Usage
59
+
60
+ Run
61
+ ```
62
+ python train.py
63
+ ```
64
+ This will dump data.pth file. And then run
65
+ ```
66
+ python chat.py
67
+ ```
__pycache__/chat.cpython-38.pyc ADDED
Binary file (1.46 kB). View file
 
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+
3
+ from chat import get_response
4
+
5
+ app = Flask(__name__)
6
+
7
+ @app.get("/")
8
+ def index_get():
9
+ return render_template("base.html")
10
+
11
+ @app.post("/predict")
12
+ def predict():
13
+ text = request.get_json().get("message")
14
+ response = get_response(text)
15
+ message = {"answer": response}
16
+ return jsonify(message)
17
+
18
+ if __name__=="__main__":
19
+ app.run(debug=True)
chat.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import json
3
+
4
+ import torch
5
+
6
+ from nltk_utils import bag_of_words, tokenize
7
+ from run_segbot import get_model
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ with open('intents.json', 'r') as json_data:
12
+ intents = json.load(json_data)
13
+
14
+ #FILE = "data.pth"
15
+ #data = torch.load(FILE)
16
+
17
+ #input_size = data["input_size"]
18
+ #hidden_size = data["hidden_size"]
19
+ #output_size = data["output_size"]
20
+ #all_words = data['all_words']
21
+ #tags = data['tags']
22
+ #model_state = data["model_state"]
23
+
24
+ #model = NeuralNet(input_size, hidden_size, output_size).to(device)
25
+ #model.load_state_dict(model_state)
26
+ #with open('model.pickle', 'rb') as f:
27
+ # model = pickle.load(f)
28
+
29
+ model = get_model()
30
+
31
+ model.eval()
32
+
33
+ bot_name = "Sam"
34
+
35
+
36
+ def get_response(msg):
37
+ sentence = tokenize(msg)
38
+ X = bag_of_words(sentence, all_words)
39
+ X = X.reshape(1, X.shape[0])
40
+ X = torch.from_numpy(X).to(device)
41
+
42
+ output = model(X)
43
+ _, predicted = torch.max(output, dim=1)
44
+
45
+ tag = tags[predicted.item()]
46
+
47
+ probs = torch.softmax(output, dim=1)
48
+ prob = probs[0][predicted.item()]
49
+ if prob.item() > 0.75:
50
+ for intent in intents['intents']:
51
+ if tag == intent["tag"]:
52
+ return random.choice(intent['responses'])
53
+
54
+ return "I do not understand..."
55
+
56
+
57
+ if __name__ == "__main__":
58
+ print("Let's chat! (type 'quit' to exit)")
59
+ while True:
60
+ # sentence = "do you use credit cards?"
61
+ sentence = input("You: ")
62
+ if sentence == "quit":
63
+ break
64
+
65
+ resp = get_response(sentence)
66
+ print(resp)
data.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f20bb4bda5d1517c4bb6d201139d136b0840d48cda09237e92bbb5b0b1fd63f4
3
+ size 5015
index2word.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75789974bed3cd0bc31ad888f26cf977a1c14fb35bc504849fa066cab1f845dd
3
+ size 47914175
intents.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "intents": [
3
+ {
4
+ "tag": "greeting",
5
+ "patterns": [
6
+ "Hi",
7
+ "Hey",
8
+ "How are you",
9
+ "Is anyone there?",
10
+ "Hello",
11
+ "Good day"
12
+ ],
13
+ "responses": [
14
+ "Hey :-)",
15
+ "Hello, thanks for visiting",
16
+ "Hi there, what can I do for you?",
17
+ "Hi there, how can I help?"
18
+ ]
19
+ },
20
+ {
21
+ "tag": "goodbye",
22
+ "patterns": ["Bye", "See you later", "Goodbye"],
23
+ "responses": [
24
+ "See you later, thanks for visiting",
25
+ "Have a nice day",
26
+ "Bye! Come back again soon."
27
+ ]
28
+ },
29
+ {
30
+ "tag": "thanks",
31
+ "patterns": ["Thanks", "Thank you", "That's helpful", "Thank's a lot!"],
32
+ "responses": ["Happy to help!", "Any time!", "My pleasure"]
33
+ },
34
+ {
35
+ "tag": "items",
36
+ "patterns": [
37
+ "Which items do you have?",
38
+ "What kinds of items are there?",
39
+ "What do you sell?"
40
+ ],
41
+ "responses": [
42
+ "We sell coffee and tea",
43
+ "We have coffee and tea"
44
+ ]
45
+ },
46
+ {
47
+ "tag": "payments",
48
+ "patterns": [
49
+ "Do you take credit cards?",
50
+ "Do you accept Mastercard?",
51
+ "Can I pay with Paypal?",
52
+ "Are you cash only?"
53
+ ],
54
+ "responses": [
55
+ "We accept VISA, Mastercard and Paypal",
56
+ "We accept most major credit cards, and Paypal"
57
+ ]
58
+ },
59
+ {
60
+ "tag": "delivery",
61
+ "patterns": [
62
+ "How long does delivery take?",
63
+ "How long does shipping take?",
64
+ "When do I get my delivery?"
65
+ ],
66
+ "responses": [
67
+ "Delivery takes 2-4 days",
68
+ "Shipping takes 2-4 days"
69
+ ]
70
+ },
71
+ {
72
+ "tag": "funny",
73
+ "patterns": [
74
+ "Tell me a joke!",
75
+ "Tell me something funny!",
76
+ "Do you know a joke?"
77
+ ],
78
+ "responses": [
79
+ "Why did the hipster burn his mouth? He drank the coffee before it was cool.",
80
+ "What did the buffalo say when his son left for college? Bison."
81
+ ]
82
+ }
83
+ ]
84
+ }
model.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc13a6fa988683240ebd80f50a53a864fc5c9b6ad90c0e3d72c624749b542a9d
3
+ size 4948315605
nltk_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import numpy as np
3
+ #nltk.download('all')
4
+ from nltk.stem.porter import PorterStemmer
5
+ stemmer = PorterStemmer()
6
+ def tokenize(sentence):
7
+ """
8
+ split sentence into array of words/tokens
9
+ a token can be a word or punctuation character, or number
10
+ """
11
+ return nltk.word_tokenize(sentence)
12
+
13
+ def stem(word):
14
+
15
+ return stemmer.stem(word.lower())
16
+
17
+ def bag_of_words(tokenized_sentence, all_words):
18
+
19
+ tokenized_sentence = [stem(w) for w in tokenized_sentence]
20
+
21
+ bag = np.zeros(len(all_words), dtype=np.float32)
22
+ for idx, w in enumerate(all_words):
23
+ if w in tokenized_sentence:
24
+ bag[idx] = 1.0
25
+ return bag
26
+
27
+
run_segbot.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from nltk.tokenize import word_tokenize
3
+ import pickle
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+ from solver import TrainSolver
8
+
9
+ from model import PointerNetworks
10
+ import gensim
11
+ from tqdm import tqdm
12
+
13
+ class Lang:
14
+ def __init__(self, name):
15
+ self.name = name
16
+ self.word2index = {"RE_DIGITS":1,"UNKNOWN":0,"PADDING":2000001}
17
+ self.word2count = {"RE_DIGITS":1,"UNKNOWN":1,"PADDING":1}
18
+ self.index2word = {2000001: "PADDING", 1: "RE_DIGITS", 0: "UNKNOWN"}
19
+ self.n_words = 3 # Count SOS and EOS
20
+
21
+ def addSentence(self, sentence):
22
+ for word in sentence.strip('\n').strip('\r').split(' '):
23
+ self.addWord(word)
24
+
25
+ def addWord(self, word):
26
+ if word not in self.word2index:
27
+ self.word2index[word] = self.n_words
28
+ self.word2count[word] = 1
29
+ self.index2word[self.n_words] = word
30
+ self.n_words += 1
31
+ else:
32
+ self.word2count[word] += 1
33
+
34
+
35
+
36
+ def mytokenizer(inS,all_dict):
37
+
38
+ #repDig = re.sub(r'\d+[\.,/]?\d+','RE_DIGITS',inS)
39
+ #repDig = re.sub(r'\d*[\d,]*\d+', 'RE_DIGITS', inS)
40
+ toked = inS
41
+ or_toked = inS
42
+ re_unk_list = []
43
+ ori_list = []
44
+
45
+ for (i,t) in enumerate(toked):
46
+ if t not in all_dict and t not in ['RE_DIGITS']:
47
+ re_unk_list.append('UNKNOWN')
48
+ ori_list.append(or_toked[i])
49
+ else:
50
+ re_unk_list.append(t)
51
+ ori_list.append(or_toked[i])
52
+
53
+ labey_edus = [0]*len(re_unk_list)
54
+ labey_edus[-1] = 1
55
+
56
+
57
+
58
+
59
+ return ori_list,re_unk_list,labey_edus
60
+
61
+
62
+
63
+ def get_mapping(X,Y,D):
64
+
65
+ X_map = []
66
+ for w in X:
67
+ if w in D:
68
+ X_map.append(D[w])
69
+ else:
70
+ X_map.append(D['UNKNOWN'])
71
+
72
+ X_map = np.array([X_map])
73
+ Y_map = np.array([Y])
74
+
75
+
76
+
77
+ return X_map,Y_map
78
+
79
+
80
+
81
+
82
+
83
+ def get_model():
84
+ with open('model.pickle', 'rb') as f:
85
+ mysolver = pickle.load(f)
86
+ return mysolver
87
+
88
+ #for i in tqdm(range(0,26431)):
89
+ test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,index2word, fukugen)
90
+ #test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
91
+ #with open(str(i)+"seped","w")as f:
92
+ # f.write(o)
93
+ #test_batch_ave_loss, test_pre, test_rec, test_f1, visdata = mysolver.check_accuracy(X_tes, Y_tes,0)
94
+ print(test_pre, test_rec, test_f1)
95
+ #start_b = visdata[3][0]
96
+ #end_b = visdata[2][0] + 1
97
+ #segments = []
98
+
99
+ #for i, END in enumerate(end_b):
100
+ # START = start_b[i]
101
+ # segments.append(' '.join(ori_X[START:END]))
102
+
103
+ return test_pre, test_rec, test_f1
104
+
105
+
106
+
solver.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.optim as optim
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+
6
+ import random
7
+ from torch.nn.utils import clip_grad_norm
8
+ import copy
9
+ from tqdm import tqdm
10
+
11
+ import os
12
+ import pickle
13
+
14
+
15
+
16
+ def get_decoder_index_XY(batchY):
17
+ '''
18
+
19
+ :param batchY: like [0 0 1 0 0 0 0 1]
20
+ :return:
21
+ '''
22
+
23
+
24
+ returnX =[]
25
+ returnY =[]
26
+ for i in range(len(batchY)):
27
+
28
+ curY = batchY[i]
29
+ index_1 = np.where(curY==1)
30
+
31
+ decoderY = index_1[0]
32
+
33
+ if len(index_1[0]) ==1:
34
+ decoderX = np.array([0])
35
+ else:
36
+ decoderX = np.append([0],decoderY[0:-1]+1)
37
+ returnX.append(decoderX)
38
+ returnY.append(decoderY)
39
+
40
+ returnX = np.array(returnX)
41
+ returnY = np.array(returnY)
42
+
43
+ return returnX,returnY
44
+
45
+ def align_variable_numpy(X,maxL,paddingNumber):
46
+
47
+ aligned = []
48
+ for cur in X:
49
+ ext_cur = []
50
+ ext_cur.extend(cur)
51
+ ext_cur.extend([paddingNumber] * (maxL - len(cur)))
52
+ aligned.append(ext_cur)
53
+ aligned = np.array(aligned)
54
+
55
+ return aligned
56
+
57
+
58
+ def sample_a_sorted_batch_from_numpy(numpyX,numpyY,batch_size,use_cuda):
59
+
60
+
61
+ if batch_size != None:
62
+ select_index = random.sample(range(len(numpyY)), batch_size)
63
+ else:
64
+ select_index = np.array(range(len(numpyY)))
65
+
66
+ select_index = np.array(range(len(numpyX)))
67
+
68
+ batch_x = [copy.deepcopy(numpyX[i]) for i in select_index]
69
+ batch_y = [copy.deepcopy(numpyY[i]) for i in select_index]
70
+
71
+ #print(batch_y)
72
+ index_decoder_X,index_decoder_Y = get_decoder_index_XY(batch_y)
73
+ #index_decoder = [get_decoder_index_XY(i) for i in batch_y]
74
+ #index_decoder_X = [i[0] for i in index_decoder]
75
+ #index_decoder_Y = [i[1] for i in index_decoder]
76
+ #print(index_decoder_Y)
77
+
78
+
79
+ #all_lens = []
80
+ all_lens = np.array([len(x) for x in batch_y])
81
+ #for x in batch_y:
82
+ # print(x)
83
+ # try:
84
+ # all_lens.append(len(x))
85
+ # except:
86
+ # all_lens.append(1)
87
+ #all_lens = np.array(all_lens)
88
+
89
+ maxL = np.max(all_lens)
90
+
91
+ #idx = all_lens
92
+ #print(idx)
93
+ idx = np.argsort(all_lens)
94
+ idx = np.sort(idx)
95
+ #print(idx)
96
+ #idx = idx[::-1] # decreasing
97
+ #print(idx)
98
+ batch_x = [batch_x[i] for i in idx]
99
+ batch_y = [batch_y[i] for i in idx]
100
+ all_lens = all_lens[idx]
101
+
102
+ index_decoder_X = np.array([index_decoder_X[i] for i in idx])
103
+ index_decoder_Y = np.array([index_decoder_Y[i] for i in idx])
104
+ #print(index_decoder_Y)
105
+
106
+ numpy_batch_x = batch_x
107
+
108
+
109
+
110
+ batch_x = align_variable_numpy(batch_x,maxL,2000001)
111
+ batch_y = align_variable_numpy(batch_y,maxL,2)
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+ print(len(batch_x))
120
+ #batch_x = Variable(torch.from_numpy(batch_x.astype(np.int64)))
121
+ batch_x = Variable(torch.from_numpy(np.array(batch_x, dtype="int64")))
122
+
123
+
124
+ if use_cuda:
125
+ batch_x = batch_x.cuda()
126
+
127
+
128
+
129
+ return numpy_batch_x,batch_x,batch_y,index_decoder_X,index_decoder_Y,all_lens,maxL
130
+
131
+
132
+
133
+
134
+ class TrainSolver(object):
135
+ def __init__(self, model,train_x,train_y,dev_x,dev_y,save_path,batch_size,eval_size,epoch, lr,lr_decay_epoch,weight_decay,use_cuda):
136
+
137
+ self.lr = lr
138
+ self.model = model
139
+ self.epoch = epoch
140
+ self.train_x = train_x
141
+ self.train_y = train_y
142
+ self.use_cuda = use_cuda
143
+ self.batch_size = batch_size
144
+ self.lr_decay_epoch = lr_decay_epoch
145
+ self.eval_size = eval_size
146
+
147
+
148
+ self.dev_x, self.dev_y = dev_x, dev_y
149
+
150
+ self.model = model
151
+ self.save_path = save_path
152
+ self.weight_decay =weight_decay
153
+
154
+
155
+
156
+
157
+ def sample_dev(self):
158
+ test_tr_x = []
159
+ test_tr_y = []
160
+ select_index = random.sample(range(len(self.train_y)),self.eval_size)
161
+ test_tr_x = [self.train_x[n] for n in select_index]
162
+ test_tr_y = [self.train_y[n] for n in select_index]
163
+
164
+ return test_tr_x,test_tr_y
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+ def get_batch_micro_metric(self,pre_b, ground_b, x,index2word, fukugen, nloop):
173
+
174
+ tokendic = {}
175
+ #with open('index2word.pickle', 'rb') as f:
176
+ # index2word = pickle.load(f)
177
+ for n,i in enumerate(index2word):
178
+ tokendic[n] = i
179
+ All_C = []
180
+ All_R = []
181
+ All_G = []
182
+ """
183
+ for i,cur_seq_y in enumerate(zip(ground_b,fukugen[nloop])):
184
+ #print(fukugen[nloop])
185
+ fuku = cur_seq_y[1]
186
+ cur_seq_y = cur_seq_y[0]
187
+ index_of_1 = np.where(cur_seq_y==1)[0]
188
+ #print(index_of_1)
189
+ index_pre = pre_b[i]
190
+ inp = x[i]
191
+ #print(len(inp))
192
+ """
193
+ print(len(pre_b), len(ground_b), len(fukugen))
194
+ #global leng
195
+ #print(fukugen)
196
+ for i,cur_seq_y in enumerate(ground_b):
197
+ #print(fukugen[nloop])
198
+ fuku = fukugen[i]
199
+ #cur_seq_y = cur_seq_y[0]
200
+ index_of_1 = np.where(cur_seq_y==1)[0]
201
+ #print(index_of_1)
202
+ index_pre = pre_b[i]
203
+ inp = x[i]
204
+ #print(len(inp))
205
+
206
+ index_pre = np.array(index_pre)
207
+ END_B = index_of_1[-1]
208
+ index_pre = index_pre[index_pre != END_B]
209
+ index_of_1 = index_of_1[index_of_1 != END_B]
210
+
211
+ no_correct = len(np.intersect1d(list(index_of_1), list(index_pre)))
212
+ All_C.append(no_correct)
213
+ All_R.append(len(index_pre))
214
+ All_G.append(len(index_of_1))
215
+
216
+ index_of_1 = list(index_of_1)
217
+ index_pre = list(index_pre)
218
+
219
+ FN = []
220
+ FP = []
221
+ TP = []
222
+ sent = []
223
+ ex = ""
224
+ for j in inp:
225
+ sent.append(tokendic[int(j.to('cpu').detach().numpy().copy())])
226
+ for k in index_of_1:
227
+ if k not in index_pre:
228
+ FN.append(k)
229
+ if k in index_pre:
230
+ TP.append(k)
231
+ for k in index_pre:
232
+ if k not in index_of_1:
233
+ FP.append(k)
234
+ #if len(FN) == 0 and len(FP) == 0:
235
+ # continue
236
+ #for n,i in enumerate(sent):
237
+ for n,k in enumerate(zip(sent, fuku)):
238
+ f = k[1]
239
+ i = k[0]
240
+ if k == "<pad>":
241
+ continue
242
+ if n in FP:
243
+ ex += f + "<FP>"
244
+ else:
245
+ ex += f
246
+ """
247
+ if n in FN:
248
+ #ex += i + "<FN>"
249
+ ex += i
250
+ elif n in FP:
251
+ ex += i + "<FP>"
252
+ elif n in TP:
253
+ ex += i + "<TP>"
254
+ else:
255
+ ex += i
256
+ """
257
+ #with open(str(nloop)+"_sep_nounk.txt", "a")as f:
258
+ # f.write(ex+"\n")
259
+ #print(i)
260
+ #leng += 1
261
+
262
+ return All_C,All_R,All_G
263
+
264
+
265
+
266
+
267
+
268
+ def get_batch_metric(self,pre_b, ground_b):
269
+
270
+ b_pr =[]
271
+ b_re =[]
272
+ b_f1 =[]
273
+ for i,cur_seq_y in enumerate(ground_b):
274
+ index_of_1 = np.where(cur_seq_y==1)[0]
275
+ index_pre = pre_b[i]
276
+
277
+ no_correct = len(np.intersect1d(index_of_1,index_pre))
278
+
279
+ cur_pre = no_correct / len(index_pre)
280
+ cur_rec = no_correct / len(index_of_1)
281
+ cur_f1 = 2*cur_pre*cur_rec/ (cur_pre+cur_rec)
282
+
283
+ b_pr.append(cur_pre)
284
+ b_re.append(cur_rec)
285
+ b_f1.append(cur_f1)
286
+
287
+ return b_pr,b_re,b_f1
288
+
289
+
290
+
291
+ def check_accuracy(self,data2X,data2Y,index2word, fukugen2):
292
+ for nloop in tqdm(range(0,108)):
293
+ dataY = data2Y[nloop]
294
+ dataX = data2X[nloop]
295
+ fukugen = fukugen2[nloop]
296
+ #print(len(dataX), len(dataY), len(fukugen))
297
+ need_loop = int(np.ceil(len(dataY) / self.batch_size))
298
+ #need_loop = int(np.ceil(len(dataY) / 1))
299
+ all_ave_loss =[]
300
+ all_boundary =[]
301
+ all_boundary_start = []
302
+ all_align_matrix = []
303
+ all_index_decoder_y =[]
304
+ all_x_save = []
305
+
306
+ all_C =[]
307
+ all_R =[]
308
+ all_G =[]
309
+
310
+ for lp in range(need_loop):
311
+ startN = lp*self.batch_size
312
+ endN = (lp+1)*self.batch_size
313
+ if endN > len(dataY):
314
+ endN = len(dataY)
315
+ #print(fukugen)
316
+ fukuge = fukugen[startN:endN]
317
+ #print(startN, endN)
318
+ #print(len(fukugen))
319
+ #print(fukugen)
320
+ #for nloop in tqdm(range(0,26431)):
321
+ numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
322
+ dataX[startN:endN], dataY[startN:endN], None, self.use_cuda)
323
+ #numpy_batch_x, batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
324
+ # dataX, dataY, None, self.use_cuda)
325
+
326
+ batch_ave_loss, batch_boundary, batch_boundary_start, batch_align_matrix = self.model.predict(batch_x,
327
+ index_decoder_Y,
328
+ all_lens)
329
+
330
+ all_ave_loss.extend([batch_ave_loss.data.item()]) #[batch_ave_loss.data[0]]
331
+ all_boundary.extend(batch_boundary)
332
+ all_boundary_start.extend(batch_boundary_start)
333
+ all_align_matrix.extend(batch_align_matrix)
334
+ all_index_decoder_y.extend(index_decoder_Y)
335
+ all_x_save.extend(numpy_batch_x)
336
+
337
+
338
+
339
+ #print(batch_y)
340
+ ba_C,ba_R,ba_G = self.get_batch_micro_metric(batch_boundary,batch_y,batch_x,index2word, fukuge, nloop)
341
+
342
+ all_C.extend(ba_C)
343
+ all_R.extend(ba_R)
344
+ all_G.extend(ba_G)
345
+
346
+
347
+ ba_pre = np.sum(all_C)/ np.sum(all_R)
348
+ ba_rec = np.sum(all_C)/ np.sum(all_G)
349
+ ba_f1 = 2*ba_pre*ba_rec/ (ba_pre+ba_rec)
350
+
351
+
352
+ return np.mean(all_ave_loss),ba_pre,ba_rec,ba_f1, (all_x_save,all_index_decoder_y,all_boundary, all_boundary_start, all_align_matrix)
353
+
354
+
355
+
356
+
357
+
358
+
359
+
360
+ def adjust_learning_rate(self,optimizer,epoch,lr_decay=0.5, lr_decay_epoch=5):
361
+
362
+ if (epoch % lr_decay_epoch == 0) and (epoch != 0):
363
+ for param_group in optimizer.param_groups:
364
+ param_group['lr'] *= lr_decay
365
+
366
+
367
+
368
+ def train(self,n):
369
+
370
+ self.test_train_x, self.test_train_y = self.sample_dev()
371
+
372
+ optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr, weight_decay=self.weight_decay)
373
+
374
+
375
+
376
+ num_each_batch = int(np.round(len(self.train_y) / self.batch_size))
377
+
378
+ #os.mkdir(self.save_path)
379
+
380
+ best_i =0
381
+ best_f1 =0
382
+
383
+ for epoch in range(self.epoch):
384
+ print(epoch)
385
+ self.adjust_learning_rate(optimizer, epoch, 0.8, self.lr_decay_epoch)
386
+
387
+ track_epoch_loss = []
388
+ for iter in tqdm(range(num_each_batch)):
389
+ #print("epoch:%d,iteration:%d" % (epoch, iter))
390
+
391
+ self.model.zero_grad()
392
+
393
+ numpy_batch_x,batch_x, batch_y, index_decoder_X, index_decoder_Y, all_lens, maxL = sample_a_sorted_batch_from_numpy(
394
+ self.train_x, self.train_y, self.batch_size, self.use_cuda)
395
+
396
+ neg_loss = self.model.neg_log_likelihood(batch_x, index_decoder_X, index_decoder_Y,all_lens)
397
+
398
+
399
+
400
+ neg_loss_v = float(neg_loss.data.item())
401
+ #print(neg_loss_v)
402
+ track_epoch_loss.append(neg_loss_v)
403
+
404
+ neg_loss.backward()
405
+
406
+ clip_grad_norm(self.model.parameters(), 5)
407
+ optimizer.step()
408
+
409
+
410
+ #TODO: after each epoch,check accuracy
411
+
412
+
413
+ self.model.eval()
414
+
415
+ #tr_batch_ave_loss, tr_pre, tr_rec, tr_f1 ,visdata= self.check_accuracy(self.test_train_x,self.test_train_y)
416
+
417
+ dev_batch_ave_loss, dev_pre, dev_rec, dev_f1, visdata =self.check_accuracy(self.dev_x,self.dev_y,n)
418
+ print("f1="+str(dev_f1))
419
+ print("loss="+str(dev_batch_ave_loss))
420
+ """
421
+ if best_f1 < dev_f1:
422
+ best_f1 = dev_f1
423
+ best_rec = dev_rec
424
+ best_pre = dev_pre
425
+ best_i = epoch
426
+
427
+
428
+
429
+ save_data = [epoch,dev_batch_ave_loss,dev_pre,dev_rec,dev_f1]
430
+
431
+
432
+ save_file_name = 'bs_{}_es_{}_lr_{}_lrdc_{}_wd_{}_epoch_loss_acc_pk_wd.txt'.format(self.batch_size,self.eval_size,self.lr,self.lr_decay_epoch,self.weight_decay)
433
+ """
434
+ #with open(os.path.join(self.save_path,save_file_name), 'a') as f:
435
+ # f.write(','.join(map(str,save_data))+'\n')
436
+
437
+
438
+ #if epoch % 1 ==0 and epoch !=0:
439
+ # torch.save(self.model, os.path.join(self.save_path,r'model_epoch_%d.torchsave'%(epoch)))
440
+
441
+
442
+ self.model.train()
443
+
444
+ #return best_i,best_pre,best_rec,best_f1
445
+ return best_i,best_f1,n
static/app.js ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Chatbox {
2
+ constructor() {
3
+ this.args = {
4
+ openButton: document.querySelector('.chatbox__button'),
5
+ chatBox: document.querySelector('.chatbox__support'),
6
+ sendButton: document.querySelector('.send__button')
7
+ }
8
+
9
+ this.state = false;
10
+ this.messages = [];
11
+ }
12
+
13
+ display() {
14
+ const {openButton, chatBox, sendButton} = this.args;
15
+
16
+ openButton.addEventListener('click', () => this.toggleState(chatBox))
17
+
18
+ sendButton.addEventListener('click', () => this.onSendButton(chatBox))
19
+
20
+ const node = chatBox.querySelector('input');
21
+ node.addEventListener("keyup", ({key}) => {
22
+ if (key === "Enter") {
23
+ this.onSendButton(chatBox)
24
+ }
25
+ })
26
+ }
27
+
28
+ toggleState(chatbox) {
29
+ this.state = !this.state;
30
+
31
+ // show or hides the box
32
+ if(this.state) {
33
+ chatbox.classList.add('chatbox--active')
34
+ } else {
35
+ chatbox.classList.remove('chatbox--active')
36
+ }
37
+ }
38
+
39
+ onSendButton(chatbox) {
40
+ var textField = chatbox.querySelector('input');
41
+ let text1 = textField.value
42
+ if (text1 === "") {
43
+ return;
44
+ }
45
+
46
+ let msg1 = { name: "User", message: text1 }
47
+ this.messages.push(msg1);
48
+
49
+ fetch('http://127.0.0.1:5000/predict', {
50
+ method: 'POST',
51
+ body: JSON.stringify({ message: text1 }),
52
+ mode: 'cors',
53
+ headers: {
54
+ 'Content-Type': 'application/json'
55
+ },
56
+ })
57
+ .then(r => r.json())
58
+ .then(r => {
59
+ let msg2 = { name: "Sam", message: r.answer };
60
+ this.messages.push(msg2);
61
+ this.updateChatText(chatbox)
62
+ textField.value = ''
63
+
64
+ }).catch((error) => {
65
+ console.error('Error:', error);
66
+ this.updateChatText(chatbox)
67
+ textField.value = ''
68
+ });
69
+ }
70
+
71
+ updateChatText(chatbox) {
72
+ var html = '';
73
+ this.messages.slice().reverse().forEach(function(item, index) {
74
+ if (item.name === "Sam")
75
+ {
76
+ html += '<div class="messages__item messages__item--visitor">' + item.message + '</div>'
77
+ }
78
+ else
79
+ {
80
+ html += '<div class="messages__item messages__item--operator">' + item.message + '</div>'
81
+ }
82
+ });
83
+
84
+ const chatmessage = chatbox.querySelector('.chatbox__messages');
85
+ chatmessage.innerHTML = html;
86
+ }
87
+ }
88
+
89
+
90
+ const chatbox = new Chatbox();
91
+ chatbox.display();
static/images/chatbox-icon.svg ADDED
static/style.css ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ margin: 0;
4
+ padding: 0;
5
+ }
6
+
7
+ body {
8
+ font-family: 'Nunito', sans-serif;
9
+ font-weight: 400;
10
+ font-size: 100%;
11
+ background: #F1F1F1;
12
+ }
13
+
14
+ *, html {
15
+ --primaryGradient: linear-gradient(93.12deg, #581B98 0.52%, #9C1DE7 100%);
16
+ --secondaryGradient: linear-gradient(268.91deg, #581B98 -2.14%, #9C1DE7 99.69%);
17
+ --primaryBoxShadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
18
+ --secondaryBoxShadow: 0px -10px 15px rgba(0, 0, 0, 0.1);
19
+ --primary: #581B98;
20
+ }
21
+
22
+ /* CHATBOX
23
+ =============== */
24
+ .chatbox {
25
+ position: absolute;
26
+ bottom: 30px;
27
+ right: 30px;
28
+ }
29
+
30
+ /* CONTENT IS CLOSE */
31
+ .chatbox__support {
32
+ display: flex;
33
+ flex-direction: column;
34
+ background: #eee;
35
+ width: 300px;
36
+ height: 350px;
37
+ z-index: -123456;
38
+ opacity: 0;
39
+ transition: all .5s ease-in-out;
40
+ }
41
+
42
+ /* CONTENT ISOPEN */
43
+ .chatbox--active {
44
+ transform: translateY(-40px);
45
+ z-index: 123456;
46
+ opacity: 1;
47
+
48
+ }
49
+
50
+ /* BUTTON */
51
+ .chatbox__button {
52
+ text-align: right;
53
+ }
54
+
55
+ .send__button {
56
+ padding: 6px;
57
+ background: transparent;
58
+ border: none;
59
+ outline: none;
60
+ cursor: pointer;
61
+ }
62
+
63
+
64
+ /* HEADER */
65
+ .chatbox__header {
66
+ position: sticky;
67
+ top: 0;
68
+ background: orange;
69
+ }
70
+
71
+ /* MESSAGES */
72
+ .chatbox__messages {
73
+ margin-top: auto;
74
+ display: flex;
75
+ overflow-y: scroll;
76
+ flex-direction: column-reverse;
77
+ }
78
+
79
+ .messages__item {
80
+ background: orange;
81
+ max-width: 60.6%;
82
+ width: fit-content;
83
+ }
84
+
85
+ .messages__item--operator {
86
+ margin-left: auto;
87
+ }
88
+
89
+ .messages__item--visitor {
90
+ margin-right: auto;
91
+ }
92
+
93
+ /* FOOTER */
94
+ .chatbox__footer {
95
+ position: sticky;
96
+ bottom: 0;
97
+ }
98
+
99
+ .chatbox__support {
100
+ background: #f9f9f9;
101
+ height: 450px;
102
+ width: 350px;
103
+ box-shadow: 0px 0px 15px rgba(0, 0, 0, 0.1);
104
+ border-top-left-radius: 20px;
105
+ border-top-right-radius: 20px;
106
+ }
107
+
108
+ /* HEADER */
109
+ .chatbox__header {
110
+ background: var(--primaryGradient);
111
+ display: flex;
112
+ flex-direction: row;
113
+ align-items: center;
114
+ justify-content: center;
115
+ padding: 15px 20px;
116
+ border-top-left-radius: 20px;
117
+ border-top-right-radius: 20px;
118
+ box-shadow: var(--primaryBoxShadow);
119
+ }
120
+
121
+ .chatbox__image--header {
122
+ margin-right: 10px;
123
+ }
124
+
125
+ .chatbox__heading--header {
126
+ font-size: 1.2rem;
127
+ color: white;
128
+ }
129
+
130
+ .chatbox__description--header {
131
+ font-size: .9rem;
132
+ color: white;
133
+ }
134
+
135
+ /* Messages */
136
+ .chatbox__messages {
137
+ padding: 0 20px;
138
+ }
139
+
140
+ .messages__item {
141
+ margin-top: 10px;
142
+ background: #E0E0E0;
143
+ padding: 8px 12px;
144
+ max-width: 70%;
145
+ }
146
+
147
+ .messages__item--visitor,
148
+ .messages__item--typing {
149
+ border-top-left-radius: 20px;
150
+ border-top-right-radius: 20px;
151
+ border-bottom-right-radius: 20px;
152
+ }
153
+
154
+ .messages__item--operator {
155
+ border-top-left-radius: 20px;
156
+ border-top-right-radius: 20px;
157
+ border-bottom-left-radius: 20px;
158
+ background: var(--primary);
159
+ color: white;
160
+ }
161
+
162
+ /* FOOTER */
163
+ .chatbox__footer {
164
+ display: flex;
165
+ flex-direction: row;
166
+ align-items: center;
167
+ justify-content: space-between;
168
+ padding: 20px 20px;
169
+ background: var(--secondaryGradient);
170
+ box-shadow: var(--secondaryBoxShadow);
171
+ border-bottom-right-radius: 10px;
172
+ border-bottom-left-radius: 10px;
173
+ margin-top: 20px;
174
+ }
175
+
176
+ .chatbox__footer input {
177
+ width: 80%;
178
+ border: none;
179
+ padding: 10px 10px;
180
+ border-radius: 30px;
181
+ text-align: left;
182
+ }
183
+
184
+ .chatbox__send--footer {
185
+ color: white;
186
+ }
187
+
188
+ .chatbox__button button,
189
+ .chatbox__button button:focus,
190
+ .chatbox__button button:visited {
191
+ padding: 10px;
192
+ background: white;
193
+ border: none;
194
+ outline: none;
195
+ border-top-left-radius: 50px;
196
+ border-top-right-radius: 50px;
197
+ border-bottom-left-radius: 50px;
198
+ box-shadow: 0px 10px 15px rgba(0, 0, 0, 0.1);
199
+ cursor: pointer;
200
+ }
templates/base.html ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
4
+
5
+ <head>
6
+ <meta charset="UTF-8">
7
+ <title>Chatbot</title>
8
+ </head>
9
+ <body>
10
+ <div class="container">
11
+ <div class="chatbox">
12
+ <div class="chatbox__support">
13
+ <div class="chatbox__header">
14
+ <div class="chatbox__image--header">
15
+ <img src="https://img.icons8.com/color/48/000000/circled-user-female-skin-type-5--v1.png" alt="image">
16
+ </div>
17
+ <div class="chatbox__content--header">
18
+ <h4 class="chatbox__heading--header">Chat support</h4>
19
+ <p class="chatbox__description--header">Hi. My name is Sam. How can I help you?</p>
20
+ </div>
21
+ </div>
22
+ <div class="chatbox__messages">
23
+ <div></div>
24
+ </div>
25
+ <div class="chatbox__footer">
26
+ <input type="text" placeholder="Write a message...">
27
+ <button class="chatbox__send--footer send__button">Send</button>
28
+ </div>
29
+ </div>
30
+ <div class="chatbox__button">
31
+ <button><img src="{{ url_for('static', filename='images/chatbox-icon.svg') }}" /></button>
32
+ </div>
33
+ </div>
34
+ </div>
35
+
36
+ <script>
37
+ $SCRIPT_ROOT = {{ request.script_root|tojson }};
38
+ </script>
39
+ <script type="text/javascript" src="{{ url_for('static', filename='app.js') }}"></script>
40
+
41
+ </body>
42
+ </html>