Spaces:
Runtime error
Runtime error
Commit
·
154ca7b
1
Parent(s):
b6436ac
[General] Initial commit
Browse files- .gitignore +4 -0
- .streamlit/config.toml +6 -0
- app.py +62 -0
- modules/prediction/ERCBCM.py +13 -0
- modules/prediction/__init__.py +36 -0
- modules/prediction/model_loader.py +35 -0
- modules/tokenizer/__init__.py +15 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
|
3 |
+
venv/
|
4 |
+
__pycache__/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor="#262730"
|
3 |
+
backgroundColor="#ffffff"
|
4 |
+
secondaryBackgroundColor="#f6f6f8"
|
5 |
+
textColor="#090909"
|
6 |
+
font="sans serif"
|
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
from modules.prediction import prepare, predict
|
4 |
+
|
5 |
+
STATUS_STOPPED = 120001
|
6 |
+
STATUS_SUBMIT = 120002
|
7 |
+
STATUS_ERROR = 120003
|
8 |
+
|
9 |
+
has_prepared = False
|
10 |
+
|
11 |
+
st.session_state['running_status'] = STATUS_STOPPED
|
12 |
+
|
13 |
+
if not has_prepared:
|
14 |
+
print('>>> [PREPARE] Preparing...')
|
15 |
+
prepare()
|
16 |
+
has_prepared = True
|
17 |
+
|
18 |
+
st.title('Entity Referring Classifier')
|
19 |
+
st.caption('It knows exactly when you are calling it. - Version 2.0.1208.01')
|
20 |
+
|
21 |
+
st.markdown('---')
|
22 |
+
|
23 |
+
livedemo_col1, livedemo_col2, livedemo_col3 = st.columns([12,1,6])
|
24 |
+
|
25 |
+
with livedemo_col1:
|
26 |
+
st.subheader('Live Demo')
|
27 |
+
|
28 |
+
with st.form("my_form"):
|
29 |
+
entity = st.text_input('Entity Name', 'Jimmy')
|
30 |
+
sentence = st.text_input('Text Input', 'Hey Jimmy.',
|
31 |
+
help='The classifier is going to analyze this sentence.')
|
32 |
+
if st.form_submit_button('Submit it'):
|
33 |
+
st.session_state['running_status'] = STATUS_SUBMIT
|
34 |
+
|
35 |
+
if st.session_state['running_status'] == STATUS_STOPPED:
|
36 |
+
st.info('Type something and submit to start!')
|
37 |
+
elif st.session_state['running_status'] == STATUS_SUBMIT:
|
38 |
+
if predict(sentence, entity) == 'CALLING':
|
39 |
+
st.success('It is a **calling**!')
|
40 |
+
else:
|
41 |
+
st.success('It is a **mentioning**!')
|
42 |
+
|
43 |
+
with livedemo_col2:
|
44 |
+
st.empty()
|
45 |
+
|
46 |
+
with livedemo_col3:
|
47 |
+
st.markdown("""
|
48 |
+
#### Get Started
|
49 |
+
""")
|
50 |
+
st.markdown("""
|
51 |
+
Hi! I'm the Entity Referring Classifier.
|
52 |
+
I can help you find out when you are calling it.
|
53 |
+
""")
|
54 |
+
st.markdown("""
|
55 |
+
#### Terms
|
56 |
+
""")
|
57 |
+
st.markdown("""
|
58 |
+
##### `Calling`
|
59 |
+
""")
|
60 |
+
st.markdown("""
|
61 |
+
##### `Mentioning`
|
62 |
+
""")
|
modules/prediction/ERCBCM.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import BertForSequenceClassification
|
3 |
+
|
4 |
+
class ERCBCM(nn.Module):
|
5 |
+
|
6 |
+
def __init__(self):
|
7 |
+
super(ERCBCM, self).__init__()
|
8 |
+
|
9 |
+
self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
10 |
+
|
11 |
+
def forward(self, text, label):
|
12 |
+
loss, text_fea = self.encoder(text, labels=label)[:2]
|
13 |
+
return loss, text_fea
|
modules/prediction/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
myPath = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
sys.path.insert(0, myPath + '/../../')
|
5 |
+
|
6 |
+
# ==========
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from modules.prediction.model_loader import load_checkpoint
|
11 |
+
from modules.prediction.ERCBCM import ERCBCM
|
12 |
+
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
13 |
+
|
14 |
+
erc_root_folder = './model'
|
15 |
+
|
16 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
+
|
18 |
+
# ==========
|
19 |
+
|
20 |
+
model_for_evaluate = ERCBCM().to(device)
|
21 |
+
|
22 |
+
def prepare():
|
23 |
+
load_checkpoint(erc_root_folder + '/model.pt', model_for_evaluate, device)
|
24 |
+
|
25 |
+
def predict(sentence, name):
|
26 |
+
label = torch.tensor([0])
|
27 |
+
label = label.type(torch.LongTensor)
|
28 |
+
label = label.to(device)
|
29 |
+
text = tokenizer.encode(normalize_v2(sentence, name))
|
30 |
+
text += [PAD_TOKEN_ID] * (128 - len(text))
|
31 |
+
text = torch.tensor([text])
|
32 |
+
text = text.type(torch.LongTensor)
|
33 |
+
text = text.to(device)
|
34 |
+
_, output = model_for_evaluate(text, label)
|
35 |
+
pred = torch.argmax(output, 1).tolist()[0]
|
36 |
+
return 'CALLING' if pred == 1 else 'MENTIONING'
|
modules/prediction/model_loader.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Save and Load Functions
|
4 |
+
|
5 |
+
def save_checkpoint(save_path, model, valid_loss):
|
6 |
+
if save_path == None:
|
7 |
+
return
|
8 |
+
state_dict = {'model_state_dict': model.state_dict(),
|
9 |
+
'valid_loss': valid_loss}
|
10 |
+
torch.save(state_dict, save_path)
|
11 |
+
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
|
12 |
+
|
13 |
+
def load_checkpoint(load_path, model, device):
|
14 |
+
if load_path == None:
|
15 |
+
return
|
16 |
+
state_dict = torch.load(load_path, map_location=device)
|
17 |
+
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
18 |
+
model.load_state_dict(state_dict['model_state_dict'])
|
19 |
+
return state_dict['valid_loss']
|
20 |
+
|
21 |
+
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
22 |
+
if save_path == None:
|
23 |
+
return
|
24 |
+
state_dict = {'train_loss_list': train_loss_list,
|
25 |
+
'valid_loss_list': valid_loss_list,
|
26 |
+
'global_steps_list': global_steps_list}
|
27 |
+
torch.save(state_dict, save_path)
|
28 |
+
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
|
29 |
+
|
30 |
+
def load_metrics(load_path, device):
|
31 |
+
if load_path == None:
|
32 |
+
return
|
33 |
+
state_dict = torch.load(load_path, map_location=device)
|
34 |
+
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
|
35 |
+
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|
modules/tokenizer/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer
|
2 |
+
|
3 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
4 |
+
|
5 |
+
# Parameters preparation.
|
6 |
+
MAX_SENT_LENGTH = 128
|
7 |
+
PAD_TOKEN_ID = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
8 |
+
|
9 |
+
def normalize_v2(text, entity):
|
10 |
+
text = text.lower()
|
11 |
+
entity = entity.lower()
|
12 |
+
if entity not in text:
|
13 |
+
return text
|
14 |
+
text = text.replace(entity, tokenizer.mask_token) # TODO: not sure if this will be decoded by BERT.
|
15 |
+
return text
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
torchtext
|
4 |
+
ipywidgets
|
5 |
+
transformers
|