lilingxi01 commited on
Commit
154ca7b
·
1 Parent(s): b6436ac

[General] Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .DS_Store
2
+
3
+ venv/
4
+ __pycache__/
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#262730"
3
+ backgroundColor="#ffffff"
4
+ secondaryBackgroundColor="#f6f6f8"
5
+ textColor="#090909"
6
+ font="sans serif"
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from modules.prediction import prepare, predict
4
+
5
+ STATUS_STOPPED = 120001
6
+ STATUS_SUBMIT = 120002
7
+ STATUS_ERROR = 120003
8
+
9
+ has_prepared = False
10
+
11
+ st.session_state['running_status'] = STATUS_STOPPED
12
+
13
+ if not has_prepared:
14
+ print('>>> [PREPARE] Preparing...')
15
+ prepare()
16
+ has_prepared = True
17
+
18
+ st.title('Entity Referring Classifier')
19
+ st.caption('It knows exactly when you are calling it. - Version 2.0.1208.01')
20
+
21
+ st.markdown('---')
22
+
23
+ livedemo_col1, livedemo_col2, livedemo_col3 = st.columns([12,1,6])
24
+
25
+ with livedemo_col1:
26
+ st.subheader('Live Demo')
27
+
28
+ with st.form("my_form"):
29
+ entity = st.text_input('Entity Name', 'Jimmy')
30
+ sentence = st.text_input('Text Input', 'Hey Jimmy.',
31
+ help='The classifier is going to analyze this sentence.')
32
+ if st.form_submit_button('Submit it'):
33
+ st.session_state['running_status'] = STATUS_SUBMIT
34
+
35
+ if st.session_state['running_status'] == STATUS_STOPPED:
36
+ st.info('Type something and submit to start!')
37
+ elif st.session_state['running_status'] == STATUS_SUBMIT:
38
+ if predict(sentence, entity) == 'CALLING':
39
+ st.success('It is a **calling**!')
40
+ else:
41
+ st.success('It is a **mentioning**!')
42
+
43
+ with livedemo_col2:
44
+ st.empty()
45
+
46
+ with livedemo_col3:
47
+ st.markdown("""
48
+ #### Get Started
49
+ """)
50
+ st.markdown("""
51
+ Hi! I'm the Entity Referring Classifier.
52
+ I can help you find out when you are calling it.
53
+ """)
54
+ st.markdown("""
55
+ #### Terms
56
+ """)
57
+ st.markdown("""
58
+ ##### `Calling`
59
+ """)
60
+ st.markdown("""
61
+ ##### `Mentioning`
62
+ """)
modules/prediction/ERCBCM.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import BertForSequenceClassification
3
+
4
+ class ERCBCM(nn.Module):
5
+
6
+ def __init__(self):
7
+ super(ERCBCM, self).__init__()
8
+
9
+ self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
10
+
11
+ def forward(self, text, label):
12
+ loss, text_fea = self.encoder(text, labels=label)[:2]
13
+ return loss, text_fea
modules/prediction/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ myPath = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.insert(0, myPath + '/../../')
5
+
6
+ # ==========
7
+
8
+ import torch
9
+
10
+ from modules.prediction.model_loader import load_checkpoint
11
+ from modules.prediction.ERCBCM import ERCBCM
12
+ from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
13
+
14
+ erc_root_folder = './model'
15
+
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ # ==========
19
+
20
+ model_for_evaluate = ERCBCM().to(device)
21
+
22
+ def prepare():
23
+ load_checkpoint(erc_root_folder + '/model.pt', model_for_evaluate, device)
24
+
25
+ def predict(sentence, name):
26
+ label = torch.tensor([0])
27
+ label = label.type(torch.LongTensor)
28
+ label = label.to(device)
29
+ text = tokenizer.encode(normalize_v2(sentence, name))
30
+ text += [PAD_TOKEN_ID] * (128 - len(text))
31
+ text = torch.tensor([text])
32
+ text = text.type(torch.LongTensor)
33
+ text = text.to(device)
34
+ _, output = model_for_evaluate(text, label)
35
+ pred = torch.argmax(output, 1).tolist()[0]
36
+ return 'CALLING' if pred == 1 else 'MENTIONING'
modules/prediction/model_loader.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Save and Load Functions
4
+
5
+ def save_checkpoint(save_path, model, valid_loss):
6
+ if save_path == None:
7
+ return
8
+ state_dict = {'model_state_dict': model.state_dict(),
9
+ 'valid_loss': valid_loss}
10
+ torch.save(state_dict, save_path)
11
+ print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
12
+
13
+ def load_checkpoint(load_path, model, device):
14
+ if load_path == None:
15
+ return
16
+ state_dict = torch.load(load_path, map_location=device)
17
+ print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
18
+ model.load_state_dict(state_dict['model_state_dict'])
19
+ return state_dict['valid_loss']
20
+
21
+ def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
22
+ if save_path == None:
23
+ return
24
+ state_dict = {'train_loss_list': train_loss_list,
25
+ 'valid_loss_list': valid_loss_list,
26
+ 'global_steps_list': global_steps_list}
27
+ torch.save(state_dict, save_path)
28
+ print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
29
+
30
+ def load_metrics(load_path, device):
31
+ if load_path == None:
32
+ return
33
+ state_dict = torch.load(load_path, map_location=device)
34
+ print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
35
+ return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
modules/tokenizer/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer
2
+
3
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
4
+
5
+ # Parameters preparation.
6
+ MAX_SENT_LENGTH = 128
7
+ PAD_TOKEN_ID = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
8
+
9
+ def normalize_v2(text, entity):
10
+ text = text.lower()
11
+ entity = entity.lower()
12
+ if entity not in text:
13
+ return text
14
+ text = text.replace(entity, tokenizer.mask_token) # TODO: not sure if this will be decoded by BERT.
15
+ return text
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchtext
4
+ ipywidgets
5
+ transformers