Spaces:
Sleeping
Sleeping
Minh Q. Le
commited on
Commit
·
a446b0b
1
Parent(s):
e26ce64
Pushed COSMIC code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Model/.gitignore +1 -0
- Model/COSMIC/erc_training/commonsense_model.py +345 -0
- Model/COSMIC/erc_training/dataloader.py +276 -0
- Model/COSMIC/erc_training/model.py +229 -0
- Model/COSMIC/erc_training/predict_epik.py +198 -0
- Model/COSMIC/feature_extraction/comet/__init__.py +0 -0
- Model/COSMIC/feature_extraction/comet/csk_feature_extract.py +110 -0
- Model/COSMIC/feature_extraction/comet/src/__init__.py +0 -0
- Model/COSMIC/feature_extraction/comet/src/data/atomic.py +337 -0
- Model/COSMIC/feature_extraction/comet/src/data/conceptnet.py +342 -0
- Model/COSMIC/feature_extraction/comet/src/data/config.py +186 -0
- Model/COSMIC/feature_extraction/comet/src/data/data.py +85 -0
- Model/COSMIC/feature_extraction/comet/src/data/utils.py +134 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/atomic_evaluate.py +40 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/conceptnet_evaluate.py +82 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/conceptnet_generate.py +112 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/evaluate.py +85 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/generate.py +72 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/sampler.py +329 -0
- Model/COSMIC/feature_extraction/comet/src/evaluate/utils.py +39 -0
- Model/COSMIC/feature_extraction/comet/src/interactive/functions.py +376 -0
- Model/COSMIC/feature_extraction/comet/src/main.py +19 -0
- Model/COSMIC/feature_extraction/comet/src/main_atomic.py +125 -0
- Model/COSMIC/feature_extraction/comet/src/main_conceptnet.py +138 -0
- Model/COSMIC/feature_extraction/comet/src/models/gpt.py +311 -0
- Model/COSMIC/feature_extraction/comet/src/models/models.py +32 -0
- Model/COSMIC/feature_extraction/comet/src/models/utils.py +12 -0
- Model/COSMIC/feature_extraction/comet/src/train/atomic_train.py +76 -0
- Model/COSMIC/feature_extraction/comet/src/train/batch.py +135 -0
- Model/COSMIC/feature_extraction/comet/src/train/conceptnet_train.py +67 -0
- Model/COSMIC/feature_extraction/comet/src/train/opt.py +122 -0
- Model/COSMIC/feature_extraction/comet/src/train/train.py +233 -0
- Model/COSMIC/feature_extraction/comet/src/train/utils.py +58 -0
- Model/COSMIC/feature_extraction/comet/utils/__init__.py +0 -0
- Model/COSMIC/feature_extraction/comet/utils/utils.py +210 -0
- Model/COSMIC/feature_extraction/multiprocessing_bpe_encoder.py +129 -0
- Model/COSMIC/feature_extraction/src/__init__.py +0 -0
- Model/COSMIC/feature_extraction/src/data/atomic.py +337 -0
- Model/COSMIC/feature_extraction/src/data/conceptnet.py +342 -0
- Model/COSMIC/feature_extraction/src/data/data.py +85 -0
- Model/COSMIC/feature_extraction/src/data/utils.py +134 -0
- Model/COSMIC/feature_extraction/src/evaluate/atomic_evaluate.py +40 -0
- Model/COSMIC/feature_extraction/src/evaluate/conceptnet_evaluate.py +82 -0
- Model/COSMIC/feature_extraction/src/evaluate/conceptnet_generate.py +112 -0
- Model/COSMIC/feature_extraction/src/evaluate/evaluate.py +85 -0
- Model/COSMIC/feature_extraction/src/evaluate/generate.py +72 -0
- Model/COSMIC/feature_extraction/src/evaluate/sampler.py +329 -0
- Model/COSMIC/feature_extraction/src/evaluate/utils.py +39 -0
- Model/COSMIC/feature_extraction/src/interactive/functions.py +328 -0
- Model/COSMIC/feature_extraction/src/main.py +19 -0
Model/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
!config.py
|
Model/COSMIC/erc_training/commonsense_model.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from torch.nn.utils.rnn import pad_sequence
|
6 |
+
import numpy as np, itertools, random, copy, math
|
7 |
+
from model import SimpleAttention, MatchingAttention, Attention
|
8 |
+
|
9 |
+
class CommonsenseRNNCell(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False,
|
12 |
+
context_attention='simple', D_a=100, dropout=0.5, emo_gru=True):
|
13 |
+
super(CommonsenseRNNCell, self).__init__()
|
14 |
+
|
15 |
+
self.D_m = D_m
|
16 |
+
self.D_s = D_s
|
17 |
+
self.D_g = D_g
|
18 |
+
self.D_p = D_p
|
19 |
+
self.D_r = D_r
|
20 |
+
self.D_i = D_i
|
21 |
+
self.D_e = D_e
|
22 |
+
|
23 |
+
# print ('dmsg', D_m, D_s, D_g)
|
24 |
+
self.g_cell = nn.GRUCell(D_m+D_p+D_r, D_g)
|
25 |
+
self.p_cell = nn.GRUCell(D_s+D_g, D_p)
|
26 |
+
self.r_cell = nn.GRUCell(D_m+D_s+D_g, D_r)
|
27 |
+
self.i_cell = nn.GRUCell(D_s+D_p, D_i)
|
28 |
+
self.e_cell = nn.GRUCell(D_m+D_p+D_r+D_i, D_e)
|
29 |
+
|
30 |
+
|
31 |
+
self.emo_gru = emo_gru
|
32 |
+
self.listener_state = listener_state
|
33 |
+
if listener_state:
|
34 |
+
self.pl_cell = nn.GRUCell(D_s+D_g, D_p)
|
35 |
+
self.rl_cell = nn.GRUCell(D_m+D_s+D_g, D_r)
|
36 |
+
|
37 |
+
self.dropout = nn.Dropout(dropout)
|
38 |
+
|
39 |
+
self.dropout1 = nn.Dropout(dropout)
|
40 |
+
self.dropout2 = nn.Dropout(dropout)
|
41 |
+
self.dropout3 = nn.Dropout(dropout)
|
42 |
+
self.dropout4 = nn.Dropout(dropout)
|
43 |
+
self.dropout5 = nn.Dropout(dropout)
|
44 |
+
|
45 |
+
if context_attention=='simple':
|
46 |
+
self.attention = SimpleAttention(D_g)
|
47 |
+
else:
|
48 |
+
self.attention = MatchingAttention(D_g, D_m, D_a, context_attention)
|
49 |
+
|
50 |
+
def _select_parties(self, X, indices):
|
51 |
+
q0_sel = []
|
52 |
+
for idx, j in zip(indices, X):
|
53 |
+
q0_sel.append(j[idx].unsqueeze(0))
|
54 |
+
q0_sel = torch.cat(q0_sel,0)
|
55 |
+
return q0_sel
|
56 |
+
|
57 |
+
def forward(self, U, x1, x2, x3, o1, o2, qmask, g_hist, q0, r0, i0, e0):
|
58 |
+
"""
|
59 |
+
U -> batch, D_m
|
60 |
+
x1, x2, x3, o1, o2 -> batch, D_m
|
61 |
+
x1 -> effect on self; x2 -> reaction of self; x3 -> intent of self
|
62 |
+
o1 -> effect on others; o2 -> reaction of others
|
63 |
+
qmask -> batch, party
|
64 |
+
g_hist -> t-1, batch, D_g
|
65 |
+
q0 -> batch, party, D_p
|
66 |
+
e0 -> batch, self.D_e
|
67 |
+
"""
|
68 |
+
qm_idx = torch.argmax(qmask, 1)
|
69 |
+
q0_sel = self._select_parties(q0, qm_idx)
|
70 |
+
r0_sel = self._select_parties(r0, qm_idx)
|
71 |
+
|
72 |
+
## global state ##
|
73 |
+
g_ = self.g_cell(torch.cat([U, q0_sel, r0_sel], dim=1),
|
74 |
+
torch.zeros(U.size()[0],self.D_g).type(U.type()) if g_hist.size()[0]==0 else
|
75 |
+
g_hist[-1])
|
76 |
+
# g_ = self.dropout(g_)
|
77 |
+
|
78 |
+
## context ##
|
79 |
+
if g_hist.size()[0]==0:
|
80 |
+
c_ = torch.zeros(U.size()[0], self.D_g).type(U.type())
|
81 |
+
alpha = None
|
82 |
+
else:
|
83 |
+
c_, alpha = self.attention(g_hist, U)
|
84 |
+
|
85 |
+
## external state ##
|
86 |
+
U_r_c_ = torch.cat([U, x2, c_], dim=1).unsqueeze(1).expand(-1, qmask.size()[1],-1)
|
87 |
+
# print ('urc', U_r_c_.size())
|
88 |
+
# print ('u x2, c', U.size(), x2.size(), c_.size())
|
89 |
+
rs_ = self.r_cell(U_r_c_.contiguous().view(-1, self.D_m+self.D_s+self.D_g),
|
90 |
+
r0.view(-1, self.D_r)).view(U.size()[0], -1, self.D_r)
|
91 |
+
# rs_ = self.dropout(rs_)
|
92 |
+
|
93 |
+
## internal state ##
|
94 |
+
es_c_ = torch.cat([x1, c_], dim=1).unsqueeze(1).expand(-1,qmask.size()[1],-1)
|
95 |
+
qs_ = self.p_cell(es_c_.contiguous().view(-1, self.D_s+self.D_g),
|
96 |
+
q0.view(-1, self.D_p)).view(U.size()[0], -1, self.D_p)
|
97 |
+
# qs_ = self.dropout(qs_)
|
98 |
+
|
99 |
+
|
100 |
+
if self.listener_state:
|
101 |
+
## listener external state ##
|
102 |
+
U_ = U.unsqueeze(1).expand(-1,qmask.size()[1],-1).contiguous().view(-1,self.D_m)
|
103 |
+
er_ = o2.unsqueeze(1).expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_s)
|
104 |
+
ss_ = self._select_parties(rs_, qm_idx).unsqueeze(1).\
|
105 |
+
expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_r)
|
106 |
+
U_er_ss_ = torch.cat([U_, er_, ss_], 1)
|
107 |
+
rl_ = self.rl_cell(U_er_ss_, r0.view(-1, self.D_r)).view(U.size()[0], -1, self.D_r)
|
108 |
+
# rl_ = self.dropout(rl_)
|
109 |
+
|
110 |
+
## listener internal state ##
|
111 |
+
es_ = o1.unsqueeze(1).expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_s)
|
112 |
+
ss_ = self._select_parties(qs_, qm_idx).unsqueeze(1).\
|
113 |
+
expand(-1, qmask.size()[1], -1).contiguous().view(-1, self.D_p)
|
114 |
+
es_ss_ = torch.cat([es_, ss_], 1)
|
115 |
+
ql_ = self.pl_cell(es_ss_, q0.view(-1, self.D_p)).view(U.size()[0], -1, self.D_p)
|
116 |
+
# ql_ = self.dropout(ql_)
|
117 |
+
|
118 |
+
else:
|
119 |
+
rl_ = r0
|
120 |
+
ql_ = q0
|
121 |
+
|
122 |
+
qmask_ = qmask.unsqueeze(2)
|
123 |
+
q_ = ql_*(1-qmask_) + qs_*qmask_
|
124 |
+
r_ = rl_*(1-qmask_) + rs_*qmask_
|
125 |
+
|
126 |
+
## intent ##
|
127 |
+
i_q_ = torch.cat([x3, self._select_parties(q_, qm_idx)], dim=1).unsqueeze(1).expand(-1, qmask.size()[1], -1)
|
128 |
+
is_ = self.i_cell(i_q_.contiguous().view(-1, self.D_s+self.D_p),
|
129 |
+
i0.view(-1, self.D_i)).view(U.size()[0], -1, self.D_i)
|
130 |
+
# is_ = self.dropout(is_)
|
131 |
+
il_ = i0
|
132 |
+
i_ = il_*(1-qmask_) + is_*qmask_
|
133 |
+
|
134 |
+
## emotion ##
|
135 |
+
es_ = torch.cat([U, self._select_parties(q_, qm_idx), self._select_parties(r_, qm_idx),
|
136 |
+
self._select_parties(i_, qm_idx)], dim=1)
|
137 |
+
e0 = torch.zeros(qmask.size()[0], self.D_e).type(U.type()) if e0.size()[0]==0\
|
138 |
+
else e0
|
139 |
+
|
140 |
+
if self.emo_gru:
|
141 |
+
e_ = self.e_cell(es_, e0)
|
142 |
+
else:
|
143 |
+
e_ = es_
|
144 |
+
|
145 |
+
# e_ = self.dropout(e_)
|
146 |
+
g_ = self.dropout1(g_)
|
147 |
+
q_ = self.dropout2(q_)
|
148 |
+
r_ = self.dropout3(r_)
|
149 |
+
i_ = self.dropout4(i_)
|
150 |
+
e_ = self.dropout5(e_)
|
151 |
+
|
152 |
+
return g_, q_, r_, i_, e_, alpha
|
153 |
+
|
154 |
+
|
155 |
+
class CommonsenseRNN(nn.Module):
|
156 |
+
|
157 |
+
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False,
|
158 |
+
context_attention='simple', D_a=100, dropout=0.5, emo_gru=True):
|
159 |
+
super(CommonsenseRNN, self).__init__()
|
160 |
+
|
161 |
+
self.D_m = D_m
|
162 |
+
self.D_g = D_g
|
163 |
+
self.D_p = D_p
|
164 |
+
self.D_r = D_r
|
165 |
+
self.D_i = D_i
|
166 |
+
self.D_e = D_e
|
167 |
+
self.dropout = nn.Dropout(dropout)
|
168 |
+
|
169 |
+
self.dialogue_cell = CommonsenseRNNCell(D_m, D_s, D_g, D_p, D_r, D_i, D_e,
|
170 |
+
listener_state, context_attention, D_a, dropout, emo_gru)
|
171 |
+
|
172 |
+
def forward(self, U, x1, x2, x3, o1, o2, qmask):
|
173 |
+
"""
|
174 |
+
U -> seq_len, batch, D_m
|
175 |
+
x1, x2, x3, o1, o2 -> seq_len, batch, D_s
|
176 |
+
qmask -> seq_len, batch, party
|
177 |
+
"""
|
178 |
+
|
179 |
+
g_hist = torch.zeros(0).type(U.type()) # 0-dimensional tensor
|
180 |
+
q_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_p).type(U.type()) # batch, party, D_p
|
181 |
+
r_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_r).type(U.type()) # batch, party, D_r
|
182 |
+
i_ = torch.zeros(qmask.size()[1], qmask.size()[2], self.D_i).type(U.type()) # batch, party, D_i
|
183 |
+
|
184 |
+
e_ = torch.zeros(0).type(U.type()) # batch, D_e
|
185 |
+
e = e_
|
186 |
+
|
187 |
+
alpha = []
|
188 |
+
for u_, x1_, x2_, x3_, o1_, o2_, qmask_ in zip(U, x1, x2, x3, o1, o2, qmask):
|
189 |
+
g_, q_, r_, i_, e_, alpha_ = self.dialogue_cell(u_, x1_, x2_, x3_, o1_, o2_,
|
190 |
+
qmask_, g_hist, q_, r_, i_, e_)
|
191 |
+
|
192 |
+
g_hist = torch.cat([g_hist, g_.unsqueeze(0)],0)
|
193 |
+
e = torch.cat([e, e_.unsqueeze(0)],0)
|
194 |
+
|
195 |
+
if type(alpha_)!=type(None):
|
196 |
+
alpha.append(alpha_[:,0,:])
|
197 |
+
|
198 |
+
return e, alpha # seq_len, batch, D_e
|
199 |
+
|
200 |
+
|
201 |
+
class CommonsenseGRUModel(nn.Module):
|
202 |
+
|
203 |
+
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, D_h, D_a=100, n_classes=7, listener_state=False,
|
204 |
+
context_attention='simple', dropout_rec=0.5, dropout=0.1, emo_gru=True, mode1=0, norm=0, residual=False):
|
205 |
+
|
206 |
+
super(CommonsenseGRUModel, self).__init__()
|
207 |
+
|
208 |
+
if mode1 == 0:
|
209 |
+
D_x = 4 * D_m
|
210 |
+
elif mode1 == 1:
|
211 |
+
D_x = 2 * D_m
|
212 |
+
else:
|
213 |
+
D_x = D_m
|
214 |
+
|
215 |
+
self.mode1 = mode1
|
216 |
+
self.norm_strategy = norm
|
217 |
+
self.linear_in = nn.Linear(D_x, D_h)
|
218 |
+
self.residual = residual
|
219 |
+
|
220 |
+
self.r_weights = nn.Parameter(torch.tensor([0.25, 0.25, 0.25, 0.25]))
|
221 |
+
|
222 |
+
norm_train = True
|
223 |
+
self.norm1a = nn.LayerNorm(D_m, elementwise_affine=norm_train)
|
224 |
+
self.norm1b = nn.LayerNorm(D_m, elementwise_affine=norm_train)
|
225 |
+
self.norm1c = nn.LayerNorm(D_m, elementwise_affine=norm_train)
|
226 |
+
self.norm1d = nn.LayerNorm(D_m, elementwise_affine=norm_train)
|
227 |
+
|
228 |
+
self.norm3a = nn.BatchNorm1d(D_m, affine=norm_train)
|
229 |
+
self.norm3b = nn.BatchNorm1d(D_m, affine=norm_train)
|
230 |
+
self.norm3c = nn.BatchNorm1d(D_m, affine=norm_train)
|
231 |
+
self.norm3d = nn.BatchNorm1d(D_m, affine=norm_train)
|
232 |
+
|
233 |
+
self.dropout = nn.Dropout(dropout)
|
234 |
+
self.dropout_rec = nn.Dropout(dropout_rec)
|
235 |
+
self.cs_rnn_f = CommonsenseRNN(D_h, D_s, D_g, D_p, D_r, D_i, D_e, listener_state,
|
236 |
+
context_attention, D_a, dropout_rec, emo_gru)
|
237 |
+
self.cs_rnn_r = CommonsenseRNN(D_h, D_s, D_g, D_p, D_r, D_i, D_e, listener_state,
|
238 |
+
context_attention, D_a, dropout_rec, emo_gru)
|
239 |
+
self.sense_gru = nn.GRU(input_size=D_s, hidden_size=D_s//2, num_layers=1, bidirectional=True)
|
240 |
+
self.matchatt = MatchingAttention(2*D_e,2*D_e,att_type='general2')
|
241 |
+
self.linear = nn.Linear(2*D_e, D_h)
|
242 |
+
self.smax_fc = nn.Linear(D_h, n_classes)
|
243 |
+
|
244 |
+
def _reverse_seq(self, X, mask):
|
245 |
+
"""
|
246 |
+
X -> seq_len, batch, dim
|
247 |
+
mask -> batch, seq_len
|
248 |
+
"""
|
249 |
+
X_ = X.transpose(0,1)
|
250 |
+
mask_sum = torch.sum(mask, 1).int()
|
251 |
+
|
252 |
+
xfs = []
|
253 |
+
for x, c in zip(X_, mask_sum):
|
254 |
+
xf = torch.flip(x[:c], [0])
|
255 |
+
xfs.append(xf)
|
256 |
+
return pad_sequence(xfs)
|
257 |
+
|
258 |
+
def forward(self, r1, r2, r3, r4, x1, x2, x3, o1, o2, qmask, umask, att2=False, return_hidden=False):
|
259 |
+
"""
|
260 |
+
U -> seq_len, batch, D_m
|
261 |
+
qmask -> seq_len, batch, party
|
262 |
+
"""
|
263 |
+
|
264 |
+
seq_len, batch, feature_dim = r1.size()
|
265 |
+
|
266 |
+
if self.norm_strategy == 1:
|
267 |
+
r1 = self.norm1a(r1.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
268 |
+
r2 = self.norm1b(r2.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
269 |
+
r3 = self.norm1c(r3.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
270 |
+
r4 = self.norm1d(r4.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
271 |
+
|
272 |
+
elif self.norm_strategy == 2:
|
273 |
+
norm2 = nn.LayerNorm((seq_len, feature_dim), elementwise_affine=False)
|
274 |
+
r1 = norm2(r1.transpose(0, 1)).transpose(0, 1)
|
275 |
+
r2 = norm2(r2.transpose(0, 1)).transpose(0, 1)
|
276 |
+
r3 = norm2(r3.transpose(0, 1)).transpose(0, 1)
|
277 |
+
r4 = norm2(r4.transpose(0, 1)).transpose(0, 1)
|
278 |
+
|
279 |
+
elif self.norm_strategy == 3:
|
280 |
+
r1 = self.norm3a(r1.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
281 |
+
r2 = self.norm3b(r2.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
282 |
+
r3 = self.norm3c(r3.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
283 |
+
r4 = self.norm3d(r4.transpose(0, 1).reshape(-1, feature_dim)).reshape(-1, seq_len, feature_dim).transpose(1, 0)
|
284 |
+
|
285 |
+
if self.mode1 == 0:
|
286 |
+
r = torch.cat([r1, r2, r3, r4], axis=-1)
|
287 |
+
elif self.mode1 == 1:
|
288 |
+
r = torch.cat([r1, r2], axis=-1)
|
289 |
+
elif self.mode1 == 2:
|
290 |
+
r = (r1 + r2 + r3 + r4)/4
|
291 |
+
elif self.mode1 == 3:
|
292 |
+
r = r1
|
293 |
+
elif self.mode1 == 4:
|
294 |
+
r = r2
|
295 |
+
elif self.mode1 == 5:
|
296 |
+
r = r3
|
297 |
+
elif self.mode1 == 6:
|
298 |
+
r = r4
|
299 |
+
elif self.mode1 == 7:
|
300 |
+
r = self.r_weights[0]*r1 + self.r_weights[1]*r2 + self.r_weights[2]*r3 + self.r_weights[3]*r4
|
301 |
+
|
302 |
+
r = self.linear_in(r)
|
303 |
+
|
304 |
+
emotions_f, alpha_f = self.cs_rnn_f(r, x1, x2, x3, o1, o2, qmask)
|
305 |
+
|
306 |
+
out_sense, _ = self.sense_gru(x1)
|
307 |
+
|
308 |
+
rev_r = self._reverse_seq(r, umask)
|
309 |
+
rev_x1 = self._reverse_seq(x1, umask)
|
310 |
+
rev_x2 = self._reverse_seq(x2, umask)
|
311 |
+
rev_x3 = self._reverse_seq(x3, umask)
|
312 |
+
rev_o1 = self._reverse_seq(o1, umask)
|
313 |
+
rev_o2 = self._reverse_seq(o2, umask)
|
314 |
+
rev_qmask = self._reverse_seq(qmask, umask)
|
315 |
+
emotions_b, alpha_b = self.cs_rnn_r(rev_r, rev_x1, rev_x2, rev_x3, rev_o1, rev_o2, rev_qmask)
|
316 |
+
emotions_b = self._reverse_seq(emotions_b, umask)
|
317 |
+
|
318 |
+
emotions = torch.cat([emotions_f,emotions_b],dim=-1)
|
319 |
+
emotions = self.dropout_rec(emotions)
|
320 |
+
|
321 |
+
alpha, alpha_f, alpha_b = [], [], []
|
322 |
+
if att2:
|
323 |
+
att_emotions = []
|
324 |
+
alpha = []
|
325 |
+
for t in emotions:
|
326 |
+
att_em, alpha_ = self.matchatt(emotions,t,mask=umask)
|
327 |
+
att_emotions.append(att_em.unsqueeze(0))
|
328 |
+
alpha.append(alpha_[:,0,:])
|
329 |
+
att_emotions = torch.cat(att_emotions,dim=0)
|
330 |
+
hidden = F.relu(self.linear(att_emotions))
|
331 |
+
else:
|
332 |
+
hidden = F.relu(self.linear(emotions))
|
333 |
+
|
334 |
+
hidden = self.dropout(hidden)
|
335 |
+
|
336 |
+
if self.residual:
|
337 |
+
hidden = hidden + r
|
338 |
+
|
339 |
+
log_prob = F.log_softmax(self.smax_fc(hidden), 2)
|
340 |
+
|
341 |
+
if return_hidden:
|
342 |
+
return hidden, alpha, alpha_f, alpha_b, emotions
|
343 |
+
return log_prob, out_sense, alpha, alpha_f, alpha_b, emotions
|
344 |
+
|
345 |
+
|
Model/COSMIC/erc_training/dataloader.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from torch.nn.utils.rnn import pad_sequence
|
4 |
+
import pickle, pandas as pd
|
5 |
+
|
6 |
+
class IEMOCAPRobertaCometDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, split):
|
9 |
+
'''
|
10 |
+
label index mapping = {'hap':0, 'sad':1, 'neu':2, 'ang':3, 'exc':4, 'fru':5}
|
11 |
+
'''
|
12 |
+
self.speakers, self.labels, \
|
13 |
+
self.roberta1, self.roberta2, self.roberta3, self.roberta4,\
|
14 |
+
self.sentences, self.trainIds, self.testIds, self.validIds \
|
15 |
+
= pickle.load(open('iemocap/iemocap_features_roberta.pkl', 'rb'), encoding='latin1')
|
16 |
+
|
17 |
+
self.xIntent, self.xAttr, self.xNeed, self.xWant, self.xEffect, self.xReact, self.oWant, self.oEffect, self.oReact \
|
18 |
+
= pickle.load(open('iemocap/iemocap_features_comet.pkl', 'rb'), encoding='latin1')
|
19 |
+
|
20 |
+
if split == 'train':
|
21 |
+
self.keys = [x for x in self.trainIds]
|
22 |
+
elif split == 'test':
|
23 |
+
self.keys = [x for x in self.testIds]
|
24 |
+
elif split == 'valid':
|
25 |
+
self.keys = [x for x in self.validIds]
|
26 |
+
|
27 |
+
self.len = len(self.keys)
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
vid = self.keys[index]
|
31 |
+
return torch.FloatTensor(self.roberta1[vid]),\
|
32 |
+
torch.FloatTensor(self.roberta2[vid]),\
|
33 |
+
torch.FloatTensor(self.roberta3[vid]),\
|
34 |
+
torch.FloatTensor(self.roberta4[vid]),\
|
35 |
+
torch.FloatTensor(self.xIntent[vid]),\
|
36 |
+
torch.FloatTensor(self.xAttr[vid]),\
|
37 |
+
torch.FloatTensor(self.xNeed[vid]),\
|
38 |
+
torch.FloatTensor(self.xWant[vid]),\
|
39 |
+
torch.FloatTensor(self.xEffect[vid]),\
|
40 |
+
torch.FloatTensor(self.xReact[vid]),\
|
41 |
+
torch.FloatTensor(self.oWant[vid]),\
|
42 |
+
torch.FloatTensor(self.oEffect[vid]),\
|
43 |
+
torch.FloatTensor(self.oReact[vid]),\
|
44 |
+
torch.FloatTensor([[1,0] if x=='M' else [0,1] for x in self.speakers[vid]]),\
|
45 |
+
torch.FloatTensor([1]*len(self.labels[vid])),\
|
46 |
+
torch.LongTensor(self.labels[vid]),\
|
47 |
+
vid
|
48 |
+
|
49 |
+
def __len__(self):
|
50 |
+
return self.len
|
51 |
+
|
52 |
+
def collate_fn(self, data):
|
53 |
+
dat = pd.DataFrame(data)
|
54 |
+
return [pad_sequence(dat[i]) if i<14 else pad_sequence(dat[i], True) if i<16 else dat[i].tolist() for i in dat]
|
55 |
+
|
56 |
+
|
57 |
+
class MELDRobertaCometDataset(Dataset):
|
58 |
+
|
59 |
+
def __init__(self, split, classify='emotion'):
|
60 |
+
'''
|
61 |
+
label index mapping =
|
62 |
+
'''
|
63 |
+
self.speakers, self.emotion_labels, self.sentiment_labels, \
|
64 |
+
self.roberta1, self.roberta2, self.roberta3, self.roberta4, \
|
65 |
+
self.sentences, self.trainIds, self.testIds, self.validIds \
|
66 |
+
= pickle.load(open('meld/meld_features_roberta.pkl', 'rb'), encoding='latin1')
|
67 |
+
|
68 |
+
self.xIntent, self.xAttr, self.xNeed, self.xWant, self.xEffect, self.xReact, self.oWant, self.oEffect, self.oReact \
|
69 |
+
= pickle.load(open('meld/meld_features_comet.pkl', 'rb'), encoding='latin1')
|
70 |
+
|
71 |
+
if split == 'train':
|
72 |
+
self.keys = [x for x in self.trainIds]
|
73 |
+
elif split == 'test':
|
74 |
+
self.keys = [x for x in self.testIds]
|
75 |
+
elif split == 'valid':
|
76 |
+
self.keys = [x for x in self.validIds]
|
77 |
+
|
78 |
+
if classify == 'emotion':
|
79 |
+
self.labels = self.emotion_labels
|
80 |
+
else:
|
81 |
+
self.labels = self.sentiment_labels
|
82 |
+
|
83 |
+
self.len = len(self.keys)
|
84 |
+
|
85 |
+
def __getitem__(self, index):
|
86 |
+
vid = self.keys[index]
|
87 |
+
return torch.FloatTensor(self.roberta1[vid]),\
|
88 |
+
torch.FloatTensor(self.roberta2[vid]),\
|
89 |
+
torch.FloatTensor(self.roberta3[vid]),\
|
90 |
+
torch.FloatTensor(self.roberta4[vid]),\
|
91 |
+
torch.FloatTensor(self.xIntent[vid]),\
|
92 |
+
torch.FloatTensor(self.xAttr[vid]),\
|
93 |
+
torch.FloatTensor(self.xNeed[vid]),\
|
94 |
+
torch.FloatTensor(self.xWant[vid]),\
|
95 |
+
torch.FloatTensor(self.xEffect[vid]),\
|
96 |
+
torch.FloatTensor(self.xReact[vid]),\
|
97 |
+
torch.FloatTensor(self.oWant[vid]),\
|
98 |
+
torch.FloatTensor(self.oEffect[vid]),\
|
99 |
+
torch.FloatTensor(self.oReact[vid]),\
|
100 |
+
torch.FloatTensor(self.speakers[vid]),\
|
101 |
+
torch.FloatTensor([1]*len(self.labels[vid])),\
|
102 |
+
torch.LongTensor(self.labels[vid]),\
|
103 |
+
vid
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return self.len
|
107 |
+
|
108 |
+
def collate_fn(self, data):
|
109 |
+
dat = pd.DataFrame(data)
|
110 |
+
return [pad_sequence(dat[i]) if i<14 else pad_sequence(dat[i], True) if i<16 else dat[i].tolist() for i in dat]
|
111 |
+
|
112 |
+
class RobertaCometDataset(Dataset):
|
113 |
+
|
114 |
+
def __init__(self, split, path_roberta="epik/epik_features_roberta.pkl", path_comet="epik/epik_features_comet.pkl"):
|
115 |
+
self.speakers, self.labels, \
|
116 |
+
self.roberta1, self.roberta2, self.roberta3, self.roberta4, \
|
117 |
+
self.sentences, self.trainIds, self.testIds, self.validIds \
|
118 |
+
= pickle.load(open(path_roberta, 'rb'), encoding='latin1')
|
119 |
+
|
120 |
+
self.xIntent, self.xAttr, self.xNeed, self.xWant, self.xEffect, self.xReact, self.oWant, self.oEffect, self.oReact \
|
121 |
+
= pickle.load(open(path_comet, 'rb'), encoding='latin1')
|
122 |
+
|
123 |
+
if split == 'train':
|
124 |
+
self.keys = [x for x in self.trainIds]
|
125 |
+
elif split == 'test':
|
126 |
+
self.keys = [x for x in self.testIds]
|
127 |
+
elif split == 'valid':
|
128 |
+
self.keys = [x for x in self.validIds]
|
129 |
+
|
130 |
+
self.len = len(self.keys)
|
131 |
+
|
132 |
+
def __getitem__(self, index):
|
133 |
+
vid = self.keys[index]
|
134 |
+
return torch.FloatTensor(self.roberta1[vid]),\
|
135 |
+
torch.FloatTensor(self.roberta2[vid]),\
|
136 |
+
torch.FloatTensor(self.roberta3[vid]),\
|
137 |
+
torch.FloatTensor(self.roberta4[vid]),\
|
138 |
+
torch.FloatTensor(self.xIntent[vid]),\
|
139 |
+
torch.FloatTensor(self.xAttr[vid]),\
|
140 |
+
torch.FloatTensor(self.xNeed[vid]),\
|
141 |
+
torch.FloatTensor(self.xWant[vid]),\
|
142 |
+
torch.FloatTensor(self.xEffect[vid]),\
|
143 |
+
torch.FloatTensor(self.xReact[vid]),\
|
144 |
+
torch.FloatTensor(self.oWant[vid]),\
|
145 |
+
torch.FloatTensor(self.oEffect[vid]),\
|
146 |
+
torch.FloatTensor(self.oReact[vid]),\
|
147 |
+
torch.FloatTensor([[1,0] if x=='0' else [0,1] for x in self.speakers[vid]]),\
|
148 |
+
torch.FloatTensor([1]*len(self.labels[vid])),\
|
149 |
+
torch.LongTensor(self.labels[vid]),\
|
150 |
+
vid
|
151 |
+
|
152 |
+
def __len__(self):
|
153 |
+
return self.len
|
154 |
+
|
155 |
+
def collate_fn(self, data):
|
156 |
+
dat = pd.DataFrame(data)
|
157 |
+
return [pad_sequence(dat[i]) if i<14 else pad_sequence(dat[i], True) if i<16 else dat[i].tolist() for i in dat]
|
158 |
+
|
159 |
+
|
160 |
+
class DailyDialogueRobertaCometDataset(Dataset):
|
161 |
+
|
162 |
+
def __init__(self, split):
|
163 |
+
|
164 |
+
self.speakers, self.labels, \
|
165 |
+
self.roberta1, self.roberta2, self.roberta3, self.roberta4, \
|
166 |
+
self.sentences, self.trainIds, self.testIds, self.validIds \
|
167 |
+
= pickle.load(open('dailydialog/dailydialog_features_roberta.pkl', 'rb'), encoding='latin1')
|
168 |
+
|
169 |
+
self.xIntent, self.xAttr, self.xNeed, self.xWant, self.xEffect, self.xReact, self.oWant, self.oEffect, self.oReact \
|
170 |
+
= pickle.load(open('dailydialog/dailydialog_features_comet.pkl', 'rb'), encoding='latin1')
|
171 |
+
|
172 |
+
if split == 'train':
|
173 |
+
self.keys = [x for x in self.trainIds]
|
174 |
+
elif split == 'test':
|
175 |
+
self.keys = [x for x in self.testIds]
|
176 |
+
elif split == 'valid':
|
177 |
+
self.keys = [x for x in self.validIds]
|
178 |
+
|
179 |
+
self.len = len(self.keys)
|
180 |
+
|
181 |
+
def __getitem__(self, index):
|
182 |
+
vid = self.keys[index]
|
183 |
+
return torch.FloatTensor(self.roberta1[vid]),\
|
184 |
+
torch.FloatTensor(self.roberta2[vid]),\
|
185 |
+
torch.FloatTensor(self.roberta3[vid]),\
|
186 |
+
torch.FloatTensor(self.roberta4[vid]),\
|
187 |
+
torch.FloatTensor(self.xIntent[vid]),\
|
188 |
+
torch.FloatTensor(self.xAttr[vid]),\
|
189 |
+
torch.FloatTensor(self.xNeed[vid]),\
|
190 |
+
torch.FloatTensor(self.xWant[vid]),\
|
191 |
+
torch.FloatTensor(self.xEffect[vid]),\
|
192 |
+
torch.FloatTensor(self.xReact[vid]),\
|
193 |
+
torch.FloatTensor(self.oWant[vid]),\
|
194 |
+
torch.FloatTensor(self.oEffect[vid]),\
|
195 |
+
torch.FloatTensor(self.oReact[vid]),\
|
196 |
+
torch.FloatTensor([[1,0] if x=='0' else [0,1] for x in self.speakers[vid]]),\
|
197 |
+
torch.FloatTensor([1]*len(self.labels[vid])),\
|
198 |
+
torch.LongTensor(self.labels[vid]),\
|
199 |
+
vid
|
200 |
+
|
201 |
+
def __len__(self):
|
202 |
+
return self.len
|
203 |
+
|
204 |
+
def collate_fn(self, data):
|
205 |
+
dat = pd.DataFrame(data)
|
206 |
+
return [pad_sequence(dat[i]) if i<14 else pad_sequence(dat[i], True) if i<16 else dat[i].tolist() for i in dat]
|
207 |
+
|
208 |
+
class EmoryNLPRobertaCometDataset(Dataset):
|
209 |
+
|
210 |
+
def __init__(self, split, classify='emotion'):
|
211 |
+
|
212 |
+
'''
|
213 |
+
label index mapping = {'Joyful': 0, 'Mad': 1, 'Peaceful': 2, 'Neutral': 3, 'Sad': 4, 'Powerful': 5, 'Scared': 6}
|
214 |
+
'''
|
215 |
+
|
216 |
+
self.speakers, self.emotion_labels, \
|
217 |
+
self.roberta1, self.roberta2, self.roberta3, self.roberta4, \
|
218 |
+
self.sentences, self.trainId, self.testId, self.validId \
|
219 |
+
= pickle.load(open('emorynlp/emorynlp_features_roberta.pkl', 'rb'), encoding='latin1')
|
220 |
+
|
221 |
+
sentiment_labels = {}
|
222 |
+
for item in self.emotion_labels:
|
223 |
+
array = []
|
224 |
+
# 0 negative, 1 neutral, 2 positive
|
225 |
+
for e in self.emotion_labels[item]:
|
226 |
+
if e in [1, 4, 6]:
|
227 |
+
array.append(0)
|
228 |
+
elif e == 3:
|
229 |
+
array.append(1)
|
230 |
+
elif e in [0, 2, 5]:
|
231 |
+
array.append(2)
|
232 |
+
sentiment_labels[item] = array
|
233 |
+
|
234 |
+
self.xIntent, self.xAttr, self.xNeed, self.xWant, self.xEffect, self.xReact, self.oWant, self.oEffect, self.oReact \
|
235 |
+
= pickle.load(open('emorynlp/emorynlp_features_comet.pkl', 'rb'), encoding='latin1')
|
236 |
+
|
237 |
+
if split == 'train':
|
238 |
+
self.keys = [x for x in self.trainId]
|
239 |
+
elif split == 'test':
|
240 |
+
self.keys = [x for x in self.testId]
|
241 |
+
elif split == 'valid':
|
242 |
+
self.keys = [x for x in self.validId]
|
243 |
+
|
244 |
+
if classify == 'emotion':
|
245 |
+
self.labels = self.emotion_labels
|
246 |
+
elif classify == 'sentiment':
|
247 |
+
self.labels = sentiment_labels
|
248 |
+
|
249 |
+
self.len = len(self.keys)
|
250 |
+
|
251 |
+
def __getitem__(self, index):
|
252 |
+
vid = self.keys[index]
|
253 |
+
return torch.FloatTensor(self.roberta1[vid]),\
|
254 |
+
torch.FloatTensor(self.roberta2[vid]),\
|
255 |
+
torch.FloatTensor(self.roberta3[vid]),\
|
256 |
+
torch.FloatTensor(self.roberta4[vid]),\
|
257 |
+
torch.FloatTensor(self.xIntent[vid]),\
|
258 |
+
torch.FloatTensor(self.xAttr[vid]),\
|
259 |
+
torch.FloatTensor(self.xNeed[vid]),\
|
260 |
+
torch.FloatTensor(self.xWant[vid]),\
|
261 |
+
torch.FloatTensor(self.xEffect[vid]),\
|
262 |
+
torch.FloatTensor(self.xReact[vid]),\
|
263 |
+
torch.FloatTensor(self.oWant[vid]),\
|
264 |
+
torch.FloatTensor(self.oEffect[vid]),\
|
265 |
+
torch.FloatTensor(self.oReact[vid]),\
|
266 |
+
torch.FloatTensor([[1,0] if x=='0' else [0,1] for x in self.speakers[vid]]),\
|
267 |
+
torch.FloatTensor([1]*len(self.labels[vid])),\
|
268 |
+
torch.LongTensor(self.labels[vid]),\
|
269 |
+
vid
|
270 |
+
|
271 |
+
def __len__(self):
|
272 |
+
return self.len
|
273 |
+
|
274 |
+
def collate_fn(self, data):
|
275 |
+
dat = pd.DataFrame(data)
|
276 |
+
return [pad_sequence(dat[i]) if i<14 else pad_sequence(dat[i], True) if i<16 else dat[i].tolist() for i in dat]
|
Model/COSMIC/erc_training/model.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
|
8 |
+
class MaskedNLLLoss(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, weight=None):
|
11 |
+
super(MaskedNLLLoss, self).__init__()
|
12 |
+
self.weight = weight
|
13 |
+
self.loss = nn.NLLLoss(weight=weight,
|
14 |
+
reduction='sum')
|
15 |
+
|
16 |
+
def forward(self, pred, target, mask):
|
17 |
+
"""
|
18 |
+
pred -> batch*seq_len, n_classes
|
19 |
+
target -> batch*seq_len
|
20 |
+
mask -> batch, seq_len
|
21 |
+
"""
|
22 |
+
mask_ = mask.view(-1,1) # batch*seq_len, 1
|
23 |
+
if type(self.weight)==type(None):
|
24 |
+
loss = self.loss(pred*mask_, target)/torch.sum(mask)
|
25 |
+
else:
|
26 |
+
loss = self.loss(pred*mask_, target)\
|
27 |
+
/torch.sum(self.weight[target]*mask_.squeeze())
|
28 |
+
return loss
|
29 |
+
|
30 |
+
|
31 |
+
class MaskedMSELoss(nn.Module):
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
super(MaskedMSELoss, self).__init__()
|
35 |
+
self.loss = nn.MSELoss(reduction='sum')
|
36 |
+
|
37 |
+
def forward(self, pred, target, mask):
|
38 |
+
"""
|
39 |
+
pred -> batch*seq_len
|
40 |
+
target -> batch*seq_len
|
41 |
+
mask -> batch*seq_len
|
42 |
+
"""
|
43 |
+
loss = self.loss(pred*mask, target)/torch.sum(mask)
|
44 |
+
return loss
|
45 |
+
|
46 |
+
|
47 |
+
class UnMaskedWeightedNLLLoss(nn.Module):
|
48 |
+
|
49 |
+
def __init__(self, weight=None):
|
50 |
+
super(UnMaskedWeightedNLLLoss, self).__init__()
|
51 |
+
self.weight = weight
|
52 |
+
self.loss = nn.NLLLoss(weight=weight,
|
53 |
+
reduction='sum')
|
54 |
+
|
55 |
+
def forward(self, pred, target):
|
56 |
+
"""
|
57 |
+
pred -> batch*seq_len, n_classes
|
58 |
+
target -> batch*seq_len
|
59 |
+
"""
|
60 |
+
if type(self.weight)==type(None):
|
61 |
+
loss = self.loss(pred, target)
|
62 |
+
else:
|
63 |
+
loss = self.loss(pred, target)\
|
64 |
+
/torch.sum(self.weight[target])
|
65 |
+
return loss
|
66 |
+
|
67 |
+
|
68 |
+
class SimpleAttention(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, input_dim):
|
71 |
+
super(SimpleAttention, self).__init__()
|
72 |
+
self.input_dim = input_dim
|
73 |
+
self.scalar = nn.Linear(self.input_dim,1,bias=False)
|
74 |
+
|
75 |
+
def forward(self, M, x=None):
|
76 |
+
"""
|
77 |
+
M -> (seq_len, batch, vector)
|
78 |
+
x -> dummy argument for the compatibility with MatchingAttention
|
79 |
+
"""
|
80 |
+
scale = self.scalar(M) # seq_len, batch, 1
|
81 |
+
alpha = F.softmax(scale, dim=0).permute(1,2,0) # batch, 1, seq_len
|
82 |
+
attn_pool = torch.bmm(alpha, M.transpose(0,1))[:,0,:] # batch, vector
|
83 |
+
return attn_pool, alpha
|
84 |
+
|
85 |
+
|
86 |
+
class MatchingAttention(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, mem_dim, cand_dim, alpha_dim=None, att_type='general'):
|
89 |
+
super(MatchingAttention, self).__init__()
|
90 |
+
assert att_type!='concat' or alpha_dim!=None
|
91 |
+
assert att_type!='dot' or mem_dim==cand_dim
|
92 |
+
self.mem_dim = mem_dim
|
93 |
+
self.cand_dim = cand_dim
|
94 |
+
self.att_type = att_type
|
95 |
+
if att_type=='general':
|
96 |
+
self.transform = nn.Linear(cand_dim, mem_dim, bias=False)
|
97 |
+
if att_type=='general2':
|
98 |
+
self.transform = nn.Linear(cand_dim, mem_dim, bias=True)
|
99 |
+
#torch.nn.init.normal_(self.transform.weight,std=0.01)
|
100 |
+
elif att_type=='concat':
|
101 |
+
self.transform = nn.Linear(cand_dim+mem_dim, alpha_dim, bias=False)
|
102 |
+
self.vector_prod = nn.Linear(alpha_dim, 1, bias=False)
|
103 |
+
|
104 |
+
def forward(self, M, x, mask=None):
|
105 |
+
"""
|
106 |
+
M -> (seq_len, batch, mem_dim)
|
107 |
+
x -> (batch, cand_dim)
|
108 |
+
mask -> (batch, seq_len)
|
109 |
+
"""
|
110 |
+
if type(mask)==type(None):
|
111 |
+
mask = torch.ones(M.size(1), M.size(0)).type(M.type())
|
112 |
+
|
113 |
+
if self.att_type=='dot':
|
114 |
+
# vector = cand_dim = mem_dim
|
115 |
+
M_ = M.permute(1,2,0) # batch, vector, seqlen
|
116 |
+
x_ = x.unsqueeze(1) # batch, 1, vector
|
117 |
+
alpha = F.softmax(torch.bmm(x_, M_), dim=2) # batch, 1, seqlen
|
118 |
+
elif self.att_type=='general':
|
119 |
+
M_ = M.permute(1,2,0) # batch, mem_dim, seqlen
|
120 |
+
x_ = self.transform(x).unsqueeze(1) # batch, 1, mem_dim
|
121 |
+
alpha = F.softmax(torch.bmm(x_, M_), dim=2) # batch, 1, seqlen
|
122 |
+
elif self.att_type=='general2':
|
123 |
+
M_ = M.permute(1,2,0) # batch, mem_dim, seqlen
|
124 |
+
x_ = self.transform(x).unsqueeze(1) # batch, 1, mem_dim
|
125 |
+
mask_ = mask.unsqueeze(2).repeat(1, 1, self.mem_dim).transpose(1, 2) # batch, seq_len, mem_dim
|
126 |
+
M_ = M_ * mask_
|
127 |
+
alpha_ = torch.bmm(x_, M_)*mask.unsqueeze(1)
|
128 |
+
alpha_ = torch.tanh(alpha_)
|
129 |
+
alpha_ = F.softmax(alpha_, dim=2)
|
130 |
+
# alpha_ = F.softmax((torch.bmm(x_, M_))*mask.unsqueeze(1), dim=2) # batch, 1, seqlen
|
131 |
+
alpha_masked = alpha_*mask.unsqueeze(1) # batch, 1, seqlen
|
132 |
+
alpha_sum = torch.sum(alpha_masked, dim=2, keepdim=True) # batch, 1, 1
|
133 |
+
alpha = alpha_masked/alpha_sum # batch, 1, 1 ; normalized
|
134 |
+
#import ipdb;ipdb.set_trace()
|
135 |
+
else:
|
136 |
+
M_ = M.transpose(0,1) # batch, seqlen, mem_dim
|
137 |
+
x_ = x.unsqueeze(1).expand(-1,M.size()[0],-1) # batch, seqlen, cand_dim
|
138 |
+
M_x_ = torch.cat([M_,x_],2) # batch, seqlen, mem_dim+cand_dim
|
139 |
+
mx_a = F.tanh(self.transform(M_x_)) # batch, seqlen, alpha_dim
|
140 |
+
alpha = F.softmax(self.vector_prod(mx_a),1).transpose(1,2) # batch, 1, seqlen
|
141 |
+
|
142 |
+
attn_pool = torch.bmm(alpha, M.transpose(0,1))[:,0,:] # batch, mem_dim
|
143 |
+
return attn_pool, alpha
|
144 |
+
|
145 |
+
|
146 |
+
class Attention(nn.Module):
|
147 |
+
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='dot_product', dropout=0):
|
148 |
+
''' Attention Mechanism
|
149 |
+
:param embed_dim:
|
150 |
+
:param hidden_dim:
|
151 |
+
:param out_dim:
|
152 |
+
:param n_head: num of head (Multi-Head Attention)
|
153 |
+
:param score_function: scaled_dot_product / mlp (concat) / bi_linear (general dot)
|
154 |
+
:return (?, q_len, out_dim,)
|
155 |
+
'''
|
156 |
+
super(Attention, self).__init__()
|
157 |
+
if hidden_dim is None:
|
158 |
+
hidden_dim = embed_dim // n_head
|
159 |
+
if out_dim is None:
|
160 |
+
out_dim = embed_dim
|
161 |
+
self.embed_dim = embed_dim
|
162 |
+
self.hidden_dim = hidden_dim
|
163 |
+
self.n_head = n_head
|
164 |
+
self.score_function = score_function
|
165 |
+
self.w_k = nn.Linear(embed_dim, n_head * hidden_dim)
|
166 |
+
self.w_q = nn.Linear(embed_dim, n_head * hidden_dim)
|
167 |
+
self.proj = nn.Linear(n_head * hidden_dim, out_dim)
|
168 |
+
self.dropout = nn.Dropout(dropout)
|
169 |
+
if score_function == 'mlp':
|
170 |
+
self.weight = nn.Parameter(torch.Tensor(hidden_dim*2))
|
171 |
+
elif self.score_function == 'bi_linear':
|
172 |
+
self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
|
173 |
+
else: # dot_product / scaled_dot_product
|
174 |
+
self.register_parameter('weight', None)
|
175 |
+
self.reset_parameters()
|
176 |
+
|
177 |
+
def reset_parameters(self):
|
178 |
+
stdv = 1. / math.sqrt(self.hidden_dim)
|
179 |
+
if self.weight is not None:
|
180 |
+
self.weight.data.uniform_(-stdv, stdv)
|
181 |
+
|
182 |
+
def forward(self, k, q):
|
183 |
+
if len(q.shape) == 2: # q_len missing
|
184 |
+
q = torch.unsqueeze(q, dim=1)
|
185 |
+
if len(k.shape) == 2: # k_len missing
|
186 |
+
k = torch.unsqueeze(k, dim=1)
|
187 |
+
mb_size = k.shape[0] # ?
|
188 |
+
k_len = k.shape[1]
|
189 |
+
q_len = q.shape[1]
|
190 |
+
# k: (?, k_len, embed_dim,)
|
191 |
+
# q: (?, q_len, embed_dim,)
|
192 |
+
# kx: (n_head*?, k_len, hidden_dim)
|
193 |
+
# qx: (n_head*?, q_len, hidden_dim)
|
194 |
+
# score: (n_head*?, q_len, k_len,)
|
195 |
+
# output: (?, q_len, out_dim,)
|
196 |
+
kx = self.w_k(k).view(mb_size, k_len, self.n_head, self.hidden_dim)
|
197 |
+
kx = kx.permute(2, 0, 1, 3).contiguous().view(-1, k_len, self.hidden_dim)
|
198 |
+
qx = self.w_q(q).view(mb_size, q_len, self.n_head, self.hidden_dim)
|
199 |
+
qx = qx.permute(2, 0, 1, 3).contiguous().view(-1, q_len, self.hidden_dim)
|
200 |
+
if self.score_function == 'dot_product':
|
201 |
+
kt = kx.permute(0, 2, 1)
|
202 |
+
score = torch.bmm(qx, kt)
|
203 |
+
elif self.score_function == 'scaled_dot_product':
|
204 |
+
kt = kx.permute(0, 2, 1)
|
205 |
+
qkt = torch.bmm(qx, kt)
|
206 |
+
score = torch.div(qkt, math.sqrt(self.hidden_dim))
|
207 |
+
elif self.score_function == 'mlp':
|
208 |
+
kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1)
|
209 |
+
qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1)
|
210 |
+
kq = torch.cat((kxx, qxx), dim=-1) # (n_head*?, q_len, k_len, hidden_dim*2)
|
211 |
+
# kq = torch.unsqueeze(kx, dim=1) + torch.unsqueeze(qx, dim=2)
|
212 |
+
score = torch.tanh(torch.matmul(kq, self.weight))
|
213 |
+
elif self.score_function == 'bi_linear':
|
214 |
+
qw = torch.matmul(qx, self.weight)
|
215 |
+
kt = kx.permute(0, 2, 1)
|
216 |
+
score = torch.bmm(qw, kt)
|
217 |
+
else:
|
218 |
+
raise RuntimeError('invalid score_function')
|
219 |
+
#score = F.softmax(score, dim=-1)
|
220 |
+
score = F.softmax(score, dim=0)
|
221 |
+
# print (score)
|
222 |
+
# print (sum(score))
|
223 |
+
output = torch.bmm(score, kx) # (n_head*?, q_len, hidden_dim)
|
224 |
+
output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1) # (?, q_len, n_head*hidden_dim)
|
225 |
+
output = self.proj(output) # (?, q_len, out_dim)
|
226 |
+
output = self.dropout(output)
|
227 |
+
return output, score
|
228 |
+
|
229 |
+
|
Model/COSMIC/erc_training/predict_epik.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, argparse
|
2 |
+
from commonsense_model import CommonsenseGRUModel
|
3 |
+
from dataloader import RobertaCometDataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
def load_model(model_path, args):
|
8 |
+
emo_gru = True
|
9 |
+
n_classes = 15
|
10 |
+
cuda = args.cuda
|
11 |
+
|
12 |
+
D_m = 1024
|
13 |
+
D_s = 768
|
14 |
+
D_g = 150
|
15 |
+
D_p = 150
|
16 |
+
D_r = 150
|
17 |
+
D_i = 150
|
18 |
+
D_h = 100
|
19 |
+
D_a = 100
|
20 |
+
D_e = D_p + D_r + D_i
|
21 |
+
|
22 |
+
model = CommonsenseGRUModel(
|
23 |
+
D_m,
|
24 |
+
D_s,
|
25 |
+
D_g,
|
26 |
+
D_p,
|
27 |
+
D_r,
|
28 |
+
D_i,
|
29 |
+
D_e,
|
30 |
+
D_h,
|
31 |
+
D_a,
|
32 |
+
n_classes=n_classes,
|
33 |
+
listener_state=args.active_listener,
|
34 |
+
context_attention=args.attention,
|
35 |
+
dropout_rec=args.rec_dropout,
|
36 |
+
dropout=args.dropout,
|
37 |
+
emo_gru=emo_gru,
|
38 |
+
mode1=args.mode1,
|
39 |
+
norm=args.norm,
|
40 |
+
residual=args.residual,
|
41 |
+
)
|
42 |
+
|
43 |
+
if cuda:
|
44 |
+
model.cuda()
|
45 |
+
|
46 |
+
model.load_state_dict(torch.load(model_path))
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
return model
|
50 |
+
|
51 |
+
|
52 |
+
def get_valid_dataloader(
|
53 |
+
roberta_features_path: str,
|
54 |
+
comet_features_path: str,
|
55 |
+
batch_size=1,
|
56 |
+
num_workers=0,
|
57 |
+
pin_memory=False,
|
58 |
+
):
|
59 |
+
valid_set = RobertaCometDataset("valid", roberta_features_path, comet_features_path)
|
60 |
+
|
61 |
+
test_loader = DataLoader(
|
62 |
+
valid_set,
|
63 |
+
batch_size=batch_size,
|
64 |
+
collate_fn=valid_set.collate_fn,
|
65 |
+
num_workers=num_workers,
|
66 |
+
pin_memory=pin_memory,
|
67 |
+
)
|
68 |
+
|
69 |
+
return test_loader, valid_set.keys
|
70 |
+
|
71 |
+
|
72 |
+
def predict(model, data_loader, args):
|
73 |
+
predictions = []
|
74 |
+
for data in data_loader:
|
75 |
+
r1, r2, r3, r4, x1, x2, x3, x4, x5, x6, o1, o2, o3, qmask, umask, label = (
|
76 |
+
[d.cuda() for d in data[:-1]] if args.cuda else data[:-1]
|
77 |
+
)
|
78 |
+
log_prob, _, alpha, alpha_f, alpha_b, _ = model(
|
79 |
+
r1, r2, r3, r4, x5, x6, x1, o2, o3, qmask, umask
|
80 |
+
)
|
81 |
+
|
82 |
+
lp_ = log_prob.transpose(0, 1).contiguous().view(-1, log_prob.size()[2])
|
83 |
+
preds = torch.argmax(lp_, dim=-1)
|
84 |
+
predictions.append(preds.data.cpu().numpy())
|
85 |
+
|
86 |
+
return predictions
|
87 |
+
|
88 |
+
|
89 |
+
def parse_cosmic_args():
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
|
92 |
+
# Parse arguments input into the cosmic model
|
93 |
+
parser.add_argument(
|
94 |
+
"--no-cuda", action="store_true", default=False, help="does not use GPU"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--lr", type=float, default=0.0001, metavar="LR", help="learning rate"
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"--l2",
|
101 |
+
type=float,
|
102 |
+
default=0.00003,
|
103 |
+
metavar="L2",
|
104 |
+
help="L2 regularization weight",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--rec-dropout",
|
108 |
+
type=float,
|
109 |
+
default=0.3,
|
110 |
+
metavar="rec_dropout",
|
111 |
+
help="rec_dropout rate",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--dropout", type=float, default=0.5, metavar="dropout", help="dropout rate"
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--batch-size", type=int, default=1, metavar="BS", help="batch size"
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--epochs", type=int, default=10, metavar="E", help="number of epochs"
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--class-weight", action="store_true", default=True, help="use class weights"
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"--active-listener", action="store_true", default=True, help="active listener"
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--attention", default="simple", help="Attention type in context GRU"
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--tensorboard",
|
133 |
+
action="store_true",
|
134 |
+
default=False,
|
135 |
+
help="Enables tensorboard log",
|
136 |
+
)
|
137 |
+
parser.add_argument("--mode1", type=int, default=2, help="Roberta features to use")
|
138 |
+
parser.add_argument("--seed", type=int, default=500, metavar="seed", help="seed")
|
139 |
+
parser.add_argument("--norm", type=int, default=0, help="normalization strategy")
|
140 |
+
parser.add_argument("--mu", type=float, default=0, help="class_weight_mu")
|
141 |
+
parser.add_argument(
|
142 |
+
"--residual", action="store_true", default=True, help="use residual connection"
|
143 |
+
)
|
144 |
+
|
145 |
+
args = parser.parse_args()
|
146 |
+
|
147 |
+
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
148 |
+
if args.cuda:
|
149 |
+
print("Running on GPU")
|
150 |
+
else:
|
151 |
+
print("Running on CPU")
|
152 |
+
|
153 |
+
return args
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
|
158 |
+
def pred_to_labels(preds):
|
159 |
+
mapped_predictions = []
|
160 |
+
for pred in preds:
|
161 |
+
# map the prediction for each conversation
|
162 |
+
mapped_labels = []
|
163 |
+
for label in pred:
|
164 |
+
mapped_labels.append(label_mapping[label])
|
165 |
+
|
166 |
+
mapped_predictions.append(mapped_labels)
|
167 |
+
|
168 |
+
# return the mapped labels for each conversation
|
169 |
+
return mapped_predictions
|
170 |
+
|
171 |
+
label_mapping = {
|
172 |
+
0: "Curiosity",
|
173 |
+
1: "Obscene",
|
174 |
+
2: "Informative",
|
175 |
+
3: "Openness",
|
176 |
+
4: "Acceptance",
|
177 |
+
5: "Interest",
|
178 |
+
6: "Greeting",
|
179 |
+
7: "Disapproval",
|
180 |
+
8: "Denial",
|
181 |
+
9: "Anxious",
|
182 |
+
10: "Uninterested",
|
183 |
+
11: "Remorse",
|
184 |
+
12: "Confused",
|
185 |
+
13: "Accusatory",
|
186 |
+
14: "Annoyed",
|
187 |
+
}
|
188 |
+
|
189 |
+
args = parse_cosmic_args()
|
190 |
+
|
191 |
+
model = load_model("epik/best_model.pt", args)
|
192 |
+
test_dataloader, ids = get_valid_dataloader()
|
193 |
+
predicted_labels = pred_to_labels(predict(model, test_dataloader, args))
|
194 |
+
|
195 |
+
for id, labels in zip(ids, predicted_labels):
|
196 |
+
print(f"Conversation ID: {id}")
|
197 |
+
print(f"Predicted Sentiment Labels: {labels}")
|
198 |
+
print(len(labels))
|
Model/COSMIC/feature_extraction/comet/__init__.py
ADDED
File without changes
|
Model/COSMIC/feature_extraction/comet/csk_feature_extract.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
from nltk import tokenize
|
4 |
+
import numpy as np
|
5 |
+
import pickle, torch
|
6 |
+
import comet.src.data.data as data
|
7 |
+
import comet.src.data.config as cfg
|
8 |
+
import comet.src.models.utils as model_utils
|
9 |
+
import comet.src.interactive.functions as interactive
|
10 |
+
|
11 |
+
|
12 |
+
class CSKFeatureExtractor:
|
13 |
+
def __init__(self, dir="."):
|
14 |
+
super(CSKFeatureExtractor, self).__init__()
|
15 |
+
|
16 |
+
device = 0
|
17 |
+
model_file = os.path.join(
|
18 |
+
dir, "comet/pretrained_models/atomic_pretrained_model.pickle"
|
19 |
+
)
|
20 |
+
sampling_algorithm = "beam-5"
|
21 |
+
category = "all"
|
22 |
+
|
23 |
+
opt, state_dict = interactive.load_model_file(model_file)
|
24 |
+
data_loader, text_encoder = interactive.load_data("atomic", opt, dir)
|
25 |
+
|
26 |
+
self.opt = opt
|
27 |
+
self.data_loader = data_loader
|
28 |
+
self.text_encoder = text_encoder
|
29 |
+
|
30 |
+
n_ctx = data_loader.max_event + data_loader.max_effect
|
31 |
+
n_vocab = len(text_encoder.encoder) + n_ctx
|
32 |
+
self.model = interactive.make_model(opt, n_vocab, n_ctx, state_dict)
|
33 |
+
self.model.eval()
|
34 |
+
|
35 |
+
if device != "cpu":
|
36 |
+
cfg.device = int(device)
|
37 |
+
cfg.do_gpu = True
|
38 |
+
torch.cuda.set_device(cfg.device)
|
39 |
+
self.model.cuda(cfg.device)
|
40 |
+
else:
|
41 |
+
cfg.device = "cpu"
|
42 |
+
|
43 |
+
def set_atomic_inputs(self, input_event, category, data_loader, text_encoder):
|
44 |
+
XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device)
|
45 |
+
prefix, suffix = data.atomic_data.do_example(
|
46 |
+
text_encoder, input_event, None, True, None
|
47 |
+
)
|
48 |
+
|
49 |
+
if len(prefix) > data_loader.max_event + 1:
|
50 |
+
prefix = prefix[: data_loader.max_event + 1]
|
51 |
+
|
52 |
+
XMB[:, : len(prefix)] = torch.LongTensor(prefix)
|
53 |
+
XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]])
|
54 |
+
|
55 |
+
batch = {}
|
56 |
+
batch["sequences"] = XMB
|
57 |
+
batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB)
|
58 |
+
return batch
|
59 |
+
|
60 |
+
def extract(self, sentence):
|
61 |
+
atomic_keys = [
|
62 |
+
"xIntent",
|
63 |
+
"xAttr",
|
64 |
+
"xNeed",
|
65 |
+
"xWant",
|
66 |
+
"xEffect",
|
67 |
+
"xReact",
|
68 |
+
"oWant",
|
69 |
+
"oEffect",
|
70 |
+
"oReact",
|
71 |
+
]
|
72 |
+
map1 = [{}, {}, {}, {}, {}, {}, {}, {}, {}]
|
73 |
+
all_keys = list(sentence.keys())
|
74 |
+
|
75 |
+
for i in tqdm(range(len(all_keys))):
|
76 |
+
item = all_keys[i]
|
77 |
+
list1 = [[], [], [], [], [], [], [], [], []]
|
78 |
+
|
79 |
+
for x in sentence[item]:
|
80 |
+
input_event = x.encode("ascii", errors="ignore").decode("utf-8")
|
81 |
+
m1 = []
|
82 |
+
for sent in tokenize.sent_tokenize(input_event):
|
83 |
+
seqs = []
|
84 |
+
masks = []
|
85 |
+
for category in atomic_keys:
|
86 |
+
batch = self.set_atomic_inputs(
|
87 |
+
sent, category, self.data_loader, self.text_encoder
|
88 |
+
)
|
89 |
+
seqs.append(batch["sequences"])
|
90 |
+
masks.append(batch["attention_mask"])
|
91 |
+
|
92 |
+
XMB = torch.cat(seqs)
|
93 |
+
MMB = torch.cat(masks)
|
94 |
+
XMB = model_utils.prepare_position_embeddings(
|
95 |
+
self.opt, self.data_loader.vocab_encoder, XMB.unsqueeze(-1)
|
96 |
+
)
|
97 |
+
h, _ = self.model(XMB.unsqueeze(1), sequence_mask=MMB)
|
98 |
+
|
99 |
+
last_index = MMB[0][:-1].nonzero()[-1].cpu().numpy()[0] + 1
|
100 |
+
m1.append(h[:, -1, :].detach().cpu().numpy())
|
101 |
+
|
102 |
+
m1 = np.mean(np.array(m1), axis=0)
|
103 |
+
|
104 |
+
for k, l1 in enumerate(list1):
|
105 |
+
l1.append(m1[k])
|
106 |
+
|
107 |
+
for k, v1 in enumerate(map1):
|
108 |
+
v1[item] = list1[k]
|
109 |
+
|
110 |
+
return map1
|
Model/COSMIC/feature_extraction/comet/src/__init__.py
ADDED
File without changes
|
Model/COSMIC/feature_extraction/comet/src/data/atomic.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.utils.utils as utils
|
2 |
+
import comet.src.data.utils as data_utils
|
3 |
+
import comet.src.data.config as cfg
|
4 |
+
|
5 |
+
import pandas
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def map_name(name):
|
15 |
+
if name == "train":
|
16 |
+
return "trn"
|
17 |
+
elif name == "test":
|
18 |
+
return "tst"
|
19 |
+
else:
|
20 |
+
return "dev"
|
21 |
+
|
22 |
+
|
23 |
+
class DataLoader(object):
|
24 |
+
def __init__(self, opt):
|
25 |
+
self.data = {}
|
26 |
+
self.data["train"] = {}
|
27 |
+
self.data["dev"] = {}
|
28 |
+
self.data["test"] = {}
|
29 |
+
|
30 |
+
self.sequences = {}
|
31 |
+
self.sequences["train"] = {}
|
32 |
+
self.sequences["dev"] = {}
|
33 |
+
self.sequences["test"] = {}
|
34 |
+
|
35 |
+
self.masks = {}
|
36 |
+
self.masks["train"] = {}
|
37 |
+
self.masks["dev"] = {}
|
38 |
+
self.masks["test"] = {}
|
39 |
+
|
40 |
+
self.offsets = {}
|
41 |
+
self.offsets["train"] = {}
|
42 |
+
self.offsets["dev"] = {}
|
43 |
+
self.offsets["test"] = {}
|
44 |
+
|
45 |
+
def offset_summary(self, split):
|
46 |
+
return self.offsets[split]["total"]
|
47 |
+
|
48 |
+
|
49 |
+
def do_take_partial_dataset(data_opts):
|
50 |
+
if data_opts.get("kr", None) is None:
|
51 |
+
return False
|
52 |
+
if data_opts.kr == 1:
|
53 |
+
return False
|
54 |
+
return True
|
55 |
+
|
56 |
+
|
57 |
+
def select_partial_dataset(data_opts, data):
|
58 |
+
num_selections = math.ceil(data_opts.kr * len(data))
|
59 |
+
return random.sample(data, num_selections)
|
60 |
+
|
61 |
+
|
62 |
+
class GenerationDataLoader(DataLoader):
|
63 |
+
def __init__(self, opt, categories):
|
64 |
+
super(GenerationDataLoader, self).__init__(opt)
|
65 |
+
|
66 |
+
self.categories = categories
|
67 |
+
self.opt = opt
|
68 |
+
|
69 |
+
for split in self.data:
|
70 |
+
self.data[split] = {"total": []}
|
71 |
+
self.offsets[split] = {"total": 0}
|
72 |
+
|
73 |
+
self.vocab_encoder = None
|
74 |
+
self.vocab_decoder = None
|
75 |
+
self.special_chars = None
|
76 |
+
self.max_event = None
|
77 |
+
self.max_effect = None
|
78 |
+
|
79 |
+
def load_data(self, path):
|
80 |
+
if ".pickle" in path:
|
81 |
+
print("Loading data from: {}".format(path))
|
82 |
+
data_utils.load_existing_data_loader(self, path)
|
83 |
+
|
84 |
+
return True
|
85 |
+
|
86 |
+
for split in self.data:
|
87 |
+
file_name = "v4_atomic_{}.csv".format(map_name(split))
|
88 |
+
|
89 |
+
df = pandas.read_csv("{}/{}".format(path, file_name), index_col=0)
|
90 |
+
df.iloc[:, :9] = df.iloc[:, :9].apply(
|
91 |
+
lambda col: col.apply(json.loads))
|
92 |
+
|
93 |
+
for cat in self.categories:
|
94 |
+
attr = df[cat]
|
95 |
+
self.data[split]["total"] += utils.zipped_flatten(zip(
|
96 |
+
attr.index, ["<{}>".format(cat)] * len(attr), attr.values))
|
97 |
+
|
98 |
+
if do_take_partial_dataset(self.opt.data):
|
99 |
+
self.data["train"]["total"] = select_partial_dataset(
|
100 |
+
self.opt.data, self.data["train"]["total"])
|
101 |
+
|
102 |
+
return False
|
103 |
+
|
104 |
+
def make_tensors(self, text_encoder, special,
|
105 |
+
splits=["train", "dev", "test"], test=False):
|
106 |
+
self.vocab_encoder = text_encoder.encoder
|
107 |
+
self.vocab_decoder = text_encoder.decoder
|
108 |
+
self.special_chars = special
|
109 |
+
|
110 |
+
sequences = {}
|
111 |
+
for split in splits:
|
112 |
+
sequences[split] = get_generation_sequences(
|
113 |
+
self.opt, self.data, split, text_encoder, test)
|
114 |
+
|
115 |
+
self.masks[split]["total"] = [(len(i[0]), len(i[1])) for
|
116 |
+
i in sequences[split]]
|
117 |
+
|
118 |
+
self.max_event = max([max([l[0] for l in self.masks[split]["total"]])
|
119 |
+
for split in self.masks])
|
120 |
+
self.max_effect = max([max([l[1] for l in self.masks[split]["total"]])
|
121 |
+
for split in self.masks])
|
122 |
+
|
123 |
+
print(self.max_event)
|
124 |
+
print(self.max_effect)
|
125 |
+
|
126 |
+
for split in splits:
|
127 |
+
num_elements = len(sequences[split])
|
128 |
+
self.sequences[split]["total"] = torch.LongTensor(
|
129 |
+
num_elements, self.max_event + self.max_effect).fill_(0)
|
130 |
+
|
131 |
+
for i, seq in enumerate(sequences[split]):
|
132 |
+
# print(self.sequences[split]["total"][i, :len(seq[0])].size())
|
133 |
+
# print(torch.FloatTensor(seq[0]).size())
|
134 |
+
self.sequences[split]["total"][i, :len(seq[0])] = \
|
135 |
+
torch.LongTensor(seq[0])
|
136 |
+
self.sequences[split]["total"][i, self.max_event:self.max_event + len(seq[1])] = \
|
137 |
+
torch.LongTensor(seq[1])
|
138 |
+
|
139 |
+
def sample_batch(self, split, bs, idxs=None):
|
140 |
+
offset = self.offsets[split]["total"]
|
141 |
+
|
142 |
+
batch = {}
|
143 |
+
|
144 |
+
# Decided not to reduce computation on here because it's all parallel
|
145 |
+
# anyway and we don't want to run out of memory in cases where we
|
146 |
+
# don't see the longest version quickly enough
|
147 |
+
|
148 |
+
if idxs:
|
149 |
+
seqs = self.sequences[split]["total"].index_select(
|
150 |
+
0, torch.LongTensor(idxs).to(
|
151 |
+
self.sequences[split]["total"].device))
|
152 |
+
else:
|
153 |
+
seqs = self.sequences[split]["total"][offset:offset + bs]
|
154 |
+
batch["sequences"] = seqs.to(cfg.device)
|
155 |
+
batch["attention_mask"] = make_attention_mask(seqs)
|
156 |
+
batch["loss_mask"] = make_loss_mask(
|
157 |
+
seqs, self.max_event, 1)
|
158 |
+
batch["key"] = ("total", offset, offset + bs)
|
159 |
+
|
160 |
+
offset += seqs.size(0)
|
161 |
+
|
162 |
+
self.offsets[split]["total"] = offset
|
163 |
+
|
164 |
+
if split == "train" and offset + bs > len(self.sequences[split]["total"]):
|
165 |
+
return batch, True
|
166 |
+
elif offset >= len(self.sequences[split]["total"]):
|
167 |
+
return batch, True
|
168 |
+
else:
|
169 |
+
return batch, False
|
170 |
+
|
171 |
+
def reset_offsets(self, splits=["train", "test", "dev"],
|
172 |
+
shuffle=True, keys=None):
|
173 |
+
if isinstance(splits, str):
|
174 |
+
splits = [splits]
|
175 |
+
|
176 |
+
for split in splits:
|
177 |
+
if keys is None:
|
178 |
+
keys = ["total"]
|
179 |
+
|
180 |
+
for key in keys:
|
181 |
+
self.offsets[split][key] = 0
|
182 |
+
|
183 |
+
if shuffle:
|
184 |
+
self.shuffle_sequences(split, keys)
|
185 |
+
|
186 |
+
def shuffle_sequences(self, split="train", keys=None):
|
187 |
+
if keys is None:
|
188 |
+
# print(type(self.data))
|
189 |
+
# print(type(self.data.keys()))
|
190 |
+
keys = self.data[split].keys()
|
191 |
+
|
192 |
+
for key in keys:
|
193 |
+
idxs = list(range(len(self.data[split][key])))
|
194 |
+
|
195 |
+
random.shuffle(idxs)
|
196 |
+
|
197 |
+
self.sequences[split][key] = \
|
198 |
+
self.sequences[split][key].index_select(
|
199 |
+
0, torch.LongTensor(idxs))
|
200 |
+
|
201 |
+
temp = [self.data[split][key][i] for i in idxs]
|
202 |
+
self.data[split][key] = temp
|
203 |
+
temp = [self.masks[split][key][i] for i in idxs]
|
204 |
+
self.masks[split][key] = temp
|
205 |
+
|
206 |
+
|
207 |
+
def prune_data_for_evaluation(data_loader, categories, split):
|
208 |
+
indices = []
|
209 |
+
for i, example in enumerate(data_loader.data[split]["total"]):
|
210 |
+
if example[1] in categories:
|
211 |
+
indices.append(i)
|
212 |
+
|
213 |
+
data_loader.masks[split]["total"] = [data_loader.masks[split]["total"][i]
|
214 |
+
for i in indices]
|
215 |
+
data_loader.sequences[split]["total"] = \
|
216 |
+
data_loader.sequences[split]["total"].index_select(
|
217 |
+
0, torch.LongTensor(indices))
|
218 |
+
data_loader.data[split]["total"] = [data_loader.data[split]["total"][i]
|
219 |
+
for i in indices]
|
220 |
+
|
221 |
+
|
222 |
+
def make_attention_mask(sequences):
|
223 |
+
return (sequences != 0).float().to(cfg.device)
|
224 |
+
|
225 |
+
|
226 |
+
def make_loss_mask(sequences, max_event, num_delim_tokens):
|
227 |
+
# print(num_delim_tokens)
|
228 |
+
# print(sequences.size())
|
229 |
+
mask = (sequences != 0).float()
|
230 |
+
mask[:, :max_event + num_delim_tokens] = 0
|
231 |
+
return mask[:, 1:].to(cfg.device)
|
232 |
+
|
233 |
+
|
234 |
+
def find_underscore_length(seq):
|
235 |
+
start = "_"
|
236 |
+
|
237 |
+
while start in seq:
|
238 |
+
start += "_"
|
239 |
+
return start[:-1]
|
240 |
+
|
241 |
+
|
242 |
+
def handle_underscores(suffix, text_encoder, prefix=False):
|
243 |
+
encoder = text_encoder.encoder
|
244 |
+
if prefix:
|
245 |
+
tok = "___"
|
246 |
+
else:
|
247 |
+
tok = find_underscore_length(suffix)
|
248 |
+
|
249 |
+
suffix_parts = [i.strip() for i in suffix.split("{}".format(tok))]
|
250 |
+
to_flatten = []
|
251 |
+
for i, part in enumerate(suffix_parts):
|
252 |
+
if part:
|
253 |
+
to_flatten.append(text_encoder.encode([part], verbose=False)[0])
|
254 |
+
|
255 |
+
if i != len(suffix_parts) - 1 and suffix_parts[i+1]:
|
256 |
+
to_flatten.append([encoder["<blank>"]])
|
257 |
+
else:
|
258 |
+
to_flatten.append([encoder["<blank>"]])
|
259 |
+
|
260 |
+
final_suffix = utils.flatten(to_flatten)
|
261 |
+
|
262 |
+
return final_suffix
|
263 |
+
|
264 |
+
|
265 |
+
def get_generation_sequences(opt, data, split, text_encoder, test):
|
266 |
+
sequences = []
|
267 |
+
count = 0
|
268 |
+
|
269 |
+
final_prefix = None
|
270 |
+
final_suffix = None
|
271 |
+
|
272 |
+
for prefix, category, suffix in tqdm(data[split]["total"]):
|
273 |
+
final_prefix, final_suffix = do_example(
|
274 |
+
text_encoder, prefix, suffix, True, True)
|
275 |
+
# if do_prefix:
|
276 |
+
# if "___" in prefix:
|
277 |
+
# final_prefix = handle_underscores(prefix, text_encoder, True)
|
278 |
+
# else:
|
279 |
+
# final_prefix = text_encoder.encode([prefix], verbose=False)[0]
|
280 |
+
# if do_suffix:
|
281 |
+
# if "_" in suffix:
|
282 |
+
# final_suffix = handle_underscores(suffix, text_encoder)
|
283 |
+
# else:
|
284 |
+
# final_suffix = text_encoder.encode([suffix], verbose=False)[0]
|
285 |
+
|
286 |
+
final = compile_final_sequence(
|
287 |
+
opt, final_prefix, final_suffix, category, text_encoder)
|
288 |
+
|
289 |
+
sequences.append(final)
|
290 |
+
|
291 |
+
count += 1
|
292 |
+
|
293 |
+
if count > 10 and test:
|
294 |
+
break
|
295 |
+
|
296 |
+
return sequences
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
def do_example(text_encoder, prefix, suffix, do_prefix, do_suffix):
|
301 |
+
final_prefix = None
|
302 |
+
final_suffix = None
|
303 |
+
|
304 |
+
if do_prefix:
|
305 |
+
if "___" in prefix:
|
306 |
+
final_prefix = handle_underscores(prefix, text_encoder, True)
|
307 |
+
else:
|
308 |
+
final_prefix = text_encoder.encode([prefix], verbose=False)[0]
|
309 |
+
if do_suffix:
|
310 |
+
if "_" in suffix:
|
311 |
+
final_suffix = handle_underscores(suffix, text_encoder)
|
312 |
+
else:
|
313 |
+
final_suffix = text_encoder.encode([suffix], verbose=False)[0]
|
314 |
+
|
315 |
+
return final_prefix, final_suffix
|
316 |
+
|
317 |
+
|
318 |
+
def compile_final_sequence(opt, final_prefix, final_suffix, category, text_encoder):
|
319 |
+
final = []
|
320 |
+
|
321 |
+
final.append(final_prefix)
|
322 |
+
final.append(
|
323 |
+
[text_encoder.encoder[category]]
|
324 |
+
+ final_suffix)
|
325 |
+
|
326 |
+
final[-1].append(text_encoder.encoder["<END>"])
|
327 |
+
|
328 |
+
return final
|
329 |
+
|
330 |
+
|
331 |
+
num_delimiter_tokens = {
|
332 |
+
"category": 1,
|
333 |
+
"hierarchy": 3,
|
334 |
+
"hierarchy+label": 4,
|
335 |
+
"category+hierarchy": 4,
|
336 |
+
"category+hierarchy+label": 5
|
337 |
+
}
|
Model/COSMIC/feature_extraction/comet/src/data/conceptnet.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.src.data.utils as data_utils
|
2 |
+
import comet.src.data.atomic as adata
|
3 |
+
import comet.src.data.config as cfg
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def map_name(name, opt):
|
11 |
+
if name == "train":
|
12 |
+
return "train{}k.txt".format(opt.trainsize)
|
13 |
+
elif name == "test":
|
14 |
+
return "test.txt"
|
15 |
+
else:
|
16 |
+
return "dev{}.txt".format(opt.devversion)
|
17 |
+
|
18 |
+
|
19 |
+
conceptnet_relations = [
|
20 |
+
'AtLocation', 'CapableOf', 'Causes', 'CausesDesire',
|
21 |
+
'CreatedBy', 'DefinedAs', 'DesireOf', 'Desires', 'HasA',
|
22 |
+
'HasFirstSubevent', 'HasLastSubevent', 'HasPainCharacter',
|
23 |
+
'HasPainIntensity', 'HasPrerequisite', 'HasProperty',
|
24 |
+
'HasSubevent', 'InheritsFrom', 'InstanceOf', 'IsA',
|
25 |
+
'LocatedNear', 'LocationOfAction', 'MadeOf', 'MotivatedByGoal',
|
26 |
+
'NotCapableOf', 'NotDesires', 'NotHasA', 'NotHasProperty',
|
27 |
+
'NotIsA', 'NotMadeOf', 'PartOf', 'ReceivesAction', 'RelatedTo',
|
28 |
+
'SymbolOf', 'UsedFor'
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
split_into_words = {
|
33 |
+
'AtLocation': "at location",
|
34 |
+
'CapableOf': "capable of",
|
35 |
+
'Causes': "causes",
|
36 |
+
'CausesDesire': "causes desire",
|
37 |
+
'CreatedBy': "created by",
|
38 |
+
'DefinedAs': "defined as",
|
39 |
+
'DesireOf': "desire of",
|
40 |
+
'Desires': "desires",
|
41 |
+
'HasA': "has a",
|
42 |
+
'HasFirstSubevent': "has first subevent",
|
43 |
+
'HasLastSubevent': "has last subevent",
|
44 |
+
'HasPainCharacter': "has pain character",
|
45 |
+
'HasPainIntensity': "has pain intensity",
|
46 |
+
'HasPrerequisite': "has prequisite",
|
47 |
+
'HasProperty': "has property",
|
48 |
+
'HasSubevent': "has subevent",
|
49 |
+
'InheritsFrom': "inherits from",
|
50 |
+
'InstanceOf': 'instance of',
|
51 |
+
'IsA': "is a",
|
52 |
+
'LocatedNear': "located near",
|
53 |
+
'LocationOfAction': "location of action",
|
54 |
+
'MadeOf': "made of",
|
55 |
+
'MotivatedByGoal': "motivated by goal",
|
56 |
+
'NotCapableOf': "not capable of",
|
57 |
+
'NotDesires': "not desires",
|
58 |
+
'NotHasA': "not has a",
|
59 |
+
'NotHasProperty': "not has property",
|
60 |
+
'NotIsA': "not is a",
|
61 |
+
'NotMadeOf': "not made of",
|
62 |
+
'PartOf': "part of",
|
63 |
+
'ReceivesAction': "receives action",
|
64 |
+
'RelatedTo': "related to",
|
65 |
+
'SymbolOf': "symbol of",
|
66 |
+
'UsedFor': "used for"
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
class GenerationDataLoader(adata.DataLoader):
|
71 |
+
def __init__(self, opt, categories=None):
|
72 |
+
super(GenerationDataLoader, self).__init__(opt)
|
73 |
+
self.opt = opt
|
74 |
+
|
75 |
+
for split in self.data:
|
76 |
+
self.data[split] = {"total": []}
|
77 |
+
self.offsets[split] = {"total": 0}
|
78 |
+
|
79 |
+
self.vocab_encoder = None
|
80 |
+
self.vocab_decoder = None
|
81 |
+
self.special_chars = None
|
82 |
+
self.max_e1 = None
|
83 |
+
self.max_e2 = None
|
84 |
+
self.max_r = None
|
85 |
+
|
86 |
+
def offset_summary(self, split):
|
87 |
+
return sum(self.offsets[split].values())
|
88 |
+
|
89 |
+
def load_data(self, path):
|
90 |
+
if ".pickle" in path:
|
91 |
+
print("Loading data from: {}".format(path))
|
92 |
+
data_utils.load_existing_data_loader(self, path)
|
93 |
+
return True
|
94 |
+
|
95 |
+
for split in self.data:
|
96 |
+
file_name = map_name(split, self.opt.data)
|
97 |
+
|
98 |
+
if split != "dev" or self.opt.data.devversion != "12":
|
99 |
+
string_tuples = open("{}/{}".format(
|
100 |
+
path, file_name), "r").read().split("\n")
|
101 |
+
tuples = [x.split("\t") for x in string_tuples if x]
|
102 |
+
else:
|
103 |
+
string_tuples = open("{}/{}".format(
|
104 |
+
path, "dev1.txt"), "r").read().split("\n")
|
105 |
+
tuples = [x.split("\t") for x in string_tuples if x]
|
106 |
+
string_tuples = open("{}/{}".format(
|
107 |
+
path, "dev2.txt"), "r").read().split("\n")
|
108 |
+
tuples += [x.split("\t") for x in string_tuples if x]
|
109 |
+
|
110 |
+
if split in ["dev", "test"]:
|
111 |
+
if self.opt.data.rel == "language":
|
112 |
+
self.data[split]["total"] = \
|
113 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
114 |
+
i[2].lower().strip(), int(i[3])) for i in tuples]
|
115 |
+
self.data[split]["positive"] = \
|
116 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
117 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])]
|
118 |
+
self.data[split]["negative"] = \
|
119 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
120 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])]
|
121 |
+
elif self.opt.data.rel == "relation":
|
122 |
+
self.data[split]["total"] = \
|
123 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
124 |
+
i[2].lower().strip(), int(i[3])) for i in tuples]
|
125 |
+
self.data[split]["positive"] = \
|
126 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
127 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])]
|
128 |
+
self.data[split]["negative"] = \
|
129 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
130 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])]
|
131 |
+
else:
|
132 |
+
if self.opt.data.rel == "language":
|
133 |
+
self.data[split]["total"] = \
|
134 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
135 |
+
i[2].lower().strip(), i[3]) for i in tuples]
|
136 |
+
elif self.opt.data.rel == "relation":
|
137 |
+
self.data[split]["total"] = \
|
138 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
139 |
+
i[2].lower().strip(), i[3]) for i in tuples]
|
140 |
+
|
141 |
+
return False
|
142 |
+
|
143 |
+
def make_tensors(self, text_encoder, special,
|
144 |
+
splits=["train", "dev", "test"], test=False):
|
145 |
+
self.vocab_encoder = text_encoder.encoder
|
146 |
+
self.vocab_decoder = text_encoder.decoder
|
147 |
+
self.special_chars = special
|
148 |
+
|
149 |
+
sequences = {}
|
150 |
+
for split in splits:
|
151 |
+
sequences[split], discarded = get_generation_sequences(
|
152 |
+
self.data, split, text_encoder, test, self.opt.data.maxe1,
|
153 |
+
self.opt.data.maxe2)
|
154 |
+
|
155 |
+
if split == "train":
|
156 |
+
self.data[split]["total"] = [j for i, j in enumerate(
|
157 |
+
self.data[split]["total"]) if i not in set(discarded)]
|
158 |
+
self.masks[split]["total"] = [(len(i[0]), len(i[1]), len(i[2])) for
|
159 |
+
i in sequences[split]]
|
160 |
+
|
161 |
+
self.max_e1 = max([max([l[0] for l in self.masks[split]["total"]])
|
162 |
+
for split in self.masks])
|
163 |
+
self.max_r = max([max([l[1] for l in self.masks[split]["total"]])
|
164 |
+
for split in self.masks])
|
165 |
+
self.max_e2 = max([max([l[2] for l in self.masks[split]["total"]])
|
166 |
+
for split in self.masks])
|
167 |
+
|
168 |
+
print(self.max_e1)
|
169 |
+
print(self.max_r)
|
170 |
+
print(self.max_e2)
|
171 |
+
|
172 |
+
for split in splits:
|
173 |
+
num_elements = len(sequences[split])
|
174 |
+
self.sequences[split]["total"] = torch.LongTensor(
|
175 |
+
num_elements, self.max_e1 + self.max_e2 + self.max_r).fill_(0)
|
176 |
+
|
177 |
+
for i, seq in enumerate(sequences[split]):
|
178 |
+
# print(self.sequences[split]["total"][i, :len(seq[0])].size())
|
179 |
+
# print(torch.FloatTensor(seq[0]).size())
|
180 |
+
self.sequences[split]["total"][i, :len(seq[0])] = \
|
181 |
+
torch.LongTensor(seq[0])
|
182 |
+
start_r = self.max_e1
|
183 |
+
end_r = self.max_e1 + len(seq[1])
|
184 |
+
self.sequences[split]["total"][i, start_r:end_r] = \
|
185 |
+
torch.LongTensor(seq[1])
|
186 |
+
start_e2 = self.max_e1 + self.max_r
|
187 |
+
end_e2 = self.max_e1 + self.max_r + len(seq[2])
|
188 |
+
self.sequences[split]["total"][i, start_e2:end_e2] = \
|
189 |
+
torch.LongTensor(seq[2])
|
190 |
+
|
191 |
+
if split in ["test", "dev"]:
|
192 |
+
print(split)
|
193 |
+
self.sequences[split]["negative"] = \
|
194 |
+
self.sequences[split]["total"].index_select(
|
195 |
+
0, torch.LongTensor([i for i, j in enumerate(
|
196 |
+
self.data[split]['total']) if not j[3]]))
|
197 |
+
# self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if not j[3]]))
|
198 |
+
self.sequences[split]["positive"] = \
|
199 |
+
self.sequences[split]["total"].index_select(
|
200 |
+
0, torch.LongTensor([i for i, j in enumerate(
|
201 |
+
self.data[split]['total']) if j[3]]))
|
202 |
+
# self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if j[3]]))
|
203 |
+
|
204 |
+
def sample_batch(self, split, bs, cat="total", idxs=None):
|
205 |
+
offset = self.offsets[split][cat]
|
206 |
+
|
207 |
+
batch = {}
|
208 |
+
|
209 |
+
# Decided not to reduce computation on here because it's all parallel
|
210 |
+
# anyway and we don't want to run out of memory in cases where we
|
211 |
+
# don't see the longest version quickly enough
|
212 |
+
|
213 |
+
if idxs:
|
214 |
+
seqs = self.sequences[split][cat].index_select(
|
215 |
+
0, torch.LongTensor(idxs).to(
|
216 |
+
self.sequences[split][cat].device))
|
217 |
+
else:
|
218 |
+
seqs = self.sequences[split][cat][offset:offset + bs]
|
219 |
+
batch["sequences"] = seqs.to(cfg.device)
|
220 |
+
batch["attention_mask"] = make_attention_mask(seqs)
|
221 |
+
batch["loss_mask"] = make_loss_mask(seqs, self.max_e1 + self.max_r)
|
222 |
+
batch["key"] = (cat, offset, offset + bs)
|
223 |
+
|
224 |
+
offset += seqs.size(0)
|
225 |
+
|
226 |
+
self.offsets[split][cat] = offset
|
227 |
+
|
228 |
+
if split == "train" and offset + bs > len(self.sequences[split][cat]):
|
229 |
+
return batch, True
|
230 |
+
elif offset >= len(self.sequences[split][cat]):
|
231 |
+
return batch, True
|
232 |
+
else:
|
233 |
+
return batch, False
|
234 |
+
|
235 |
+
def reset_offsets(self, splits=["train", "test", "dev"],
|
236 |
+
shuffle=True, keys=None):
|
237 |
+
if isinstance(splits, str):
|
238 |
+
splits = [splits]
|
239 |
+
|
240 |
+
for split in splits:
|
241 |
+
if keys is None:
|
242 |
+
keys = ["total", "positive", "negative"]
|
243 |
+
|
244 |
+
for key in keys:
|
245 |
+
self.offsets[split][key] = 0
|
246 |
+
|
247 |
+
if shuffle:
|
248 |
+
self.shuffle_sequences(split, keys)
|
249 |
+
|
250 |
+
def shuffle_sequences(self, split="train", keys=None):
|
251 |
+
if keys is None:
|
252 |
+
# print(type(self.data))
|
253 |
+
# print(type(self.data.keys()))
|
254 |
+
keys = self.data[split].keys()
|
255 |
+
|
256 |
+
for key in keys:
|
257 |
+
if key in ["positive", "negative"]:
|
258 |
+
continue
|
259 |
+
idxs = list(range(len(self.data[split][key])))
|
260 |
+
|
261 |
+
random.shuffle(idxs)
|
262 |
+
|
263 |
+
self.sequences[split][key] = \
|
264 |
+
self.sequences[split][key].index_select(
|
265 |
+
0, torch.LongTensor(idxs))
|
266 |
+
|
267 |
+
temp = [self.data[split][key][i] for i in idxs]
|
268 |
+
self.data[split][key] = temp
|
269 |
+
|
270 |
+
temp = [self.masks[split][key][i] for i in idxs]
|
271 |
+
self.masks[split][key] = temp
|
272 |
+
|
273 |
+
|
274 |
+
def make_attention_mask(sequences):
|
275 |
+
return (sequences != 0).float().to(cfg.device)
|
276 |
+
|
277 |
+
|
278 |
+
def make_loss_mask(sequences, max_event):
|
279 |
+
# print(sequences.size())
|
280 |
+
mask = (sequences != 0).float()
|
281 |
+
mask[:, :max_event] = 0
|
282 |
+
return mask[:, 1:].to(cfg.device)
|
283 |
+
|
284 |
+
|
285 |
+
def get_generation_sequences(data, split, text_encoder, test,
|
286 |
+
max_e1=10, max_e2=15):
|
287 |
+
sequences = []
|
288 |
+
count = 0
|
289 |
+
|
290 |
+
final_event1 = None
|
291 |
+
final_event2 = None
|
292 |
+
final_relation = None
|
293 |
+
|
294 |
+
discarded = []
|
295 |
+
|
296 |
+
for event1, relation, event2, _ in tqdm(data[split]["total"]):
|
297 |
+
e1, r, e2 = do_example(text_encoder, event1, relation, event2)
|
298 |
+
|
299 |
+
if (split == "train" and len(e1) > max_e1 or
|
300 |
+
len(e2) > max_e2):
|
301 |
+
discarded.append(count)
|
302 |
+
count += 1
|
303 |
+
continue
|
304 |
+
|
305 |
+
final = compile_final_sequence(
|
306 |
+
e1, e2, r, text_encoder)
|
307 |
+
|
308 |
+
sequences.append(final)
|
309 |
+
|
310 |
+
count += 1
|
311 |
+
|
312 |
+
if count > 10 and test:
|
313 |
+
break
|
314 |
+
|
315 |
+
return sequences, discarded
|
316 |
+
|
317 |
+
|
318 |
+
def do_example(text_encoder, event1, relation, event2):
|
319 |
+
final_event1 = text_encoder.encode([event1], verbose=False)[0]
|
320 |
+
if relation.lower() != relation:
|
321 |
+
final_relation = [text_encoder.encoder[relation]]
|
322 |
+
else:
|
323 |
+
final_relation = text_encoder.encode(
|
324 |
+
[relation], verbose=False)[0]
|
325 |
+
if event2 is not None:
|
326 |
+
final_event2 = text_encoder.encode([event2], verbose=False)[0]
|
327 |
+
else:
|
328 |
+
final_event2 = None
|
329 |
+
|
330 |
+
return final_event1, final_relation, final_event2
|
331 |
+
|
332 |
+
|
333 |
+
def compile_final_sequence(final_event1, final_event2, final_relation, text_encoder):
|
334 |
+
final = []
|
335 |
+
|
336 |
+
final.append(final_event1)
|
337 |
+
final.append(final_relation)
|
338 |
+
final.append(final_event2)
|
339 |
+
|
340 |
+
final[-1].append(text_encoder.encoder["<END>"])
|
341 |
+
|
342 |
+
return final
|
Model/COSMIC/feature_extraction/comet/src/data/config.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from comet.utils.utils import DD
|
3 |
+
|
4 |
+
device = "cpu"
|
5 |
+
|
6 |
+
save = False
|
7 |
+
test_save = False
|
8 |
+
toy = False
|
9 |
+
do_gen = False
|
10 |
+
|
11 |
+
save_strategy = "all"
|
12 |
+
|
13 |
+
|
14 |
+
def get_parameters(opt, exp_type="model"):
|
15 |
+
params = DD()
|
16 |
+
params.net = DD()
|
17 |
+
|
18 |
+
params.mle = 0
|
19 |
+
params.dataset = opt.dataset
|
20 |
+
|
21 |
+
params.net = get_net_parameters(opt)
|
22 |
+
params.train = get_training_parameters(opt)
|
23 |
+
|
24 |
+
params.model = params.net.model
|
25 |
+
params.exp = opt.exp
|
26 |
+
|
27 |
+
params.data = get_data_parameters(opt, params.exp, params.dataset)
|
28 |
+
params.eval = get_eval_parameters(opt, params.data.get("categories", None))
|
29 |
+
|
30 |
+
meta = DD()
|
31 |
+
|
32 |
+
params.trainer = opt.trainer
|
33 |
+
|
34 |
+
meta.iterations = int(opt.iterations)
|
35 |
+
meta.cycle = opt.cycle
|
36 |
+
params.cycle = opt.cycle
|
37 |
+
params.iters = int(opt.iterations)
|
38 |
+
|
39 |
+
global toy
|
40 |
+
toy = opt.toy
|
41 |
+
|
42 |
+
global do_gen
|
43 |
+
do_gen = opt.do_gen
|
44 |
+
|
45 |
+
global save
|
46 |
+
save = opt.save
|
47 |
+
|
48 |
+
global test_save
|
49 |
+
test_save = opt.test_save
|
50 |
+
|
51 |
+
global save_strategy
|
52 |
+
save_strategy = opt.save_strategy
|
53 |
+
|
54 |
+
print(params)
|
55 |
+
return params, meta
|
56 |
+
|
57 |
+
|
58 |
+
def get_eval_parameters(opt, force_categories=None):
|
59 |
+
evaluate = DD()
|
60 |
+
|
61 |
+
if opt.eval_sampler == "beam":
|
62 |
+
evaluate.bs = opt.beam_size
|
63 |
+
elif opt.eval_sampler == "greedy":
|
64 |
+
evaluate.bs = 1
|
65 |
+
elif opt.eval_sampler == "topk":
|
66 |
+
evaluate.k = opt.topk_size
|
67 |
+
|
68 |
+
evaluate.smax = opt.gen_seqlength
|
69 |
+
evaluate.sample = opt.eval_sampler
|
70 |
+
|
71 |
+
evaluate.numseq = opt.num_sequences
|
72 |
+
|
73 |
+
evaluate.gs = opt.generate_sequences
|
74 |
+
evaluate.es = opt.evaluate_sequences
|
75 |
+
|
76 |
+
if opt.dataset == "atomic":
|
77 |
+
if "eval_categories" in opt and force_categories is None:
|
78 |
+
evaluate.categories = opt.eval_categories
|
79 |
+
else:
|
80 |
+
evaluate.categories = force_categories
|
81 |
+
|
82 |
+
return evaluate
|
83 |
+
|
84 |
+
|
85 |
+
def get_data_parameters(opt, experiment, dataset):
|
86 |
+
data = DD()
|
87 |
+
if dataset == "atomic":
|
88 |
+
data.categories = sorted(opt.categories)
|
89 |
+
# hard-coded
|
90 |
+
data.maxe1 = 17
|
91 |
+
data.maxe2 = 35
|
92 |
+
data.maxr = 1
|
93 |
+
|
94 |
+
elif dataset == "conceptnet":
|
95 |
+
data.rel = opt.relation_format
|
96 |
+
data.trainsize = opt.training_set_size
|
97 |
+
data.devversion = opt.development_set_versions_to_use
|
98 |
+
data.maxe1 = opt.max_event_1_size
|
99 |
+
data.maxe2 = opt.max_event_2_size
|
100 |
+
if data.rel == "language":
|
101 |
+
# hard-coded
|
102 |
+
data.maxr = 5
|
103 |
+
else:
|
104 |
+
# hard-coded
|
105 |
+
data.maxr = 1
|
106 |
+
|
107 |
+
return data
|
108 |
+
|
109 |
+
|
110 |
+
def get_training_parameters(opt):
|
111 |
+
train = DD()
|
112 |
+
static = DD()
|
113 |
+
static.exp = opt.exp
|
114 |
+
|
115 |
+
static.seed = opt.random_seed
|
116 |
+
|
117 |
+
# weight decay
|
118 |
+
static.l2 = opt.l2
|
119 |
+
static.vl2 = True
|
120 |
+
static.lrsched = opt.learning_rate_schedule # 'warmup_linear'
|
121 |
+
static.lrwarm = opt.learning_rate_warmup # 0.002
|
122 |
+
|
123 |
+
# gradient clipping
|
124 |
+
static.clip = opt.clip
|
125 |
+
|
126 |
+
# what loss function to use
|
127 |
+
static.loss = opt.loss
|
128 |
+
|
129 |
+
dynamic = DD()
|
130 |
+
dynamic.lr = opt.learning_rate # learning rate
|
131 |
+
dynamic.bs = opt.batch_size # batch size
|
132 |
+
# optimizer to use {adam, rmsprop, etc.}
|
133 |
+
dynamic.optim = opt.optimizer
|
134 |
+
|
135 |
+
# rmsprop
|
136 |
+
# alpha is interpolation average
|
137 |
+
|
138 |
+
static.update(opt[dynamic.optim])
|
139 |
+
|
140 |
+
train.static = static
|
141 |
+
train.dynamic = dynamic
|
142 |
+
|
143 |
+
return train
|
144 |
+
|
145 |
+
|
146 |
+
def get_net_parameters(opt):
|
147 |
+
net = DD()
|
148 |
+
net.model = opt.model
|
149 |
+
net.nL = opt.num_layers
|
150 |
+
net.nH = opt.num_heads
|
151 |
+
net.hSize = opt.hidden_dim
|
152 |
+
net.edpt = opt.embedding_dropout
|
153 |
+
net.adpt = opt.attention_dropout
|
154 |
+
net.rdpt = opt.residual_dropout
|
155 |
+
net.odpt = opt.output_dropout
|
156 |
+
net.pt = opt.pretrain
|
157 |
+
net.afn = opt.activation
|
158 |
+
|
159 |
+
# how to intialize parameters
|
160 |
+
# format is gauss+{}+{}.format(mean, std)
|
161 |
+
# n = the default initialization pytorch
|
162 |
+
net.init = opt.init
|
163 |
+
|
164 |
+
return net
|
165 |
+
|
166 |
+
|
167 |
+
def read_config(file_):
|
168 |
+
config = DD()
|
169 |
+
print(file_)
|
170 |
+
for k, v in file_.items():
|
171 |
+
if v == "True" or v == "T" or v == "true":
|
172 |
+
config[k] = True
|
173 |
+
elif v == "False" or v == "F" or v == "false":
|
174 |
+
config[k] = False
|
175 |
+
elif type(v) == dict:
|
176 |
+
config[k] = read_config(v)
|
177 |
+
else:
|
178 |
+
config[k] = v
|
179 |
+
|
180 |
+
return config
|
181 |
+
|
182 |
+
|
183 |
+
def load_config(name):
|
184 |
+
with open(name, "r") as f:
|
185 |
+
config = json.load(f)
|
186 |
+
return config
|
Model/COSMIC/feature_extraction/comet/src/data/data.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import comet.src.data.atomic as atomic_data
|
3 |
+
import comet.src.data.conceptnet as conceptnet_data
|
4 |
+
import comet.src.data.config as cfg
|
5 |
+
|
6 |
+
import comet.utils.utils as utils
|
7 |
+
|
8 |
+
import pickle
|
9 |
+
import torch
|
10 |
+
import json
|
11 |
+
|
12 |
+
|
13 |
+
start_token = "<START>"
|
14 |
+
end_token = "<END>"
|
15 |
+
blank_token = "<blank>"
|
16 |
+
|
17 |
+
|
18 |
+
def save_checkpoint(state, filename):
|
19 |
+
print("Saving model to {}".format(filename))
|
20 |
+
torch.save(state, filename)
|
21 |
+
|
22 |
+
|
23 |
+
def save_step(model, vocab, optimizer, opt, length, lrs):
|
24 |
+
if cfg.test_save:
|
25 |
+
name = "{}.pickle".format(utils.make_name(
|
26 |
+
opt, prefix="garbage/models/", is_dir=False, eval_=True))
|
27 |
+
else:
|
28 |
+
name = "{}.pickle".format(utils.make_name(
|
29 |
+
opt, prefix="models/", is_dir=False, eval_=True))
|
30 |
+
save_checkpoint({
|
31 |
+
"epoch": length, "state_dict": model.state_dict(),
|
32 |
+
"optimizer": optimizer.state_dict(), "opt": opt,
|
33 |
+
"vocab": vocab, "epoch_learning_rates": lrs},
|
34 |
+
name)
|
35 |
+
|
36 |
+
|
37 |
+
def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"):
|
38 |
+
if cfg.test_save:
|
39 |
+
name = "{}/{}.{}".format(utils.make_name(
|
40 |
+
opt, prefix="garbage/{}/".format(eval_type),
|
41 |
+
is_dir=True, eval_=True), split, ext)
|
42 |
+
else:
|
43 |
+
name = "{}/{}.{}".format(utils.make_name(
|
44 |
+
opt, prefix="results/{}/".format(eval_type),
|
45 |
+
is_dir=True, eval_=True), split, ext)
|
46 |
+
print("Saving {} {} to {}".format(split, eval_type, name))
|
47 |
+
|
48 |
+
if ext == "pickle":
|
49 |
+
with open(name, "wb") as f:
|
50 |
+
pickle.dump(stats, f)
|
51 |
+
elif ext == "txt":
|
52 |
+
with open(name, "w") as f:
|
53 |
+
f.write(stats)
|
54 |
+
elif ext == "json":
|
55 |
+
with open(name, "w") as f:
|
56 |
+
json.dump(stats, f)
|
57 |
+
else:
|
58 |
+
raise
|
59 |
+
|
60 |
+
|
61 |
+
def load_checkpoint(filename, gpu=True):
|
62 |
+
if os.path.exists(filename):
|
63 |
+
checkpoint = torch.load(
|
64 |
+
filename, map_location=lambda storage, loc: storage)
|
65 |
+
else:
|
66 |
+
print("No model found at {}".format(filename))
|
67 |
+
return checkpoint
|
68 |
+
|
69 |
+
|
70 |
+
def make_data_loader(opt, *args):
|
71 |
+
if opt.dataset == "atomic":
|
72 |
+
return atomic_data.GenerationDataLoader(opt, *args)
|
73 |
+
elif opt.dataset == "conceptnet":
|
74 |
+
return conceptnet_data.GenerationDataLoader(opt, *args)
|
75 |
+
|
76 |
+
|
77 |
+
def set_max_sizes(data_loader, force_split=None):
|
78 |
+
data_loader.total_size = {}
|
79 |
+
if force_split is not None:
|
80 |
+
data_loader.total_size[force_split] = \
|
81 |
+
data_loader.sequences[force_split]["total"].size(0)
|
82 |
+
return
|
83 |
+
for split in data_loader.sequences:
|
84 |
+
data_loader.total_size[split] = \
|
85 |
+
data_loader.sequences[split]["total"].size(0)
|
Model/COSMIC/feature_extraction/comet/src/data/utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import ftfy
|
3 |
+
import json
|
4 |
+
import spacy
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def load_existing_data_loader(data_loader, path):
|
11 |
+
old_data_loader = torch.load(path)
|
12 |
+
for attr in data_loader.__dict__.keys():
|
13 |
+
if attr not in old_data_loader.__dict__.keys():
|
14 |
+
continue
|
15 |
+
setattr(data_loader, attr, getattr(old_data_loader, attr))
|
16 |
+
|
17 |
+
|
18 |
+
################################################################################
|
19 |
+
#
|
20 |
+
# Code Below taken from HuggingFace pytorch-openai-lm repository
|
21 |
+
#
|
22 |
+
################################################################################
|
23 |
+
|
24 |
+
|
25 |
+
def get_pairs(word):
|
26 |
+
"""
|
27 |
+
Return set of symbol pairs in a word.
|
28 |
+
word is represented as tuple of symbols (symbols being variable-length strings)
|
29 |
+
"""
|
30 |
+
pairs = set()
|
31 |
+
prev_char = word[0]
|
32 |
+
for char in word[1:]:
|
33 |
+
pairs.add((prev_char, char))
|
34 |
+
prev_char = char
|
35 |
+
return pairs
|
36 |
+
|
37 |
+
|
38 |
+
def text_standardize(text):
|
39 |
+
"""
|
40 |
+
fixes some issues the spacy tokenizer had on books corpus
|
41 |
+
also does some whitespace standardization
|
42 |
+
"""
|
43 |
+
text = text.replace('—', '-')
|
44 |
+
text = text.replace('–', '-')
|
45 |
+
text = text.replace('―', '-')
|
46 |
+
text = text.replace('…', '...')
|
47 |
+
text = text.replace('´', "'")
|
48 |
+
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
|
49 |
+
text = re.sub(r'\s*\n\s*', ' \n ', text)
|
50 |
+
text = re.sub(r'[^\S\n]+', ' ', text)
|
51 |
+
return text.strip()
|
52 |
+
|
53 |
+
|
54 |
+
class TextEncoder(object):
|
55 |
+
"""
|
56 |
+
mostly a wrapper for a public python bpe tokenizer
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, encoder_path, bpe_path):
|
60 |
+
self.nlp = spacy.load(
|
61 |
+
'en_core_web_sm', disable=['parser', 'tagger', 'ner', 'textcat'])
|
62 |
+
self.encoder = json.load(open(encoder_path))
|
63 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
64 |
+
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
|
65 |
+
merges = [tuple(merge.split()) for merge in merges]
|
66 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
67 |
+
self.cache = {}
|
68 |
+
|
69 |
+
def bpe(self, token):
|
70 |
+
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
71 |
+
if token in self.cache:
|
72 |
+
return self.cache[token]
|
73 |
+
pairs = get_pairs(word)
|
74 |
+
|
75 |
+
if not pairs:
|
76 |
+
return token+'</w>'
|
77 |
+
|
78 |
+
while True:
|
79 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(
|
80 |
+
pair, float('inf')))
|
81 |
+
if bigram not in self.bpe_ranks:
|
82 |
+
break
|
83 |
+
first, second = bigram
|
84 |
+
new_word = []
|
85 |
+
i = 0
|
86 |
+
while i < len(word):
|
87 |
+
try:
|
88 |
+
j = word.index(first, i)
|
89 |
+
new_word.extend(word[i:j])
|
90 |
+
i = j
|
91 |
+
except:
|
92 |
+
new_word.extend(word[i:])
|
93 |
+
break
|
94 |
+
|
95 |
+
if (word[i] == first and i < len(word) - 1 and
|
96 |
+
word[i+1] == second):
|
97 |
+
new_word.append(first+second)
|
98 |
+
i += 2
|
99 |
+
else:
|
100 |
+
new_word.append(word[i])
|
101 |
+
i += 1
|
102 |
+
new_word = tuple(new_word)
|
103 |
+
word = new_word
|
104 |
+
if len(word) == 1:
|
105 |
+
break
|
106 |
+
else:
|
107 |
+
pairs = get_pairs(word)
|
108 |
+
word = ' '.join(word)
|
109 |
+
if word == '\n </w>':
|
110 |
+
word = '\n</w>'
|
111 |
+
self.cache[token] = word
|
112 |
+
return word
|
113 |
+
|
114 |
+
def encode(self, texts, verbose=True):
|
115 |
+
texts_tokens = []
|
116 |
+
if verbose:
|
117 |
+
for text in tqdm(texts, ncols=80, leave=False):
|
118 |
+
text = self.nlp(text_standardize(ftfy.fix_text(text)))
|
119 |
+
text_tokens = []
|
120 |
+
for token in text:
|
121 |
+
text_tokens.extend(
|
122 |
+
[self.encoder.get(t, 0) for t in
|
123 |
+
self.bpe(token.text.lower()).split(' ')])
|
124 |
+
texts_tokens.append(text_tokens)
|
125 |
+
else:
|
126 |
+
for text in texts:
|
127 |
+
text = self.nlp(text_standardize(ftfy.fix_text(text)))
|
128 |
+
text_tokens = []
|
129 |
+
for token in text:
|
130 |
+
text_tokens.extend(
|
131 |
+
[self.encoder.get(t, 0) for t in
|
132 |
+
self.bpe(token.text.lower()).split(' ')])
|
133 |
+
texts_tokens.append(text_tokens)
|
134 |
+
return texts_tokens
|
Model/COSMIC/feature_extraction/comet/src/evaluate/atomic_evaluate.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.src.train.batch as batch
|
2 |
+
import comet.src.evaluate.evaluate as base_evaluate
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def make_evaluator(opt, *args):
|
6 |
+
if opt.exp == "generation":
|
7 |
+
return AtomicGenerationEvaluator(opt, *args)
|
8 |
+
else:
|
9 |
+
return AtomicClassificationEvaluator(opt, *args)
|
10 |
+
|
11 |
+
|
12 |
+
class AtomicGenerationEvaluator(base_evaluate.Evaluator):
|
13 |
+
def __init__(self, opt, model, data_loader):
|
14 |
+
super(AtomicGenerationEvaluator, self).__init__(
|
15 |
+
opt, model, data_loader)
|
16 |
+
|
17 |
+
self.batch = batch.batch_atomic_generate
|
18 |
+
|
19 |
+
def initialize_losses(self):
|
20 |
+
average_loss = {"total_micro": 0, "total_macro": 0}
|
21 |
+
nums = {"total_micro": 0, "total_macro": 0}
|
22 |
+
return average_loss, nums
|
23 |
+
|
24 |
+
def compute_final_scores(self, average_loss, nums):
|
25 |
+
average_loss["total_macro"] /= nums["total_macro"]
|
26 |
+
average_loss["total_micro"] /= nums["total_micro"]
|
27 |
+
|
28 |
+
average_loss["ppl_macro"] = np.exp(average_loss["total_macro"])
|
29 |
+
average_loss["ppl_micro"] = np.exp(average_loss["total_micro"])
|
30 |
+
|
31 |
+
return average_loss
|
32 |
+
|
33 |
+
def counter(self, nums):
|
34 |
+
return nums["total_macro"]
|
35 |
+
|
36 |
+
def print_result(self, split, epoch_losses):
|
37 |
+
print("{} Loss: \t {}".format(
|
38 |
+
split, epoch_losses["total_micro"]))
|
39 |
+
print("{} Perplexity: \t {}".format(
|
40 |
+
split, epoch_losses["ppl_micro"]))
|
Model/COSMIC/feature_extraction/comet/src/evaluate/conceptnet_evaluate.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import comet.src.train.batch as batch_utils
|
5 |
+
import comet.utils.utils as utils
|
6 |
+
import comet.src.evaluate.evaluate as base_evaluate
|
7 |
+
|
8 |
+
|
9 |
+
def make_evaluator(opt, *args, **kwargs):
|
10 |
+
return ConceptNetGenerationEvaluator(opt, *args, **kwargs)
|
11 |
+
|
12 |
+
|
13 |
+
class ConceptNetGenerationEvaluator(base_evaluate.Evaluator):
|
14 |
+
def __init__(self, opt, model, data_loader, track=False):
|
15 |
+
super(ConceptNetGenerationEvaluator, self).__init__(
|
16 |
+
opt, model, data_loader)
|
17 |
+
|
18 |
+
if track:
|
19 |
+
self.tracker = {"positive": [], "negative": []}
|
20 |
+
else:
|
21 |
+
self.tracker = None
|
22 |
+
|
23 |
+
def batch(self, opt, nums, average_loss, batch_variables, eval_mode):
|
24 |
+
batch_variables["category"] = self.current_category
|
25 |
+
|
26 |
+
outputs = batch_utils.batch_conceptnet_generate(
|
27 |
+
opt, nums, average_loss, batch_variables, eval_mode,
|
28 |
+
tracking_mode=self.tracker is not None)
|
29 |
+
|
30 |
+
if outputs.get("tracking", None) is not None:
|
31 |
+
self.tracker[self.current_category] += outputs["tracking"]
|
32 |
+
|
33 |
+
if outputs["reset"] and batch_variables["category"] == "positive":
|
34 |
+
outputs["reset"] = False
|
35 |
+
self.current_category = "negative"
|
36 |
+
|
37 |
+
return outputs
|
38 |
+
|
39 |
+
def initialize_losses(self):
|
40 |
+
average_loss = {"total_micro": 0, "total_macro": 0,
|
41 |
+
"negative_micro": 0, "negative_macro": 0}
|
42 |
+
nums = {"total_micro": 0, "total_macro": 0,
|
43 |
+
"negative_micro": 0, "negative_macro": 0}
|
44 |
+
|
45 |
+
self.current_category = "positive"
|
46 |
+
|
47 |
+
if self.tracker is not None:
|
48 |
+
self.tracker = {"positive": [], "negative": []}
|
49 |
+
|
50 |
+
return average_loss, nums
|
51 |
+
|
52 |
+
def compute_final_scores(self, average_loss, nums):
|
53 |
+
average_loss["total_macro"] /= nums["total_macro"]
|
54 |
+
average_loss["total_micro"] /= nums["total_micro"]
|
55 |
+
|
56 |
+
if nums["negative_micro"]:
|
57 |
+
average_loss["negative_macro"] /= nums["negative_macro"]
|
58 |
+
average_loss["negative_micro"] /= nums["negative_micro"]
|
59 |
+
else:
|
60 |
+
average_loss["negative_macro"] = 0
|
61 |
+
average_loss["negative_micro"] = 0
|
62 |
+
|
63 |
+
average_loss["macro_diff"] = (average_loss["negative_macro"] -
|
64 |
+
average_loss["total_macro"])
|
65 |
+
average_loss["micro_diff"] = (average_loss["negative_micro"] -
|
66 |
+
average_loss["total_micro"])
|
67 |
+
|
68 |
+
average_loss["ppl_macro"] = np.exp(average_loss["total_macro"])
|
69 |
+
average_loss["ppl_micro"] = np.exp(average_loss["total_micro"])
|
70 |
+
|
71 |
+
return average_loss
|
72 |
+
|
73 |
+
def counter(self, nums):
|
74 |
+
return nums["total_macro"]
|
75 |
+
|
76 |
+
def print_result(self, split, epoch_losses):
|
77 |
+
print("{} Loss: \t {}".format(
|
78 |
+
split, epoch_losses["total_micro"]))
|
79 |
+
print("{} Diff: \t {}".format(
|
80 |
+
split, epoch_losses["micro_diff"]))
|
81 |
+
print("{} Perplexity: \t {}".format(
|
82 |
+
split, epoch_losses["ppl_micro"]))
|
Model/COSMIC/feature_extraction/comet/src/evaluate/conceptnet_generate.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import comet.src.evaluate.generate as base_generate
|
5 |
+
import comet.src.evaluate.sampler as sampling
|
6 |
+
import comet.utils.utils as utils
|
7 |
+
import comet.src.data.config as cfg
|
8 |
+
|
9 |
+
|
10 |
+
def make_generator(opt, *args):
|
11 |
+
return ConceptNetGenerator(opt, *args)
|
12 |
+
|
13 |
+
|
14 |
+
class ConceptNetGenerator(base_generate.Generator):
|
15 |
+
def __init__(self, opt, model, data_loader):
|
16 |
+
self.opt = opt
|
17 |
+
|
18 |
+
self.model = model
|
19 |
+
self.data_loader = data_loader
|
20 |
+
|
21 |
+
self.sampler = sampling.make_sampler(
|
22 |
+
opt.eval.sample, opt, data_loader)
|
23 |
+
|
24 |
+
def reset_sequences(self):
|
25 |
+
return []
|
26 |
+
|
27 |
+
def generate(self, split="dev"):
|
28 |
+
print("Generating Sequences")
|
29 |
+
|
30 |
+
# Set evaluation mode
|
31 |
+
self.model.eval()
|
32 |
+
|
33 |
+
# Reset evaluation set for dataset split
|
34 |
+
self.data_loader.reset_offsets(splits=split, shuffle=False)
|
35 |
+
|
36 |
+
start = time.time()
|
37 |
+
count = 0
|
38 |
+
sequences = None
|
39 |
+
|
40 |
+
# Reset generated sequence buffer
|
41 |
+
sequences = self.reset_sequences()
|
42 |
+
|
43 |
+
# Initialize progress bar
|
44 |
+
bar = utils.set_progress_bar(
|
45 |
+
self.data_loader.total_size[split] / 2)
|
46 |
+
|
47 |
+
reset = False
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
# Cycle through development set
|
51 |
+
while not reset:
|
52 |
+
|
53 |
+
start = len(sequences)
|
54 |
+
# Generate a single batch
|
55 |
+
reset = self.generate_batch(sequences, split, bs=1)
|
56 |
+
|
57 |
+
end = len(sequences)
|
58 |
+
|
59 |
+
if not reset:
|
60 |
+
bar.update(end - start)
|
61 |
+
else:
|
62 |
+
print(end)
|
63 |
+
|
64 |
+
count += 1
|
65 |
+
|
66 |
+
if cfg.toy and count > 10:
|
67 |
+
break
|
68 |
+
if (self.opt.eval.gs != "full" and (count > opt.eval.gs)):
|
69 |
+
break
|
70 |
+
|
71 |
+
torch.cuda.synchronize()
|
72 |
+
print("{} generations completed in: {} s".format(
|
73 |
+
split, time.time() - start))
|
74 |
+
|
75 |
+
# Compute scores for sequences (e.g., BLEU, ROUGE)
|
76 |
+
# Computes scores that the generator is initialized with
|
77 |
+
# Change define_scorers to add more scorers as possibilities
|
78 |
+
# avg_scores, indiv_scores = self.compute_sequence_scores(
|
79 |
+
# sequences, split)
|
80 |
+
avg_scores, indiv_scores = None, None
|
81 |
+
|
82 |
+
return sequences, avg_scores, indiv_scores
|
83 |
+
|
84 |
+
def generate_batch(self, sequences, split, verbose=False, bs=1):
|
85 |
+
# Sample batch from data loader
|
86 |
+
batch, reset = self.data_loader.sample_batch(
|
87 |
+
split, bs=bs, cat="positive")
|
88 |
+
|
89 |
+
start_idx = self.data_loader.max_e1 + self.data_loader.max_r
|
90 |
+
max_end_len = self.data_loader.max_e2
|
91 |
+
|
92 |
+
context = batch["sequences"][:, :start_idx]
|
93 |
+
reference = batch["sequences"][:, start_idx:]
|
94 |
+
init = "".join([self.data_loader.vocab_decoder[i].replace(
|
95 |
+
'</w>', ' ') for i in context[:, :self.data_loader.max_e1].squeeze().tolist() if i]).strip()
|
96 |
+
|
97 |
+
start = self.data_loader.max_e1
|
98 |
+
end = self.data_loader.max_e1 + self.data_loader.max_r
|
99 |
+
|
100 |
+
attr = "".join([self.data_loader.vocab_decoder[i].replace(
|
101 |
+
'</w>', ' ') for i in context[:, start:end].squeeze(0).tolist() if i]).strip()
|
102 |
+
|
103 |
+
# Decode sequence
|
104 |
+
sampling_result = self.sampler.generate_sequence(
|
105 |
+
batch, self.model, self.data_loader, start_idx, max_end_len)
|
106 |
+
|
107 |
+
sampling_result["key"] = batch["key"]
|
108 |
+
sampling_result["e1"] = init
|
109 |
+
sampling_result["r"] = attr
|
110 |
+
sequences.append(sampling_result)
|
111 |
+
|
112 |
+
return reset
|
Model/COSMIC/feature_extraction/comet/src/evaluate/evaluate.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import comet.utils.utils as utils
|
5 |
+
import comet.src.data.config as cfg
|
6 |
+
|
7 |
+
|
8 |
+
class Evaluator(object):
|
9 |
+
def __init__(self, opt, model, data_loader):
|
10 |
+
super(Evaluator, self).__init__()
|
11 |
+
|
12 |
+
self.data_loader = data_loader
|
13 |
+
self.model = model
|
14 |
+
|
15 |
+
self.batch_variables = {
|
16 |
+
"model": model,
|
17 |
+
"data": data_loader
|
18 |
+
}
|
19 |
+
|
20 |
+
self.opt = opt
|
21 |
+
|
22 |
+
def validate(self, l, split="dev", losses={}, keyset=None):
|
23 |
+
self.batch_variables["split"] = split
|
24 |
+
print("Evaluating {}".format(split))
|
25 |
+
|
26 |
+
epoch_losses = self.epoch(
|
27 |
+
self.opt, self.model, self.data_loader, split, keyset)
|
28 |
+
|
29 |
+
self.print_result(split, epoch_losses)
|
30 |
+
|
31 |
+
for loss_name, loss_val in epoch_losses.items():
|
32 |
+
losses.setdefault(loss_name, {})
|
33 |
+
losses[loss_name][l] = loss_val
|
34 |
+
|
35 |
+
def epoch(self, opt, model, data_loader, split, keyset=None):
|
36 |
+
average_loss, nums = self.initialize_losses()
|
37 |
+
|
38 |
+
data_loader.reset_offsets(splits=split, shuffle=False)
|
39 |
+
|
40 |
+
# Set evaluation mode
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
start = time.time()
|
44 |
+
|
45 |
+
# Initialize progress bar
|
46 |
+
bar = utils.set_progress_bar(
|
47 |
+
data_loader.total_size[split])
|
48 |
+
|
49 |
+
reset = False
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
while not reset:
|
53 |
+
|
54 |
+
start = data_loader.offset_summary(split)
|
55 |
+
|
56 |
+
outputs = self.batch(
|
57 |
+
opt, nums, average_loss,
|
58 |
+
self.batch_variables, eval_mode=True)
|
59 |
+
|
60 |
+
end = data_loader.offset_summary(split)
|
61 |
+
|
62 |
+
reset = outputs["reset"]
|
63 |
+
|
64 |
+
if not reset:
|
65 |
+
bar.update(end - start)
|
66 |
+
else:
|
67 |
+
print(end)
|
68 |
+
|
69 |
+
if cfg.toy and self.counter(nums) > 100:
|
70 |
+
break
|
71 |
+
if (opt.eval.es != "full" and
|
72 |
+
(self.counter(nums) > opt.eval.es)):
|
73 |
+
break
|
74 |
+
|
75 |
+
nums = outputs["nums"]
|
76 |
+
|
77 |
+
torch.cuda.synchronize()
|
78 |
+
|
79 |
+
print("{} evaluation completed in: {} s".format(
|
80 |
+
split.capitalize(), time.time() - start))
|
81 |
+
|
82 |
+
average_loss = self.compute_final_scores(
|
83 |
+
average_loss, nums)
|
84 |
+
|
85 |
+
return average_loss
|
Model/COSMIC/feature_extraction/comet/src/evaluate/generate.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.src.data.data as data
|
2 |
+
import comet.src.data.config as cfg
|
3 |
+
import comet.src.evaluate.sampler as sampling
|
4 |
+
|
5 |
+
|
6 |
+
def do_gen_run(opt, generator, l, split="dev", scores={}):
|
7 |
+
# Generate sequences for examples in evaluation set using
|
8 |
+
# current trained model
|
9 |
+
|
10 |
+
if opt.eval.gs == "full":
|
11 |
+
sequences, avg_scores, indiv_scores = generator.generate(split)
|
12 |
+
else:
|
13 |
+
sequences, avg_scores, indiv_scores = generator.generate_some(split)
|
14 |
+
|
15 |
+
if avg_scores is not None:
|
16 |
+
# Record scores from generated sequences
|
17 |
+
for score_name, score_val in avg_scores.items():
|
18 |
+
scores.setdefault(score_name, {})
|
19 |
+
scores[score_name].setdefault(l, [])
|
20 |
+
scores[score_name][l] += [score_val]
|
21 |
+
|
22 |
+
# Save generated sequences
|
23 |
+
save_sequences(opt, sequences, avg_scores, indiv_scores,
|
24 |
+
l, split, opt.eval.gs == "full",
|
25 |
+
generator.data_loader)
|
26 |
+
|
27 |
+
|
28 |
+
def save_sequences(opt, sequences, avg_scores, indiv_scores,
|
29 |
+
l, split, full, data_loader):
|
30 |
+
# This seems a bit roundabout since l = opt.train.dynamic in train.py
|
31 |
+
# But it's in case we start checkpointing outside of epoch boundaries
|
32 |
+
opt.train.dynamic.epoch = l
|
33 |
+
|
34 |
+
if cfg.save:
|
35 |
+
if full:
|
36 |
+
names = {"gens": "gens", "scores": "scores",
|
37 |
+
"indiv": "indiv.scores"}
|
38 |
+
else:
|
39 |
+
names = {"gens": "gens.small", "scores": "scores.small",
|
40 |
+
"indiv": "indiv.scores.small"}
|
41 |
+
# Save generated sequences
|
42 |
+
data.save_eval_file(opt, sequences, names["gens"], split)
|
43 |
+
|
44 |
+
if avg_scores is not None:
|
45 |
+
# Save average scores over evaluation set for generated sequences
|
46 |
+
# Scores computed are the ones the generator was initialized with
|
47 |
+
data.save_eval_file(opt, avg_scores, names["scores"], split)
|
48 |
+
|
49 |
+
if split == "dev":
|
50 |
+
# Save individual scores
|
51 |
+
data.save_eval_file(
|
52 |
+
opt, indiv_scores, names["indiv"], split)
|
53 |
+
|
54 |
+
|
55 |
+
class Generator(object):
|
56 |
+
def __init__(self, opt, model, data_loader, scorers, reward_function=None):
|
57 |
+
super(Generator, self).__init__()
|
58 |
+
self.opt = opt
|
59 |
+
|
60 |
+
self.model = model
|
61 |
+
self.data_loader = data_loader
|
62 |
+
|
63 |
+
self.sampler = sampling.make_sampler(
|
64 |
+
opt.eval.sample, opt, data_loader)
|
65 |
+
|
66 |
+
|
67 |
+
def generate(self, split="dev"):
|
68 |
+
pass
|
69 |
+
|
70 |
+
def generate_batch(self, sequences, split, verbose=False, bs=32):
|
71 |
+
pass
|
72 |
+
|
Model/COSMIC/feature_extraction/comet/src/evaluate/sampler.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import comet.src.data.data as data
|
6 |
+
import comet.src.data.config as cfg
|
7 |
+
import comet.src.models.utils as model_utils
|
8 |
+
import comet.src.evaluate.utils as eval_utils
|
9 |
+
import comet.src.train.batch as batch_utils
|
10 |
+
|
11 |
+
def make_sampler(sampler_type, opt, *args, **kwargs):
|
12 |
+
print("Initializing Greedy Sampler")
|
13 |
+
return GreedySampler(opt, *args, **kwargs)
|
14 |
+
|
15 |
+
class Sampler():
|
16 |
+
def __init__(self, opt, data_loader, batch_mode=False):
|
17 |
+
# Token on which to end sampling
|
18 |
+
self.end_token = data_loader.vocab_encoder[data.end_token]
|
19 |
+
|
20 |
+
self.opt = opt
|
21 |
+
|
22 |
+
def generate_sequence(self, batch, model):
|
23 |
+
raise
|
24 |
+
|
25 |
+
|
26 |
+
class GreedySampler(Sampler):
|
27 |
+
def __init__(self, opt, data_loader, batch_mode=True):
|
28 |
+
super(GreedySampler, self).__init__(opt, data_loader)
|
29 |
+
|
30 |
+
def append_batch(self, X, next_idx, mask):
|
31 |
+
next_pos = X[:, -1:, 1] + 1
|
32 |
+
next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
|
33 |
+
next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
|
34 |
+
return torch.cat((X, next_x), 1), next_mask
|
35 |
+
|
36 |
+
def generate_sequence(self, batch, model, data_loader, start_idx, end_len):
|
37 |
+
XMB = batch["sequences"][:, :start_idx]
|
38 |
+
MMB = batch["attention_mask"][:, :start_idx]
|
39 |
+
|
40 |
+
XMB = model_utils.prepare_position_embeddings(
|
41 |
+
self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1))
|
42 |
+
|
43 |
+
_, lp = model(
|
44 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
45 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
46 |
+
|
47 |
+
values, indices = lm_probs[:, -1, :].max(dim=-1)
|
48 |
+
seqs = indices.clone().unsqueeze(1)
|
49 |
+
|
50 |
+
loss = values
|
51 |
+
counts = 1
|
52 |
+
next_pos = XMB[:, -1:, 1] + 1
|
53 |
+
next_x = torch.cat((indices.view(-1, 1), next_pos), -1).unsqueeze(1)
|
54 |
+
XMB = torch.cat((XMB, next_x), 1)
|
55 |
+
MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)
|
56 |
+
|
57 |
+
# Sample from top k
|
58 |
+
|
59 |
+
for _ in range(self.opt.eval.smax):
|
60 |
+
_, lp = model(
|
61 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
62 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
63 |
+
|
64 |
+
# Sample from top k
|
65 |
+
values, next_idx = lm_probs[:, -1, :].max(dim=-1)
|
66 |
+
|
67 |
+
loss += values
|
68 |
+
counts += 1
|
69 |
+
|
70 |
+
next_idx = next_idx.unsqueeze(1)
|
71 |
+
|
72 |
+
seqs = torch.cat([seqs, next_idx], 1)
|
73 |
+
|
74 |
+
if (next_idx.item() == self.end_token) or (_ == end_len - 1):
|
75 |
+
break
|
76 |
+
|
77 |
+
XMB, MMB = self.append_batch(XMB, next_idx, MMB)
|
78 |
+
|
79 |
+
beams = []
|
80 |
+
|
81 |
+
for beam in seqs:
|
82 |
+
beams.append(" ".join("".join(
|
83 |
+
[data_loader.vocab_decoder[tok.item()].replace(
|
84 |
+
'</w>', ' ').replace('\n', '')
|
85 |
+
for tok in beam if tok != self.end_token]).split()))
|
86 |
+
|
87 |
+
sampling_result = {
|
88 |
+
"sequence": beams[0],
|
89 |
+
"beams": beams,
|
90 |
+
"beam_losses": [loss.item()],
|
91 |
+
"loss": loss.item(),
|
92 |
+
"beam_lengths": [counts],
|
93 |
+
"length": counts
|
94 |
+
}
|
95 |
+
|
96 |
+
return sampling_result
|
97 |
+
|
98 |
+
|
99 |
+
class TopKSampler(Sampler):
|
100 |
+
def __init__(self, opt, data_loader, batch_mode=True):
|
101 |
+
super(TopKSampler, self).__init__(opt, data_loader)
|
102 |
+
|
103 |
+
def append_batch(self, X, next_idx, mask):
|
104 |
+
next_pos = X[:, -1:, 1] + 1
|
105 |
+
next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
|
106 |
+
next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
|
107 |
+
return torch.cat((X, next_x), 1), next_mask
|
108 |
+
|
109 |
+
def generate_sequence(self, batch, model, data_loader, start_idx, end_len):
|
110 |
+
# start_idx = context_size_event + 1
|
111 |
+
# start_idx = max_e1 + max_r
|
112 |
+
# end_idx = context_size_effect - 1
|
113 |
+
# end_idx = max_e2
|
114 |
+
XMB = batch["sequences"][:, :start_idx]
|
115 |
+
MMB = batch["attention_mask"][:, :start_idx]
|
116 |
+
|
117 |
+
XMB = model_utils.prepare_position_embeddings(
|
118 |
+
self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1))
|
119 |
+
|
120 |
+
_, lp = model(
|
121 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
122 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
123 |
+
|
124 |
+
values, indices = lm_probs[:, -1, :].topk(self.opt.eval.k)
|
125 |
+
seqs = indices.t().clone()
|
126 |
+
|
127 |
+
losses = - values.view(-1, 1)
|
128 |
+
|
129 |
+
ended = (seqs == self.end_token).float()
|
130 |
+
counts = (1 - ended)
|
131 |
+
XMB = XMB.repeat(self.opt.eval.k, 1, 1)
|
132 |
+
MMB = MMB.repeat(self.opt.eval.k, 1)
|
133 |
+
next_pos = XMB[:, -1:, 1] + 1
|
134 |
+
next_x = torch.cat((indices.view(self.opt.eval.k, -1), next_pos), -1).unsqueeze(1)
|
135 |
+
XMB = torch.cat((XMB, next_x), 1)
|
136 |
+
MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)
|
137 |
+
|
138 |
+
# Sample from top k
|
139 |
+
|
140 |
+
for _ in range(end_len):
|
141 |
+
_, lp = model(XMB.unsqueeze(1), sequence_mask=MMB)
|
142 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
143 |
+
|
144 |
+
# Sample from top k
|
145 |
+
values, indices = lm_probs[:, -1, :].topk(self.opt.eval.k)
|
146 |
+
choice = torch.multinomial(values.exp(), 1)
|
147 |
+
next_idx = indices.gather(-1, choice)
|
148 |
+
|
149 |
+
ended = ended + (next_idx == self.end_token).float() * (1 - ended)
|
150 |
+
|
151 |
+
next_idx = next_idx * (1 - ended).long() + ended.long() * self.end_token
|
152 |
+
|
153 |
+
counts += (1 - ended)
|
154 |
+
|
155 |
+
seqs = torch.cat([seqs, next_idx], 1)
|
156 |
+
|
157 |
+
if ended.sum().item() == self.opt.eval.k:
|
158 |
+
break
|
159 |
+
|
160 |
+
losses -= values.gather(-1, choice) * (1 - ended)
|
161 |
+
|
162 |
+
XMB, MMB = self.append_batch(XMB, next_idx, MMB)
|
163 |
+
|
164 |
+
beams = []
|
165 |
+
|
166 |
+
for beam in seqs:
|
167 |
+
beams.append(" ".join("".join(
|
168 |
+
[data_loader.vocab_decoder[tok.item()].replace(
|
169 |
+
'</w>', ' ').replace('\n', '')
|
170 |
+
for tok in beam if tok != self.end_token]).split()))
|
171 |
+
|
172 |
+
sampling_result = {
|
173 |
+
"sequence": beams[0],
|
174 |
+
"beams": beams,
|
175 |
+
"beam_losses": losses.squeeze().tolist(),
|
176 |
+
"loss": losses[0].item(),
|
177 |
+
"beam_lengths": counts.long().squeeze().tolist(),
|
178 |
+
"length": counts[0].long().item()
|
179 |
+
}
|
180 |
+
|
181 |
+
return sampling_result
|
182 |
+
|
183 |
+
|
184 |
+
class BeamSampler(TopKSampler):
|
185 |
+
def __init__(self, opt, data_loader, batch_mode=True, scorer=None):
|
186 |
+
super(BeamSampler, self).__init__(opt, data_loader, batch_mode)
|
187 |
+
|
188 |
+
self.kill_mask = torch.ones(opt.eval.bs, opt.eval.bs).to(cfg.device) * 9000
|
189 |
+
self.kill_mask[:, 0] = 0
|
190 |
+
|
191 |
+
def make_batch(self, X):
|
192 |
+
X = np.array(X)
|
193 |
+
assert X.ndim in [1, 2]
|
194 |
+
if X.ndim == 1:
|
195 |
+
X = np.expand_dims(X, axis=0)
|
196 |
+
pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
|
197 |
+
pos_enc = np.expand_dims(pos_enc, axis=0)
|
198 |
+
batch = np.stack([X, pos_enc], axis=-1)
|
199 |
+
batch = torch.tensor(batch, dtype=torch.long).to(device)
|
200 |
+
return batch
|
201 |
+
|
202 |
+
def append_batch(self, X, beam_toks, mask):
|
203 |
+
next_pos = X[:, -1:, 1] + 1
|
204 |
+
next_x = torch.cat((beam_toks.unsqueeze(1), next_pos), -1).unsqueeze(1)
|
205 |
+
next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
|
206 |
+
return torch.cat((X, next_x), 1), next_mask
|
207 |
+
|
208 |
+
def generate_sequence(self, batch, model, data_loader, start_idx, end_len):
|
209 |
+
# start_idx = context_size_event + 1
|
210 |
+
# start_idx = max_e1 + max_r
|
211 |
+
# end_idx = context_size_effect - 1
|
212 |
+
# end_idx = max_e2
|
213 |
+
XMB = batch["sequences"][:, :start_idx]
|
214 |
+
MMB = batch["attention_mask"][:, :start_idx]
|
215 |
+
|
216 |
+
XMB = model_utils.prepare_position_embeddings(
|
217 |
+
self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1))
|
218 |
+
|
219 |
+
tokens = []
|
220 |
+
beam_losses = []
|
221 |
+
# Beam Search
|
222 |
+
beam_lls, beam_toks, beam_seqs = None, None, None
|
223 |
+
_, lp = model(XMB.unsqueeze(1), sequence_mask=MMB)
|
224 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
225 |
+
dist = lm_probs[:, -1, :].squeeze()
|
226 |
+
beam_lls, beam_toks = dist.topk(self.opt.eval.bs)
|
227 |
+
beam_losses.append(beam_lls)
|
228 |
+
|
229 |
+
ended = (beam_toks == self.end_token).float()
|
230 |
+
counts = (2 - ended)
|
231 |
+
beam_toks = beam_toks.unsqueeze(1)
|
232 |
+
beam_seqs = beam_toks.clone()
|
233 |
+
XMB = XMB.repeat(self.opt.eval.bs, 1, 1)
|
234 |
+
MMB = MMB.repeat(self.opt.eval.bs, 1)
|
235 |
+
next_pos = XMB[:, -1:, 1] + 1
|
236 |
+
next_x = torch.cat((beam_toks, next_pos), -1).unsqueeze(1)
|
237 |
+
XMB = torch.cat((XMB, next_x), 1)
|
238 |
+
MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)
|
239 |
+
|
240 |
+
for _ in range(end_len):
|
241 |
+
|
242 |
+
# Compute distribution for current beam
|
243 |
+
_, lp = model(
|
244 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
245 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
246 |
+
dist = lm_probs[:, -1, :].squeeze()
|
247 |
+
|
248 |
+
# get hypothesis tokens for distribution
|
249 |
+
hyp_beam_lls, hyp_beam_toks = dist.topk(self.opt.eval.bs)
|
250 |
+
|
251 |
+
# Compute masks and expand beam
|
252 |
+
expanded_ended = ended.unsqueeze(1).repeat(1, self.opt.eval.bs)
|
253 |
+
hypothesis_mask = expanded_ended * self.kill_mask + (1 - expanded_ended)
|
254 |
+
|
255 |
+
paper_results = False
|
256 |
+
|
257 |
+
if paper_results:
|
258 |
+
# Results from paper with slightly buggy beam search
|
259 |
+
current_beam_lls = beam_lls.unsqueeze(1).repeat(
|
260 |
+
1, self.opt.eval.bs).view(self.opt.eval.bs**2)
|
261 |
+
else:
|
262 |
+
# Current beam search implementation
|
263 |
+
current_beam_lls = beam_losses[-1].unsqueeze(1).repeat(
|
264 |
+
1, self.opt.eval.bs).view(self.opt.eval.bs**2)
|
265 |
+
|
266 |
+
# Compute losses of hypotheses, masking those that have ended
|
267 |
+
hyp_beam_lls = (hyp_beam_lls.view(self.opt.eval.bs**2) *
|
268 |
+
hypothesis_mask.view(-1)) + current_beam_lls
|
269 |
+
|
270 |
+
# Get normalizer for sequences
|
271 |
+
temp_counts = counts.unsqueeze(1).repeat(1, self.opt.eval.bs).view(
|
272 |
+
self.opt.eval.bs ** 2)
|
273 |
+
|
274 |
+
# Select best beams with lowest aggregate loss
|
275 |
+
beam_lls, top_beam_idxs = (hyp_beam_lls / temp_counts).topk(self.opt.eval.bs)
|
276 |
+
|
277 |
+
# Update placements in beam based on selecetion
|
278 |
+
beam_losses = [i.index_select(0, top_beam_idxs // self.opt.eval.bs)
|
279 |
+
for i in beam_losses]
|
280 |
+
ended = ended.index_select(0, top_beam_idxs // self.opt.eval.bs)
|
281 |
+
counts = temp_counts.index_select(0, top_beam_idxs)
|
282 |
+
|
283 |
+
# Save beam losses
|
284 |
+
beam_losses.append(beam_lls * counts)
|
285 |
+
|
286 |
+
# Update beam tokens
|
287 |
+
ended_mask = (1 - ended).long()
|
288 |
+
end_replacement = (self.end_token * ended).long()
|
289 |
+
next_toks = hyp_beam_toks.view(-1)[top_beam_idxs]
|
290 |
+
beam_toks = next_toks * ended_mask + end_replacement
|
291 |
+
|
292 |
+
# Update ended and counts
|
293 |
+
ended = ended + (beam_toks == self.end_token).float() * (1 - ended)
|
294 |
+
counts = counts + (1 - ended)
|
295 |
+
|
296 |
+
# Update beam sequences
|
297 |
+
beam_seqs = beam_seqs.t().repeat(self.opt.eval.bs, 1).t().contiguous().view(
|
298 |
+
self.opt.eval.bs**2, -1)[top_beam_idxs]
|
299 |
+
beam_seqs = torch.cat((beam_seqs, beam_toks.unsqueeze(1)), dim=1)
|
300 |
+
|
301 |
+
# I have no idea what's going on but Ari's on point with it
|
302 |
+
XMB = XMB.transpose(0, 1).transpose(1, 2).repeat(
|
303 |
+
self.opt.eval.bs, 1, 1).transpose(2, 1).transpose(
|
304 |
+
1, 0).contiguous().view(
|
305 |
+
self.opt.eval.bs**2, XMB.size(1), XMB.size(2))[top_beam_idxs]
|
306 |
+
|
307 |
+
XMB, MMB = self.append_batch(XMB, beam_toks, MMB)
|
308 |
+
|
309 |
+
if (beam_toks == self.end_token).sum().item() == self.opt.eval.bs:
|
310 |
+
break
|
311 |
+
|
312 |
+
beams = []
|
313 |
+
|
314 |
+
for beam in beam_seqs:
|
315 |
+
beams.append(" ".join("".join(
|
316 |
+
[data_loader.vocab_decoder[tok.item()].replace(
|
317 |
+
'</w>', ' ').replace('\n', '')
|
318 |
+
for tok in beam if tok != self.end_token]).split()))
|
319 |
+
|
320 |
+
sampling_result = {
|
321 |
+
"sequence": beams[0],
|
322 |
+
"beams": beams,
|
323 |
+
"beam_losses": beam_lls.tolist(),
|
324 |
+
"loss": beam_lls[0].item(),
|
325 |
+
"beam_lengths": counts.tolist(),
|
326 |
+
"length": counts[0].item()
|
327 |
+
}
|
328 |
+
|
329 |
+
return sampling_result
|
Model/COSMIC/feature_extraction/comet/src/evaluate/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def update_classification_losses(losses, nums, name, bs, loss):
|
3 |
+
if not isinstance(loss, float):
|
4 |
+
print(type(loss))
|
5 |
+
raise
|
6 |
+
|
7 |
+
nums[name] += bs
|
8 |
+
|
9 |
+
losses[name] += loss * bs
|
10 |
+
|
11 |
+
|
12 |
+
def update_generation_losses(losses, nums, micro, macro, bs, length, loss):
|
13 |
+
# Update Losses
|
14 |
+
nums[macro] += bs
|
15 |
+
|
16 |
+
if isinstance(length, int):
|
17 |
+
update_indiv_generation_losses(
|
18 |
+
losses, nums, micro, macro, bs, length, loss)
|
19 |
+
else:
|
20 |
+
update_tensor_generation_losses(
|
21 |
+
losses, nums, micro, macro, bs, length, loss)
|
22 |
+
|
23 |
+
|
24 |
+
def update_indiv_generation_losses(losses, nums, micro,
|
25 |
+
macro, bs, length, loss):
|
26 |
+
nums[micro] += bs * length
|
27 |
+
|
28 |
+
batch_loss = loss * bs
|
29 |
+
|
30 |
+
losses[micro] += batch_loss
|
31 |
+
losses[macro] += batch_loss / length
|
32 |
+
|
33 |
+
|
34 |
+
def update_tensor_generation_losses(losses, nums, micro,
|
35 |
+
macro, bs, length, loss):
|
36 |
+
nums[micro] += length.sum().item()
|
37 |
+
|
38 |
+
losses[micro] += loss.sum().item()
|
39 |
+
losses[macro] += (loss / length.float()).sum().item()
|
Model/COSMIC/feature_extraction/comet/src/interactive/functions.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from comet.src.data.utils import TextEncoder
|
5 |
+
import comet.src.data.config as cfg
|
6 |
+
import comet.src.data.data as data
|
7 |
+
import comet.src.models.models as models
|
8 |
+
from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler
|
9 |
+
import comet.utils.utils as utils
|
10 |
+
|
11 |
+
|
12 |
+
def load_model_file(model_file):
|
13 |
+
model_stuff = data.load_checkpoint(model_file)
|
14 |
+
opt = model_stuff["opt"]
|
15 |
+
state_dict = model_stuff["state_dict"]
|
16 |
+
|
17 |
+
return opt, state_dict
|
18 |
+
|
19 |
+
|
20 |
+
def load_data(dataset, opt, dir="."):
|
21 |
+
if dataset == "atomic":
|
22 |
+
data_loader = load_atomic_data(opt, dir)
|
23 |
+
elif dataset == "conceptnet":
|
24 |
+
data_loader = load_conceptnet_data(opt, dir)
|
25 |
+
|
26 |
+
# Initialize TextEncoder
|
27 |
+
encoder_path = os.path.join(dir, "comet/model/encoder_bpe_40000.json")
|
28 |
+
bpe_path = os.path.join(dir, "comet/model/vocab_40000.bpe")
|
29 |
+
text_encoder = TextEncoder(encoder_path, bpe_path)
|
30 |
+
text_encoder.encoder = data_loader.vocab_encoder
|
31 |
+
text_encoder.decoder = data_loader.vocab_decoder
|
32 |
+
|
33 |
+
return data_loader, text_encoder
|
34 |
+
|
35 |
+
|
36 |
+
def load_atomic_data(opt, dir="."):
|
37 |
+
# Hacky workaround, you may have to change this
|
38 |
+
# if your models use different pad lengths for e1, e2, r
|
39 |
+
if opt.data.get("maxe1", None) is None:
|
40 |
+
opt.data.maxe1 = 17
|
41 |
+
opt.data.maxe2 = 35
|
42 |
+
opt.data.maxr = 1
|
43 |
+
|
44 |
+
# temporarily change to the target directory
|
45 |
+
current_dir = os.getcwd()
|
46 |
+
os.chdir(dir)
|
47 |
+
|
48 |
+
path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle"
|
49 |
+
data_loader = data.make_data_loader(opt, opt.data.categories)
|
50 |
+
loaded = data_loader.load_data(path)
|
51 |
+
|
52 |
+
# go back to the original working directory
|
53 |
+
os.chdir(current_dir)
|
54 |
+
|
55 |
+
return data_loader
|
56 |
+
|
57 |
+
|
58 |
+
def load_conceptnet_data(opt, dir="."):
|
59 |
+
# Hacky workaround, you may have to change this
|
60 |
+
# if your models use different pad lengths for r
|
61 |
+
if opt.data.get("maxr", None) is None:
|
62 |
+
if opt.data.rel == "language":
|
63 |
+
opt.data.maxr = 5
|
64 |
+
else:
|
65 |
+
opt.data.maxr = 1
|
66 |
+
|
67 |
+
# temporarily change to the target directory
|
68 |
+
current_dir = os.getcwd()
|
69 |
+
os.chdir(dir)
|
70 |
+
|
71 |
+
path = "comet/data/conceptnet/processed/generation/{}.pickle".format(
|
72 |
+
utils.make_name_string(opt.data)
|
73 |
+
)
|
74 |
+
data_loader = data.make_data_loader(opt)
|
75 |
+
loaded = data_loader.load_data(path)
|
76 |
+
|
77 |
+
# go back to the original working directory
|
78 |
+
os.chdir(current_dir)
|
79 |
+
|
80 |
+
return data_loader
|
81 |
+
|
82 |
+
|
83 |
+
def make_model(opt, n_vocab, n_ctx, state_dict):
|
84 |
+
model = models.make_model(
|
85 |
+
opt, n_vocab, n_ctx, None, load=False, return_acts=True, return_probs=False
|
86 |
+
)
|
87 |
+
|
88 |
+
models.load_state_dict(model, state_dict)
|
89 |
+
|
90 |
+
model.eval()
|
91 |
+
return model
|
92 |
+
|
93 |
+
|
94 |
+
def set_sampler(opt, sampling_algorithm, data_loader):
|
95 |
+
if "beam" in sampling_algorithm:
|
96 |
+
opt.eval.bs = int(sampling_algorithm.split("-")[1])
|
97 |
+
sampler = BeamSampler(opt, data_loader)
|
98 |
+
elif "topk" in sampling_algorithm:
|
99 |
+
# print("Still bugs in the topk sampler. Use beam or greedy instead")
|
100 |
+
# raise NotImplementedError
|
101 |
+
opt.eval.k = int(sampling_algorithm.split("-")[1])
|
102 |
+
sampler = TopKSampler(opt, data_loader)
|
103 |
+
else:
|
104 |
+
sampler = GreedySampler(opt, data_loader)
|
105 |
+
|
106 |
+
return sampler
|
107 |
+
|
108 |
+
|
109 |
+
def get_atomic_sequence(
|
110 |
+
input_event, model, sampler, data_loader, text_encoder, category
|
111 |
+
):
|
112 |
+
if isinstance(category, list):
|
113 |
+
outputs = {}
|
114 |
+
for cat in category:
|
115 |
+
new_outputs = get_atomic_sequence(
|
116 |
+
input_event, model, sampler, data_loader, text_encoder, cat
|
117 |
+
)
|
118 |
+
outputs.update(new_outputs)
|
119 |
+
return outputs
|
120 |
+
elif category == "all":
|
121 |
+
outputs = {}
|
122 |
+
|
123 |
+
for category in data_loader.categories:
|
124 |
+
new_outputs = get_atomic_sequence(
|
125 |
+
input_event, model, sampler, data_loader, text_encoder, category
|
126 |
+
)
|
127 |
+
outputs.update(new_outputs)
|
128 |
+
return outputs
|
129 |
+
else:
|
130 |
+
sequence_all = {}
|
131 |
+
|
132 |
+
sequence_all["event"] = input_event
|
133 |
+
sequence_all["effect_type"] = category
|
134 |
+
|
135 |
+
with torch.no_grad():
|
136 |
+
batch = set_atomic_inputs(input_event, category, data_loader, text_encoder)
|
137 |
+
|
138 |
+
sampling_result = sampler.generate_sequence(
|
139 |
+
batch,
|
140 |
+
model,
|
141 |
+
data_loader,
|
142 |
+
data_loader.max_event
|
143 |
+
+ data.atomic_data.num_delimiter_tokens["category"],
|
144 |
+
data_loader.max_effect
|
145 |
+
- data.atomic_data.num_delimiter_tokens["category"],
|
146 |
+
)
|
147 |
+
|
148 |
+
sequence_all["beams"] = sampling_result["beams"]
|
149 |
+
|
150 |
+
# print_atomic_sequence(sequence_all)
|
151 |
+
|
152 |
+
return {category: sequence_all}
|
153 |
+
|
154 |
+
|
155 |
+
def print_atomic_sequence(sequence_object):
|
156 |
+
input_event = sequence_object["event"]
|
157 |
+
category = sequence_object["effect_type"]
|
158 |
+
|
159 |
+
print("Input Event: {}".format(input_event))
|
160 |
+
print("Target Effect: {}".format(category))
|
161 |
+
print("")
|
162 |
+
print("Candidate Sequences:")
|
163 |
+
for beam in sequence_object["beams"]:
|
164 |
+
print(beam)
|
165 |
+
print("")
|
166 |
+
print("====================================================")
|
167 |
+
print("")
|
168 |
+
|
169 |
+
|
170 |
+
def set_atomic_inputs(input_event, category, data_loader, text_encoder):
|
171 |
+
XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device)
|
172 |
+
prefix, suffix = data.atomic_data.do_example(
|
173 |
+
text_encoder, input_event, None, True, None
|
174 |
+
)
|
175 |
+
|
176 |
+
if len(prefix) > data_loader.max_event + 1:
|
177 |
+
prefix = prefix[: data_loader.max_event + 1]
|
178 |
+
|
179 |
+
XMB[:, : len(prefix)] = torch.LongTensor(prefix)
|
180 |
+
XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]])
|
181 |
+
|
182 |
+
batch = {}
|
183 |
+
batch["sequences"] = XMB
|
184 |
+
batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB)
|
185 |
+
|
186 |
+
return batch
|
187 |
+
|
188 |
+
|
189 |
+
def get_conceptnet_sequence(
|
190 |
+
e1, model, sampler, data_loader, text_encoder, relation, force=False
|
191 |
+
):
|
192 |
+
if isinstance(relation, list):
|
193 |
+
outputs = {}
|
194 |
+
|
195 |
+
for rel in relation:
|
196 |
+
new_outputs = get_conceptnet_sequence(
|
197 |
+
e1, model, sampler, data_loader, text_encoder, rel
|
198 |
+
)
|
199 |
+
outputs.update(new_outputs)
|
200 |
+
return outputs
|
201 |
+
elif relation == "all":
|
202 |
+
outputs = {}
|
203 |
+
|
204 |
+
for relation in data.conceptnet_data.conceptnet_relations:
|
205 |
+
new_outputs = get_conceptnet_sequence(
|
206 |
+
e1, model, sampler, data_loader, text_encoder, relation
|
207 |
+
)
|
208 |
+
outputs.update(new_outputs)
|
209 |
+
return outputs
|
210 |
+
else:
|
211 |
+
sequence_all = {}
|
212 |
+
|
213 |
+
sequence_all["e1"] = e1
|
214 |
+
sequence_all["relation"] = relation
|
215 |
+
|
216 |
+
with torch.no_grad():
|
217 |
+
if data_loader.max_r != 1:
|
218 |
+
relation_sequence = data.conceptnet_data.split_into_words[relation]
|
219 |
+
else:
|
220 |
+
relation_sequence = "<{}>".format(relation)
|
221 |
+
|
222 |
+
batch, abort = set_conceptnet_inputs(
|
223 |
+
e1,
|
224 |
+
relation_sequence,
|
225 |
+
text_encoder,
|
226 |
+
data_loader.max_e1,
|
227 |
+
data_loader.max_r,
|
228 |
+
force,
|
229 |
+
)
|
230 |
+
|
231 |
+
if abort:
|
232 |
+
return {relation: sequence_all}
|
233 |
+
|
234 |
+
sampling_result = sampler.generate_sequence(
|
235 |
+
batch,
|
236 |
+
model,
|
237 |
+
data_loader,
|
238 |
+
data_loader.max_e1 + data_loader.max_r,
|
239 |
+
data_loader.max_e2,
|
240 |
+
)
|
241 |
+
|
242 |
+
sequence_all["beams"] = sampling_result["beams"]
|
243 |
+
|
244 |
+
print_conceptnet_sequence(sequence_all)
|
245 |
+
|
246 |
+
return {relation: sequence_all}
|
247 |
+
|
248 |
+
|
249 |
+
def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force):
|
250 |
+
abort = False
|
251 |
+
|
252 |
+
e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example(
|
253 |
+
text_encoder, input_event, relation, None
|
254 |
+
)
|
255 |
+
|
256 |
+
if len(e1_tokens) > max_e1:
|
257 |
+
if force:
|
258 |
+
XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device)
|
259 |
+
else:
|
260 |
+
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device)
|
261 |
+
return {}, True
|
262 |
+
else:
|
263 |
+
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device)
|
264 |
+
|
265 |
+
XMB[:, : len(e1_tokens)] = torch.LongTensor(e1_tokens)
|
266 |
+
XMB[:, max_e1 : max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens)
|
267 |
+
|
268 |
+
batch = {}
|
269 |
+
batch["sequences"] = XMB
|
270 |
+
batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB)
|
271 |
+
|
272 |
+
return batch, abort
|
273 |
+
|
274 |
+
|
275 |
+
def print_conceptnet_sequence(sequence_object):
|
276 |
+
e1 = sequence_object["e1"]
|
277 |
+
relation = sequence_object["relation"]
|
278 |
+
|
279 |
+
print("Input Entity: {}".format(e1))
|
280 |
+
print("Target Relation: {}".format(relation))
|
281 |
+
print("")
|
282 |
+
print("Candidate Sequences:")
|
283 |
+
for beam in sequence_object["beams"]:
|
284 |
+
print(beam)
|
285 |
+
print("")
|
286 |
+
print("====================================================")
|
287 |
+
print("")
|
288 |
+
|
289 |
+
|
290 |
+
def print_help(data):
|
291 |
+
print("")
|
292 |
+
if data == "atomic":
|
293 |
+
print('Provide a seed event such as "PersonX goes to the mall"')
|
294 |
+
print("Don't include names, instead replacing them with PersonX, PersonY, etc.")
|
295 |
+
print("The event should always have PersonX included")
|
296 |
+
if data == "conceptnet":
|
297 |
+
print('Provide a seed entity such as "go to the mall"')
|
298 |
+
print("Because the model was trained on lemmatized entities,")
|
299 |
+
print("it works best if the input entities are also lemmatized")
|
300 |
+
print("")
|
301 |
+
|
302 |
+
|
303 |
+
def print_relation_help(data):
|
304 |
+
print_category_help(data)
|
305 |
+
|
306 |
+
|
307 |
+
def print_category_help(data):
|
308 |
+
print("")
|
309 |
+
if data == "atomic":
|
310 |
+
print("Enter a possible effect type from the following effect types:")
|
311 |
+
print(
|
312 |
+
"all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}"
|
313 |
+
)
|
314 |
+
print(
|
315 |
+
"oEffect - generate the effect of the event on participants other than PersonX"
|
316 |
+
)
|
317 |
+
print(
|
318 |
+
"oReact - generate the reactions of participants other than PersonX to the event"
|
319 |
+
)
|
320 |
+
print(
|
321 |
+
"oEffect - generate what participants other than PersonX may want after the event"
|
322 |
+
)
|
323 |
+
elif data == "conceptnet":
|
324 |
+
print("Enter a possible relation from the following list:")
|
325 |
+
print("")
|
326 |
+
print("AtLocation")
|
327 |
+
print("CapableOf")
|
328 |
+
print("Causes")
|
329 |
+
print("CausesDesire")
|
330 |
+
print("CreatedBy")
|
331 |
+
print("DefinedAs")
|
332 |
+
print("DesireOf")
|
333 |
+
print("Desires")
|
334 |
+
print("HasA")
|
335 |
+
print("HasFirstSubevent")
|
336 |
+
print("HasLastSubevent")
|
337 |
+
print("HasPainCharacter")
|
338 |
+
print("HasPainIntensity")
|
339 |
+
print("HasPrerequisite")
|
340 |
+
print("HasProperty")
|
341 |
+
print("HasSubevent")
|
342 |
+
print("InheritsFrom")
|
343 |
+
print("InstanceOf")
|
344 |
+
print("IsA")
|
345 |
+
print("LocatedNear")
|
346 |
+
print("LocationOfAction")
|
347 |
+
print("MadeOf")
|
348 |
+
print("MotivatedByGoal")
|
349 |
+
print("NotCapableOf")
|
350 |
+
print("NotDesires")
|
351 |
+
print("NotHasA")
|
352 |
+
print("NotHasProperty")
|
353 |
+
print("NotIsA")
|
354 |
+
print("NotMadeOf")
|
355 |
+
print("PartOf")
|
356 |
+
print("ReceivesAction")
|
357 |
+
print("RelatedTo")
|
358 |
+
print("SymbolOf")
|
359 |
+
print("UsedFor")
|
360 |
+
print("")
|
361 |
+
print("NOTE: Capitalization is important")
|
362 |
+
else:
|
363 |
+
raise
|
364 |
+
print("")
|
365 |
+
|
366 |
+
|
367 |
+
def print_sampling_help():
|
368 |
+
print("")
|
369 |
+
print(
|
370 |
+
"Provide a sampling algorithm to produce the sequence with from the following:"
|
371 |
+
)
|
372 |
+
print("")
|
373 |
+
print("greedy")
|
374 |
+
print("beam-# where # is the beam size")
|
375 |
+
print("topk-# where # is k")
|
376 |
+
print("")
|
Model/COSMIC/feature_extraction/comet/src/main.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument("--experiment_type", type=str, default='atomic',
|
9 |
+
choices=["atomic", "conceptnet"])
|
10 |
+
parser.add_argument("--experiment_num", type=str, default="0")
|
11 |
+
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
if args.experiment_type == "atomic":
|
15 |
+
from main_atomic import main
|
16 |
+
main(args.experiment_num)
|
17 |
+
if args.experiment_type == "conceptnet":
|
18 |
+
from main_conceptnet import main
|
19 |
+
main(args.experiment_num)
|
Model/COSMIC/feature_extraction/comet/src/main_atomic.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import comet.src.train.atomic_train as train
|
7 |
+
import comet.src.models.models as models
|
8 |
+
import comet.src.data.data as data
|
9 |
+
import comet.utils.utils as utils
|
10 |
+
import comet.src.train.utils as train_utils
|
11 |
+
import comet.src.data.config as cfg
|
12 |
+
|
13 |
+
from comet.src.data.utils import TextEncoder
|
14 |
+
from comet.src.train.opt import OpenAIAdam
|
15 |
+
|
16 |
+
|
17 |
+
def main(num):
|
18 |
+
# Generate configuration files depending on experiment being run
|
19 |
+
utils.generate_config_files("atomic", num)
|
20 |
+
|
21 |
+
# Loads the correct configuration file
|
22 |
+
config_file = "config/atomic/config_{}.json".format(num)
|
23 |
+
|
24 |
+
print(config_file)
|
25 |
+
|
26 |
+
# Read config file to option
|
27 |
+
config = cfg.read_config(cfg.load_config(config_file))
|
28 |
+
opt, meta = cfg.get_parameters(config)
|
29 |
+
|
30 |
+
# Set the random seeds
|
31 |
+
torch.manual_seed(opt.train.static.seed)
|
32 |
+
random.seed(opt.train.static.seed)
|
33 |
+
if config.gpu_mode:
|
34 |
+
torch.cuda.manual_seed_all(opt.train.static.seed)
|
35 |
+
|
36 |
+
# Where to find the data
|
37 |
+
splits = ["train", "dev", "test"]
|
38 |
+
|
39 |
+
opt.train.dynamic.epoch = 0
|
40 |
+
|
41 |
+
print("Loading Data")
|
42 |
+
|
43 |
+
categories = opt.data.categories
|
44 |
+
|
45 |
+
path = "data/atomic/processed/{}/{}.pickle".format(
|
46 |
+
opt.exp, utils.make_name_string(opt.data))
|
47 |
+
|
48 |
+
data_loader = data.make_data_loader(opt, categories)
|
49 |
+
loaded = data_loader.load_data(path)
|
50 |
+
print(data_loader.sequences["train"]["total"].size(0))
|
51 |
+
data_loader.opt = opt
|
52 |
+
data_loader.batch_size = opt.train.dynamic.bs
|
53 |
+
|
54 |
+
print("Done.")
|
55 |
+
|
56 |
+
# Initialize text_encoder
|
57 |
+
text_encoder = TextEncoder(config.encoder_path, config.bpe_path)
|
58 |
+
|
59 |
+
special = [data.start_token, data.end_token]
|
60 |
+
special += ["<{}>".format(cat) for cat in categories]
|
61 |
+
special += [data.blank_token]
|
62 |
+
|
63 |
+
text_encoder.encoder = data_loader.vocab_encoder
|
64 |
+
text_encoder.decoder = data_loader.vocab_decoder
|
65 |
+
|
66 |
+
opt.data.maxe1 = data_loader.max_event
|
67 |
+
opt.data.maxe2 = data_loader.max_effect
|
68 |
+
opt.data.maxr = data.atomic_data.num_delimiter_tokens["category"]
|
69 |
+
|
70 |
+
n_special = len(special)
|
71 |
+
n_ctx = opt.data.maxe1 + opt.data.maxe2
|
72 |
+
n_vocab = len(text_encoder.encoder) + n_ctx
|
73 |
+
|
74 |
+
print(data_loader.__dict__.keys())
|
75 |
+
opt.net.vSize = n_vocab
|
76 |
+
|
77 |
+
print("Building Model")
|
78 |
+
|
79 |
+
model = models.make_model(
|
80 |
+
opt, n_vocab, n_ctx, n_special,
|
81 |
+
load=(opt.net.init=="pt"))
|
82 |
+
|
83 |
+
print("Done.")
|
84 |
+
|
85 |
+
print("Files will be logged at: {}".format(
|
86 |
+
utils.make_name(opt, prefix="results/losses/",
|
87 |
+
is_dir=True, eval_=True)))
|
88 |
+
|
89 |
+
data_loader.reset_offsets("train")
|
90 |
+
|
91 |
+
# Get number of examples
|
92 |
+
data.set_max_sizes(data_loader)
|
93 |
+
|
94 |
+
if config.gpu_mode:
|
95 |
+
print("Pushing to GPU: {}".format(config.gpu_index))
|
96 |
+
cfg.device = config.gpu_index
|
97 |
+
cfg.do_gpu = True
|
98 |
+
torch.cuda.set_device(cfg.device)
|
99 |
+
if config.multigpu:
|
100 |
+
model = models.multi_gpu(
|
101 |
+
model, config.gpu_indices).cuda()
|
102 |
+
else:
|
103 |
+
model.cuda(cfg.device)
|
104 |
+
print("Done.")
|
105 |
+
|
106 |
+
print("Training")
|
107 |
+
|
108 |
+
optimizer = OpenAIAdam(model.parameters(),
|
109 |
+
lr=opt.train.dynamic.lr,
|
110 |
+
schedule=opt.train.static.lrsched,
|
111 |
+
warmup=opt.train.static.lrwarm,
|
112 |
+
t_total=meta.iterations,
|
113 |
+
b1=opt.train.static.b1,
|
114 |
+
b2=opt.train.static.b2,
|
115 |
+
e=opt.train.static.e,
|
116 |
+
l2=opt.train.static.l2,
|
117 |
+
vector_l2=opt.train.static.vl2,
|
118 |
+
max_grad_norm=opt.train.static.clip)
|
119 |
+
|
120 |
+
scorers = ["bleu", "rouge", "cider"]
|
121 |
+
trainer = train.make_trainer(
|
122 |
+
opt, meta, data_loader, model, optimizer)
|
123 |
+
trainer.set_evaluator(opt, model, data_loader)
|
124 |
+
|
125 |
+
trainer.run()
|
Model/COSMIC/feature_extraction/comet/src/main_conceptnet.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import comet.src.train.conceptnet_train as train
|
7 |
+
import comet.src.models.models as models
|
8 |
+
import comet.src.data.data as data
|
9 |
+
import comet.utils.utils as utils
|
10 |
+
import comet.src.train.utils as train_utils
|
11 |
+
import comet.src.data.config as cfg
|
12 |
+
|
13 |
+
from comet.src.data.utils import TextEncoder
|
14 |
+
from comet.src.train.opt import OpenAIAdam
|
15 |
+
|
16 |
+
|
17 |
+
def main(num):
|
18 |
+
# Generate configuration files depending on experiment being run
|
19 |
+
utils.generate_config_files("conceptnet", num)
|
20 |
+
|
21 |
+
# Loads the correct configuration file
|
22 |
+
config_file = "config/conceptnet/config_{}.json".format(num)
|
23 |
+
|
24 |
+
print(config_file)
|
25 |
+
|
26 |
+
# Read config file to option
|
27 |
+
config = cfg.read_config(cfg.load_config(config_file))
|
28 |
+
opt, meta = cfg.get_parameters(config)
|
29 |
+
|
30 |
+
# config.gpu_mode = torch.cuda.is_available()
|
31 |
+
|
32 |
+
# Set the random seeds
|
33 |
+
torch.manual_seed(opt.train.static.seed)
|
34 |
+
random.seed(opt.train.static.seed)
|
35 |
+
if config.gpu_mode:
|
36 |
+
torch.cuda.manual_seed_all(opt.train.static.seed)
|
37 |
+
|
38 |
+
# Load the data
|
39 |
+
splits = ["train", "dev", "test"]
|
40 |
+
|
41 |
+
opt.train.dynamic.epoch = 0
|
42 |
+
|
43 |
+
print("Loading Data")
|
44 |
+
|
45 |
+
# Initialize path to pre-set data loader
|
46 |
+
path = "data/conceptnet/processed/{}/{}.pickle".format(
|
47 |
+
opt.exp, utils.make_name_string(opt.data))
|
48 |
+
|
49 |
+
# Make data loader
|
50 |
+
data_loader = data.make_data_loader(opt)
|
51 |
+
loaded = data_loader.load_data(path)
|
52 |
+
print(data_loader.sequences["train"]["total"].size(0))
|
53 |
+
data_loader.opt = opt
|
54 |
+
data_loader.batch_size = opt.train.dynamic.bs
|
55 |
+
|
56 |
+
print("Done.")
|
57 |
+
|
58 |
+
text_encoder = TextEncoder(config.encoder_path, config.bpe_path)
|
59 |
+
|
60 |
+
categories = data.conceptnet_data.conceptnet_relations
|
61 |
+
|
62 |
+
special = [data.start_token, data.end_token]
|
63 |
+
special += ["<{}>".format(cat) for cat in categories]
|
64 |
+
|
65 |
+
if loaded:
|
66 |
+
text_encoder.encoder = data_loader.vocab_encoder
|
67 |
+
text_encoder.decoder = data_loader.vocab_decoder
|
68 |
+
else:
|
69 |
+
for special_token in special:
|
70 |
+
text_encoder.decoder[len(encoder)] = special_token
|
71 |
+
text_encoder.encoder[special_token] = len(encoder)
|
72 |
+
data_loader.make_tensors(text_encoder, special)
|
73 |
+
|
74 |
+
# Set max size of different parts of relation
|
75 |
+
context_size_e1 = data_loader.max_e1
|
76 |
+
context_size_e2 = data_loader.max_e2
|
77 |
+
context_size_r = data_loader.max_r
|
78 |
+
|
79 |
+
opt.data.maxr = context_size_r
|
80 |
+
|
81 |
+
n_special = len(special)
|
82 |
+
n_ctx = context_size_e1 + context_size_r + context_size_e2
|
83 |
+
n_vocab = len(text_encoder.encoder) + n_ctx
|
84 |
+
|
85 |
+
print(data_loader.__dict__.keys())
|
86 |
+
opt.net.vSize = n_vocab
|
87 |
+
|
88 |
+
# Build Model
|
89 |
+
print("Building Model")
|
90 |
+
|
91 |
+
model = models.make_model(
|
92 |
+
opt, n_vocab, n_ctx, n_special,
|
93 |
+
load=(opt.net.init=="pt"))
|
94 |
+
|
95 |
+
print("Done.")
|
96 |
+
|
97 |
+
print("Files will be logged at: {}".format(
|
98 |
+
utils.make_name(opt, prefix="results/losses/",
|
99 |
+
is_dir=True, eval_=True)))
|
100 |
+
|
101 |
+
data_loader.reset_offsets("train", keys=["total"])
|
102 |
+
|
103 |
+
data.set_max_sizes(data_loader)
|
104 |
+
|
105 |
+
# Push to GPU
|
106 |
+
if config.gpu_mode:
|
107 |
+
print("Pushing to GPU: {}".format(config.gpu_index))
|
108 |
+
cfg.device = config.gpu_index
|
109 |
+
cfg.do_gpu = True
|
110 |
+
torch.cuda.set_device(cfg.device)
|
111 |
+
if config.multigpu:
|
112 |
+
model = models.multi_gpu(
|
113 |
+
model, config.gpu_indices).cuda()
|
114 |
+
else:
|
115 |
+
model.cuda(cfg.device)
|
116 |
+
print("Done.")
|
117 |
+
|
118 |
+
print("Training")
|
119 |
+
|
120 |
+
optimizer = OpenAIAdam(model.parameters(),
|
121 |
+
lr=opt.train.dynamic.lr,
|
122 |
+
schedule=opt.train.static.lrsched,
|
123 |
+
warmup=opt.train.static.lrwarm,
|
124 |
+
t_total=meta.iterations,
|
125 |
+
b1=opt.train.static.b1,
|
126 |
+
b2=opt.train.static.b2,
|
127 |
+
e=opt.train.static.e,
|
128 |
+
l2=opt.train.static.l2,
|
129 |
+
vector_l2=opt.train.static.vl2,
|
130 |
+
max_grad_norm=opt.train.static.clip)
|
131 |
+
|
132 |
+
trainer = train.make_trainer(
|
133 |
+
opt, meta, data_loader, model, optimizer)
|
134 |
+
print(data_loader.sequences["dev"]["total"].max())
|
135 |
+
trainer.set_generator(opt, model, data_loader)
|
136 |
+
trainer.set_evaluator(opt, model, data_loader)
|
137 |
+
|
138 |
+
trainer.run()
|
Model/COSMIC/feature_extraction/comet/src/models/gpt.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import re
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
|
12 |
+
|
13 |
+
'''
|
14 |
+
Much of this code is taken from HuggingFace's OpenAI LM Implementation here:
|
15 |
+
|
16 |
+
https://github.com/huggingface/pytorch-openai-transformer-lm
|
17 |
+
'''
|
18 |
+
|
19 |
+
|
20 |
+
def gelu(x):
|
21 |
+
return (0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) *
|
22 |
+
(x + 0.044715 * torch.pow(x, 3)))))
|
23 |
+
|
24 |
+
|
25 |
+
def swish(x):
|
26 |
+
return x * torch.sigmoid(x)
|
27 |
+
|
28 |
+
|
29 |
+
ACT_FNS = {
|
30 |
+
'relu': nn.ReLU,
|
31 |
+
'swish': swish,
|
32 |
+
'gelu': gelu
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
class LayerNorm(nn.Module):
|
37 |
+
"Construct a layernorm module in the OpenAI style \
|
38 |
+
(epsilon inside the square root)."
|
39 |
+
|
40 |
+
def __init__(self, n_state, e=1e-5):
|
41 |
+
super(LayerNorm, self).__init__()
|
42 |
+
self.g = nn.Parameter(torch.ones(n_state))
|
43 |
+
self.b = nn.Parameter(torch.zeros(n_state))
|
44 |
+
self.e = e
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
u = x.mean(-1, keepdim=True)
|
48 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
49 |
+
x = (x - u) / torch.sqrt(s + self.e)
|
50 |
+
return self.g * x + self.b
|
51 |
+
|
52 |
+
|
53 |
+
class Conv1D(nn.Module):
|
54 |
+
def __init__(self, nf, rf, nx):
|
55 |
+
super(Conv1D, self).__init__()
|
56 |
+
self.rf = rf
|
57 |
+
self.nf = nf
|
58 |
+
if rf == 1: # faster 1x1 conv
|
59 |
+
w = torch.empty(nx, nf)
|
60 |
+
nn.init.normal_(w, std=0.02)
|
61 |
+
self.w = Parameter(w)
|
62 |
+
self.b = Parameter(torch.zeros(nf))
|
63 |
+
else: # was used to train LM
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
if self.rf == 1:
|
68 |
+
size_out = x.size()[:-1] + (self.nf,)
|
69 |
+
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
|
70 |
+
x = x.view(*size_out)
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class Attention(nn.Module):
|
77 |
+
def __init__(self, nx, n_ctx, cfg, scale=False):
|
78 |
+
super(Attention, self).__init__()
|
79 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
80 |
+
|
81 |
+
assert n_state % cfg.nH == 0
|
82 |
+
self.register_buffer('b', torch.tril(torch.ones(
|
83 |
+
n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
84 |
+
self.n_head = cfg.nH
|
85 |
+
self.split_size = n_state
|
86 |
+
self.scale = scale
|
87 |
+
self.c_attn = Conv1D(n_state * 3, 1, nx)
|
88 |
+
self.c_proj = Conv1D(n_state, 1, nx)
|
89 |
+
self.attn_dropout = nn.Dropout(cfg.adpt)
|
90 |
+
self.resid_dropout = nn.Dropout(cfg.rdpt)
|
91 |
+
|
92 |
+
# dimensions of w: (batch_size x num_heads x seq_length x seq_length)
|
93 |
+
def _attn(self, q, k, v, sequence_mask):
|
94 |
+
w = torch.matmul(q, k)
|
95 |
+
if self.scale:
|
96 |
+
w = w / math.sqrt(v.size(-1))
|
97 |
+
|
98 |
+
b_subset = self.b[:, :, :w.size(-2), :w.size(-1)]
|
99 |
+
|
100 |
+
if sequence_mask is not None:
|
101 |
+
b_subset = b_subset * sequence_mask.view(
|
102 |
+
sequence_mask.size(0), 1, -1)
|
103 |
+
b_subset = b_subset.permute(1, 0, 2, 3)
|
104 |
+
|
105 |
+
w = w * b_subset + -1e9 * (1 - b_subset)
|
106 |
+
w = nn.Softmax(dim=-1)(w)
|
107 |
+
w = self.attn_dropout(w)
|
108 |
+
return torch.matmul(w, v)
|
109 |
+
|
110 |
+
def merge_heads(self, x):
|
111 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
112 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
113 |
+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
114 |
+
|
115 |
+
def split_heads(self, x, k=False):
|
116 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
117 |
+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
118 |
+
if k:
|
119 |
+
return x.permute(0, 2, 3, 1)
|
120 |
+
else:
|
121 |
+
return x.permute(0, 2, 1, 3)
|
122 |
+
|
123 |
+
def forward(self, x, sequence_mask):
|
124 |
+
x = self.c_attn(x)
|
125 |
+
query, key, value = x.split(self.split_size, dim=2)
|
126 |
+
query = self.split_heads(query)
|
127 |
+
key = self.split_heads(key, k=True)
|
128 |
+
value = self.split_heads(value)
|
129 |
+
a = self._attn(query, key, value, sequence_mask)
|
130 |
+
a = self.merge_heads(a)
|
131 |
+
a = self.c_proj(a)
|
132 |
+
a = self.resid_dropout(a)
|
133 |
+
return a
|
134 |
+
|
135 |
+
|
136 |
+
class MLP(nn.Module):
|
137 |
+
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
|
138 |
+
super(MLP, self).__init__()
|
139 |
+
nx = cfg.hSize
|
140 |
+
self.c_fc = Conv1D(n_state, 1, nx)
|
141 |
+
self.c_proj = Conv1D(nx, 1, n_state)
|
142 |
+
self.act = ACT_FNS[cfg.afn]
|
143 |
+
self.dropout = nn.Dropout(cfg.rdpt)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
h = self.act(self.c_fc(x))
|
147 |
+
h2 = self.c_proj(h)
|
148 |
+
return self.dropout(h2)
|
149 |
+
|
150 |
+
|
151 |
+
class Block(nn.Module):
|
152 |
+
def __init__(self, n_ctx, cfg, scale=False):
|
153 |
+
super(Block, self).__init__()
|
154 |
+
nx = cfg.hSize
|
155 |
+
self.attn = Attention(nx, n_ctx, cfg, scale)
|
156 |
+
self.ln_1 = LayerNorm(nx)
|
157 |
+
self.mlp = MLP(4 * nx, cfg)
|
158 |
+
self.ln_2 = LayerNorm(nx)
|
159 |
+
|
160 |
+
def forward(self, x, sequence_mask):
|
161 |
+
a = self.attn(x, sequence_mask)
|
162 |
+
n = self.ln_1(x + a)
|
163 |
+
m = self.mlp(n)
|
164 |
+
h = self.ln_2(n + m)
|
165 |
+
return h
|
166 |
+
|
167 |
+
|
168 |
+
class TransformerModel(nn.Module):
|
169 |
+
""" Transformer model """
|
170 |
+
|
171 |
+
def __init__(self, cfg, vocab=40990, n_ctx=512):
|
172 |
+
super(TransformerModel, self).__init__()
|
173 |
+
self.vocab = vocab
|
174 |
+
self.embed = nn.Embedding(vocab, cfg.hSize)
|
175 |
+
self.drop = nn.Dropout(cfg.edpt)
|
176 |
+
block = Block(n_ctx, cfg, scale=True)
|
177 |
+
self.h = nn.ModuleList([copy.deepcopy(block)
|
178 |
+
for _ in range(cfg.nL)])
|
179 |
+
|
180 |
+
nn.init.normal_(self.embed.weight, std=0.02)
|
181 |
+
|
182 |
+
def forward(self, x, sequence_mask):
|
183 |
+
x = x.view(-1, x.size(-2), x.size(-1))
|
184 |
+
e = self.embed(x)
|
185 |
+
# Add the position information to the input embeddings
|
186 |
+
h = e.sum(dim=2)
|
187 |
+
for block in self.h:
|
188 |
+
h = block(h, sequence_mask)
|
189 |
+
return h
|
190 |
+
|
191 |
+
|
192 |
+
class LMModel(nn.Module):
|
193 |
+
""" Transformer with language model head only """
|
194 |
+
def __init__(self, cfg, vocab=40990, n_ctx=512,
|
195 |
+
return_probs=False, return_acts=False):
|
196 |
+
super(LMModel, self).__init__()
|
197 |
+
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
|
198 |
+
self.lm_head = LMHead(self.transformer, cfg, trunc_and_reshape=False)
|
199 |
+
self.return_probs = return_probs
|
200 |
+
self.return_acts = return_acts
|
201 |
+
if self.return_probs or self.return_acts:
|
202 |
+
pos_emb_mask = torch.zeros(1, 1, vocab)
|
203 |
+
pos_emb_mask[:, :, -n_ctx:] = -1e12
|
204 |
+
self.register_buffer('pos_emb_mask', pos_emb_mask)
|
205 |
+
|
206 |
+
def forward(self, x, sequence_mask=None):
|
207 |
+
h = self.transformer(x, sequence_mask)
|
208 |
+
lm_logits = self.lm_head(h)
|
209 |
+
if self.return_probs:
|
210 |
+
lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
|
211 |
+
elif self.return_acts:
|
212 |
+
lm_logits = lm_logits + self.pos_emb_mask
|
213 |
+
return h, lm_logits
|
214 |
+
|
215 |
+
|
216 |
+
class LMHead(nn.Module):
|
217 |
+
""" Language Model Head for the transformer """
|
218 |
+
|
219 |
+
def __init__(self, model, cfg, trunc_and_reshape=True):
|
220 |
+
super(LMHead, self).__init__()
|
221 |
+
self.n_embd = cfg.hSize
|
222 |
+
embed_shape = model.embed.weight.shape
|
223 |
+
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
224 |
+
self.decoder.weight = model.embed.weight # Tied weights
|
225 |
+
self.trunc_and_reshape = trunc_and_reshape # XD
|
226 |
+
|
227 |
+
def forward(self, h):
|
228 |
+
# Truncated Language modeling logits (we remove the last token)
|
229 |
+
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) \
|
230 |
+
if self.trunc_and_reshape else h # XD
|
231 |
+
lm_logits = self.decoder(h_trunc)
|
232 |
+
return lm_logits
|
233 |
+
|
234 |
+
|
235 |
+
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12,
|
236 |
+
n_embd=768, path='./model/', path_names='./'):
|
237 |
+
# Load weights from TF model
|
238 |
+
print("Loading weights...")
|
239 |
+
names = json.load(open(path_names + 'parameters_names.json'))
|
240 |
+
shapes = json.load(open(path + 'params_shapes.json'))
|
241 |
+
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
242 |
+
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
|
243 |
+
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
244 |
+
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
245 |
+
if n_ctx > 0:
|
246 |
+
init_params[0] = init_params[0][:n_ctx]
|
247 |
+
if n_special > 0:
|
248 |
+
init_params[0] = np.concatenate(
|
249 |
+
[init_params[1],
|
250 |
+
(np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
|
251 |
+
init_params[0]
|
252 |
+
], 0)
|
253 |
+
else:
|
254 |
+
init_params[0] = np.concatenate(
|
255 |
+
[init_params[1],
|
256 |
+
init_params[0]
|
257 |
+
], 0)
|
258 |
+
del init_params[1]
|
259 |
+
if n_transfer == -1:
|
260 |
+
n_transfer = 0
|
261 |
+
else:
|
262 |
+
n_transfer = 1 + n_transfer * 12
|
263 |
+
init_params = [arr.squeeze() for arr in init_params]
|
264 |
+
|
265 |
+
try:
|
266 |
+
assert model.embed.weight.shape == init_params[0].shape
|
267 |
+
except AssertionError as e:
|
268 |
+
e.args += (model.embed.weight.shape, init_params[0].shape)
|
269 |
+
raise
|
270 |
+
|
271 |
+
model.embed.weight.data = torch.from_numpy(init_params[0])
|
272 |
+
|
273 |
+
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
|
274 |
+
name = name[6:] # skip "model/"
|
275 |
+
assert name[-2:] == ":0"
|
276 |
+
name = name[:-2]
|
277 |
+
name = name.split('/')
|
278 |
+
pointer = model
|
279 |
+
for m_name in name:
|
280 |
+
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
281 |
+
l = re.split(r'(\d+)', m_name)
|
282 |
+
else:
|
283 |
+
l = [m_name]
|
284 |
+
pointer = getattr(pointer, l[0])
|
285 |
+
if len(l) >= 2:
|
286 |
+
num = int(l[1])
|
287 |
+
pointer = pointer[num]
|
288 |
+
try:
|
289 |
+
assert pointer.shape == ip.shape
|
290 |
+
except AssertionError as e:
|
291 |
+
e.args += (pointer.shape, ip.shape)
|
292 |
+
raise
|
293 |
+
pointer.data = torch.from_numpy(ip)
|
294 |
+
|
295 |
+
|
296 |
+
class dotdict(dict):
|
297 |
+
"""dot.notation access to dictionary attributes"""
|
298 |
+
__getattr__ = dict.get
|
299 |
+
__setattr__ = dict.__setitem__
|
300 |
+
__delattr__ = dict.__delitem__
|
301 |
+
|
302 |
+
|
303 |
+
DEFAULT_CONFIG = dotdict({
|
304 |
+
'n_embd': 768,
|
305 |
+
'n_head': 12,
|
306 |
+
'n_layer': 12,
|
307 |
+
'embd_pdrop': 0.1,
|
308 |
+
'attn_pdrop': 0.1,
|
309 |
+
'resid_pdrop': 0.1,
|
310 |
+
'afn': 'gelu',
|
311 |
+
'clf_pdrop': 0.1})
|
Model/COSMIC/feature_extraction/comet/src/models/models.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from comet.src.models.gpt import (LMModel, DEFAULT_CONFIG, load_openai_pretrained_model)
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
def make_model(opt, n_vocab, n_ctx, n_special, load=True,
|
6 |
+
return_acts=True, return_probs=False,
|
7 |
+
clf_token="<CLASS>", answer_size=None):
|
8 |
+
print(n_ctx)
|
9 |
+
if opt.exp == "generation":
|
10 |
+
model = LMModel(
|
11 |
+
opt.net, n_vocab, n_ctx, return_acts=return_acts,
|
12 |
+
return_probs=return_probs)
|
13 |
+
elif opt.exp == "classification":
|
14 |
+
model = ClfModel(
|
15 |
+
opt.net, n_vocab, n_ctx, clf_token, answer_size)
|
16 |
+
if load:
|
17 |
+
print("LOADING PRETRAINED TRANSFORMER")
|
18 |
+
load_openai_pretrained_model(
|
19 |
+
model.transformer, n_ctx=n_ctx, n_special=n_special)
|
20 |
+
return model
|
21 |
+
|
22 |
+
|
23 |
+
def multi_gpu(model, devices):
|
24 |
+
return nn.DataParallel(model, device_ids=devices)
|
25 |
+
|
26 |
+
|
27 |
+
def load_state_dict(model, state_dict):
|
28 |
+
try:
|
29 |
+
model.load_state_dict(state_dict)
|
30 |
+
except RuntimeError:
|
31 |
+
new_state_dict = {i[len("module."):]: j for i, j in state_dict.items()}
|
32 |
+
model.load_state_dict(new_state_dict)
|
Model/COSMIC/feature_extraction/comet/src/models/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def prepare_position_embeddings(opt, encoder_vocab, sequences):
|
5 |
+
vocab_size = len(encoder_vocab)
|
6 |
+
num_positions = sequences.size(-2)
|
7 |
+
position_embeddings = torch.LongTensor(
|
8 |
+
range(vocab_size, vocab_size + num_positions)).to(sequences.device)
|
9 |
+
sequences = sequences.repeat(1, 1, 2)
|
10 |
+
sequences[:, :, 1] = position_embeddings
|
11 |
+
return sequences
|
12 |
+
|
Model/COSMIC/feature_extraction/comet/src/train/atomic_train.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import comet.src.train.train as base_train
|
4 |
+
import comet.src.train.batch as batch
|
5 |
+
import comet.src.evaluate.atomic_evaluate as evaluate
|
6 |
+
# import comet.src.evaluate.atomic_generate as gen
|
7 |
+
|
8 |
+
|
9 |
+
def make_trainer(opt, *args):
|
10 |
+
return AtomicGenerationIteratorTrainer(opt, *args)
|
11 |
+
|
12 |
+
|
13 |
+
class AtomicGenerationIteratorTrainer(base_train.IteratorTrainer):
|
14 |
+
def __init__(self, opt, *args):
|
15 |
+
super(AtomicGenerationIteratorTrainer, self).__init__(opt, *args)
|
16 |
+
|
17 |
+
self.initialize_losses(opt.data.get("categories", []))
|
18 |
+
|
19 |
+
def set_evaluator(self, opt, model, data_loader):
|
20 |
+
self.evaluator = evaluate.make_evaluator(
|
21 |
+
opt, model, data_loader)
|
22 |
+
|
23 |
+
# def set_generator(self, opt, model, data_loader, scores, reward=None):
|
24 |
+
# self.generator = gen.make_generator(
|
25 |
+
# opt, model, data_loader, scores, reward)
|
26 |
+
|
27 |
+
def set_sampler(self, opt):
|
28 |
+
if opt.train.static.samp not in self.samplers:
|
29 |
+
self.samplers[opt.train.static.samp] = sampling.make_sampler(
|
30 |
+
opt.train.static.samp, opt, self.data_loader, batch_mode=True)
|
31 |
+
self.batch_variables["sampler"] = self.samplers
|
32 |
+
|
33 |
+
def batch(self, opt, *args):
|
34 |
+
outputs = batch.batch_atomic_generate(opt, *args)
|
35 |
+
|
36 |
+
token_loss = outputs["loss"]
|
37 |
+
nums = outputs["nums"]
|
38 |
+
reset = outputs["reset"]
|
39 |
+
|
40 |
+
return token_loss, nums, reset
|
41 |
+
|
42 |
+
def initialize_losses(self, categories):
|
43 |
+
self.losses["train"] = {
|
44 |
+
"total_micro": [0],
|
45 |
+
"total_macro": [0]
|
46 |
+
}
|
47 |
+
|
48 |
+
nums = {"total_micro": 0, "total_macro": 0}
|
49 |
+
|
50 |
+
for category in categories:
|
51 |
+
micro_name = "{}_micro".format(category)
|
52 |
+
macro_name = "{}_macro".format(category)
|
53 |
+
|
54 |
+
self.losses["train"][micro_name] = [0]
|
55 |
+
self.losses["train"][macro_name] = [0]
|
56 |
+
|
57 |
+
nums[micro_name] = 0
|
58 |
+
nums[macro_name] = 0
|
59 |
+
|
60 |
+
return nums
|
61 |
+
|
62 |
+
def update_top_score(self, opt):
|
63 |
+
print(self.top_score)
|
64 |
+
if self.top_score is None:
|
65 |
+
self.top_score = (self.opt.train.dynamic.epoch,
|
66 |
+
self.get_tracked_score())
|
67 |
+
elif self.get_tracked_score() < self.top_score[-1]:
|
68 |
+
self.top_score = (self.opt.train.dynamic.epoch,
|
69 |
+
self.get_tracked_score())
|
70 |
+
print(self.top_score)
|
71 |
+
|
72 |
+
def get_tracked_score(self):
|
73 |
+
return self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch]
|
74 |
+
|
75 |
+
def counter(self, nums):
|
76 |
+
return nums["total_macro"]
|
Model/COSMIC/feature_extraction/comet/src/train/batch.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import copy
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import comet.src.data.config as cfg
|
9 |
+
import comet.src.train.utils as train_utils
|
10 |
+
import comet.src.models.utils as model_utils
|
11 |
+
import comet.src.evaluate.utils as eval_utils
|
12 |
+
import comet.utils.utils as utils
|
13 |
+
from IPython import embed
|
14 |
+
|
15 |
+
|
16 |
+
##############################################################################
|
17 |
+
# BATCH
|
18 |
+
##############################################################################
|
19 |
+
|
20 |
+
|
21 |
+
def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False):
|
22 |
+
data_loader = batch_variables["data"]
|
23 |
+
model = batch_variables["model"]
|
24 |
+
split = batch_variables["split"]
|
25 |
+
|
26 |
+
batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs)
|
27 |
+
|
28 |
+
input_ = model_utils.prepare_position_embeddings(
|
29 |
+
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
|
30 |
+
attention_mask = batch["attention_mask"]
|
31 |
+
loss_mask = batch["loss_mask"]
|
32 |
+
|
33 |
+
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)
|
34 |
+
|
35 |
+
loss, dist = mle_steps(
|
36 |
+
opt.net.model, model, input_[:, :-1, :], targets,
|
37 |
+
attention_mask[:, :-1], loss_reduction="none")
|
38 |
+
|
39 |
+
# Set loss name
|
40 |
+
micro_name = "total_micro"
|
41 |
+
macro_name = "total_macro"
|
42 |
+
|
43 |
+
length = loss_mask.sum(1)
|
44 |
+
bs = input_.size(0)
|
45 |
+
|
46 |
+
final_loss = (loss * loss_mask).sum(1)
|
47 |
+
|
48 |
+
update_generation_losses(losses, nums, micro_name, macro_name, bs,
|
49 |
+
length, (loss * loss_mask).sum(1), split)
|
50 |
+
|
51 |
+
final_loss = final_loss / length
|
52 |
+
|
53 |
+
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}
|
54 |
+
|
55 |
+
return outputs
|
56 |
+
|
57 |
+
|
58 |
+
def batch_conceptnet_generate(opt, nums, losses, batch_variables,
|
59 |
+
eval_mode=False, tracking_mode=False):
|
60 |
+
data_loader = batch_variables["data"]
|
61 |
+
model = batch_variables["model"]
|
62 |
+
split = batch_variables["split"]
|
63 |
+
category = batch_variables["category"]
|
64 |
+
|
65 |
+
batch, reset = data_loader.sample_batch(
|
66 |
+
split, bs=opt.train.dynamic.bs, cat=category)
|
67 |
+
|
68 |
+
input_ = model_utils.prepare_position_embeddings(
|
69 |
+
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
|
70 |
+
attention_mask = batch["attention_mask"]
|
71 |
+
loss_mask = batch["loss_mask"]
|
72 |
+
|
73 |
+
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)
|
74 |
+
|
75 |
+
loss, dist = mle_steps(
|
76 |
+
opt.net.model, model, input_[:, :-1, :], targets,
|
77 |
+
attention_mask[:, :-1], loss_reduction="none")
|
78 |
+
|
79 |
+
# Set loss name
|
80 |
+
if not eval_mode or batch_variables["category"] == "positive":
|
81 |
+
micro_name = "total_micro"
|
82 |
+
macro_name = "total_macro"
|
83 |
+
else:
|
84 |
+
micro_name = "negative_micro"
|
85 |
+
macro_name = "negative_macro"
|
86 |
+
|
87 |
+
length = loss_mask.sum(1)
|
88 |
+
bs = input_.size(0)
|
89 |
+
|
90 |
+
final_loss = (loss * loss_mask).sum(1)
|
91 |
+
|
92 |
+
update_generation_losses(losses, nums, micro_name, macro_name, bs,
|
93 |
+
length, (loss * loss_mask).sum(1), split)
|
94 |
+
|
95 |
+
final_loss = final_loss / length
|
96 |
+
|
97 |
+
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}
|
98 |
+
|
99 |
+
if tracking_mode:
|
100 |
+
outputs["tracking"] = final_loss.squeeze().tolist()
|
101 |
+
|
102 |
+
return outputs
|
103 |
+
|
104 |
+
|
105 |
+
def mle_steps(key, model, input_, targets, attention_mask,
|
106 |
+
loss_reduction="mean", i=None):
|
107 |
+
word_acts = decode(model, input_.unsqueeze(1),
|
108 |
+
attention_mask, i)
|
109 |
+
|
110 |
+
word_dist = train_utils.modify_output_for_loss_fn(
|
111 |
+
"nll", word_acts, dim=-1)
|
112 |
+
|
113 |
+
# Compute losses
|
114 |
+
loss = F.nll_loss(
|
115 |
+
word_dist.view(-1, word_dist.size(-1)),
|
116 |
+
targets, reduction=loss_reduction)
|
117 |
+
|
118 |
+
if loss_reduction != "mean":
|
119 |
+
return loss.view(word_dist.size(0), -1), word_dist
|
120 |
+
else:
|
121 |
+
return loss, word_dist
|
122 |
+
|
123 |
+
|
124 |
+
def decode(model, input_, attention_mask, i=None):
|
125 |
+
return model(input_, sequence_mask=attention_mask)
|
126 |
+
|
127 |
+
|
128 |
+
def update_generation_losses(losses, nums, micro, macro, bs,
|
129 |
+
length, loss, split):
|
130 |
+
if split == "train":
|
131 |
+
train_utils.update_generation_losses(
|
132 |
+
losses, nums, micro, macro, bs, length, loss)
|
133 |
+
else:
|
134 |
+
eval_utils.update_generation_losses(
|
135 |
+
losses, nums, micro, macro, bs, length, loss)
|
Model/COSMIC/feature_extraction/comet/src/train/conceptnet_train.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import comet.src.data.config as cfg
|
5 |
+
|
6 |
+
import comet.src.train.atomic_train as base_train
|
7 |
+
import comet.src.train.batch as batch_utils
|
8 |
+
import comet.src.evaluate.conceptnet_evaluate as evaluate
|
9 |
+
import comet.src.evaluate.conceptnet_generate as gen
|
10 |
+
|
11 |
+
|
12 |
+
def make_trainer(opt, *args):
|
13 |
+
return ConceptNetGenerationIteratorTrainer(opt, *args)
|
14 |
+
|
15 |
+
|
16 |
+
class ConceptNetGenerationIteratorTrainer(
|
17 |
+
base_train.AtomicGenerationIteratorTrainer):
|
18 |
+
def set_evaluator(self, opt, model, data_loader):
|
19 |
+
self.evaluator = evaluate.make_evaluator(
|
20 |
+
opt, model, data_loader)
|
21 |
+
|
22 |
+
def set_generator(self, opt, model, data_loader):
|
23 |
+
self.generator = gen.make_generator(
|
24 |
+
opt, model, data_loader)
|
25 |
+
|
26 |
+
def batch(self, opt, *args):
|
27 |
+
outputs = batch_utils.batch_atomic_generate(opt, *args)
|
28 |
+
|
29 |
+
token_loss = outputs["loss"]
|
30 |
+
nums = outputs["nums"]
|
31 |
+
reset = outputs["reset"]
|
32 |
+
|
33 |
+
return token_loss, nums, reset
|
34 |
+
|
35 |
+
def update_top_score(self, opt):
|
36 |
+
print(self.top_score)
|
37 |
+
|
38 |
+
tracked_scores = self.get_tracked_score()
|
39 |
+
|
40 |
+
if self.top_score is None:
|
41 |
+
self.top_score = \
|
42 |
+
self.top_score = {"epoch": {}, "score": {}}
|
43 |
+
self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch
|
44 |
+
self.top_score["score"]["total_micro"] = tracked_scores["total_micro"]
|
45 |
+
else:
|
46 |
+
if tracked_scores["total_micro"] < self.top_score["score"]["total_micro"]:
|
47 |
+
self.top_score["epoch"]["total_micro"] = self.opt.train.dynamic.epoch
|
48 |
+
self.top_score["score"]["total_micro"] = tracked_scores["total_micro"]
|
49 |
+
|
50 |
+
print(self.top_score)
|
51 |
+
|
52 |
+
def get_tracked_score(self):
|
53 |
+
return {
|
54 |
+
"total_micro": self.losses["dev"]["total_micro"][self.opt.train.dynamic.epoch]
|
55 |
+
}
|
56 |
+
|
57 |
+
def decide_to_save(self):
|
58 |
+
to_save = cfg.save and not cfg.toy
|
59 |
+
|
60 |
+
curr_epoch = self.opt.train.dynamic.epoch
|
61 |
+
|
62 |
+
to_save = to_save or cfg.test_save
|
63 |
+
print(cfg.save_strategy)
|
64 |
+
if cfg.save_strategy == "best":
|
65 |
+
if ((self.top_score["epoch"]["total_micro"] != curr_epoch)):
|
66 |
+
to_save = False
|
67 |
+
return to_save
|
Model/COSMIC/feature_extraction/comet/src/train/opt.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''TAKEN from OpenAI LM Code by HuggingFace'''
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch.optim import Optimizer
|
6 |
+
from torch.nn.utils import clip_grad_norm_
|
7 |
+
|
8 |
+
|
9 |
+
def warmup_cosine(x, warmup=0.002):
|
10 |
+
s = 1 if x <= warmup else 0
|
11 |
+
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x)))
|
12 |
+
|
13 |
+
|
14 |
+
def warmup_constant(x, warmup=0.002):
|
15 |
+
s = 1 if x <= warmup else 0
|
16 |
+
return s*(x/warmup) + (1-s)*1
|
17 |
+
|
18 |
+
|
19 |
+
def warmup_linear(x, warmup=0.002):
|
20 |
+
s = 1 if x <= warmup else 0
|
21 |
+
|
22 |
+
# print(s)
|
23 |
+
|
24 |
+
return (s*(x/warmup) + (1-s))*(1-x)
|
25 |
+
|
26 |
+
|
27 |
+
SCHEDULES = {
|
28 |
+
'warmup_cosine': warmup_cosine,
|
29 |
+
'warmup_constant': warmup_constant,
|
30 |
+
'warmup_linear': warmup_linear,
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class OpenAIAdam(Optimizer):
|
35 |
+
"""Implements Open AI version of Adam algorithm with weight decay fix.
|
36 |
+
"""
|
37 |
+
def __init__(self, params, lr, schedule, warmup, t_total,
|
38 |
+
b1=0.9, b2=0.999, e=1e-8, l2=0,
|
39 |
+
vector_l2=False, max_grad_norm=-1, **kwargs):
|
40 |
+
if not 0.0 <= lr:
|
41 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
42 |
+
if schedule not in SCHEDULES:
|
43 |
+
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
44 |
+
if not 0 <= warmup:
|
45 |
+
raise ValueError("Invalid warmup: {}".format(warmup))
|
46 |
+
if not 0.0 <= b1 < 1.0:
|
47 |
+
raise ValueError("Invalid b1 parameter: {}".format(b1))
|
48 |
+
if not 0.0 <= b2 < 1.0:
|
49 |
+
raise ValueError("Invalid b2 parameter: {}".format(b2))
|
50 |
+
if not 0.0 <= e:
|
51 |
+
raise ValueError("Invalid epsilon value: {}".format(e))
|
52 |
+
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
53 |
+
b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2,
|
54 |
+
max_grad_norm=max_grad_norm)
|
55 |
+
super(OpenAIAdam, self).__init__(params, defaults)
|
56 |
+
|
57 |
+
def step(self, closure=None):
|
58 |
+
"""Performs a single optimization step.
|
59 |
+
|
60 |
+
Arguments:
|
61 |
+
closure (callable, optional): A closure that reevaluates the model
|
62 |
+
and returns the loss.
|
63 |
+
"""
|
64 |
+
loss = None
|
65 |
+
if closure is not None:
|
66 |
+
loss = closure()
|
67 |
+
|
68 |
+
for group in self.param_groups:
|
69 |
+
# print(group['t_total'])
|
70 |
+
# print(group['warmup'])
|
71 |
+
# if self.state[group['params'][0]]:
|
72 |
+
# print(self.state[group['params'][0]]['step'] / group['t_total'])
|
73 |
+
# print()
|
74 |
+
for p in group['params']:
|
75 |
+
if p.grad is None:
|
76 |
+
continue
|
77 |
+
grad = p.grad.data
|
78 |
+
if grad.is_sparse:
|
79 |
+
raise RuntimeError(
|
80 |
+
'Adam does not support sparse gradients, \
|
81 |
+
please consider SparseAdam instead')
|
82 |
+
|
83 |
+
state = self.state[p]
|
84 |
+
|
85 |
+
# State initialization
|
86 |
+
if len(state) == 0:
|
87 |
+
state['step'] = 0
|
88 |
+
# Exponential moving average of gradient values
|
89 |
+
state['exp_avg'] = torch.zeros_like(p.data)
|
90 |
+
# Exponential moving average of squared gradient values
|
91 |
+
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
92 |
+
|
93 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
94 |
+
beta1, beta2 = group['b1'], group['b2']
|
95 |
+
|
96 |
+
state['step'] += 1
|
97 |
+
|
98 |
+
# Add grad clipping
|
99 |
+
if group['max_grad_norm'] > 0:
|
100 |
+
clip_grad_norm_(p, group['max_grad_norm'])
|
101 |
+
|
102 |
+
# Decay the first and second moment running average coefficient
|
103 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
104 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
105 |
+
denom = exp_avg_sq.sqrt().add_(group['e'])
|
106 |
+
|
107 |
+
bias_correction1 = 1 - beta1 ** state['step']
|
108 |
+
bias_correction2 = 1 - beta2 ** state['step']
|
109 |
+
|
110 |
+
schedule_fct = SCHEDULES[group['schedule']]
|
111 |
+
lr_scheduled = (group['lr'] * schedule_fct(state['step'] /
|
112 |
+
group['t_total'], group['warmup']))
|
113 |
+
step_size = (lr_scheduled * math.sqrt(bias_correction2) /
|
114 |
+
bias_correction1)
|
115 |
+
|
116 |
+
p.data.addcdiv_(-step_size, exp_avg, denom)
|
117 |
+
|
118 |
+
# Add weight decay at the end (fixed version)
|
119 |
+
if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
|
120 |
+
p.data.add_(-lr_scheduled * group['l2'], p.data)
|
121 |
+
|
122 |
+
return loss
|
Model/COSMIC/feature_extraction/comet/src/train/train.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import comet.src.data.config as cfg
|
6 |
+
import comet.src.data.data as data
|
7 |
+
import comet.src.train.utils as train_utils
|
8 |
+
import comet.src.train.batch as batch
|
9 |
+
|
10 |
+
import comet.src.evaluate.evaluate as evaluate
|
11 |
+
import comet.src.evaluate.generate as gen
|
12 |
+
import comet.src.evaluate.sampler as sampling
|
13 |
+
|
14 |
+
import comet.utils.utils as utils
|
15 |
+
|
16 |
+
from tensorboardX import SummaryWriter
|
17 |
+
|
18 |
+
|
19 |
+
class Trainer(object):
|
20 |
+
def __init__(self, opt, meta, data_loader, model, optimizer):
|
21 |
+
self.optimizer = optimizer
|
22 |
+
|
23 |
+
self.model = model
|
24 |
+
|
25 |
+
if opt.trainer == "epoch":
|
26 |
+
self.epochs = meta.epochs
|
27 |
+
self.data_loader = data_loader
|
28 |
+
self.opt = opt
|
29 |
+
|
30 |
+
self.losses = {"dev": {}, "test": {}, "train": {}}
|
31 |
+
|
32 |
+
self.top_score = None
|
33 |
+
|
34 |
+
self.lrs = {}
|
35 |
+
|
36 |
+
self.batch_variables = {
|
37 |
+
"data": self.data_loader,
|
38 |
+
"model": self.model,
|
39 |
+
"split": "train"
|
40 |
+
}
|
41 |
+
|
42 |
+
self.do_gen = cfg.do_gen
|
43 |
+
self.samplers = {}
|
44 |
+
|
45 |
+
def decide_to_save(self):
|
46 |
+
to_save = cfg.save and not cfg.toy
|
47 |
+
|
48 |
+
to_save = to_save or cfg.test_save
|
49 |
+
print(cfg.save_strategy)
|
50 |
+
if cfg.save_strategy == "best":
|
51 |
+
if self.top_score[0] != self.opt.train.dynamic.epoch:
|
52 |
+
print("DOING IT RIGHT")
|
53 |
+
to_save = False
|
54 |
+
return to_save
|
55 |
+
|
56 |
+
def save_model(self, tracked_score):
|
57 |
+
lrs = {}
|
58 |
+
for i, param_group in enumerate(self.optimizer.param_groups):
|
59 |
+
lrs[i] = param_group['lr']
|
60 |
+
self.lrs[self.opt.train.dynamic.epoch] = lrs
|
61 |
+
|
62 |
+
to_save = self.decide_to_save()
|
63 |
+
|
64 |
+
if to_save:
|
65 |
+
data.save_step(
|
66 |
+
self.model, self.data_loader.vocab_encoder,
|
67 |
+
self.optimizer, self.opt,
|
68 |
+
self.opt.train.dynamic.epoch, self.lrs)
|
69 |
+
|
70 |
+
def log_losses(self, opt, losses):
|
71 |
+
if (not cfg.toy and cfg.save) or cfg.test_save:
|
72 |
+
data.save_eval_file(opt, losses["train"], "losses", split="train")
|
73 |
+
data.save_eval_file(opt, losses['dev'], "losses", split="dev")
|
74 |
+
data.save_eval_file(opt, losses['test'], "losses", split="test")
|
75 |
+
|
76 |
+
def set_logger(self):
|
77 |
+
if cfg.toy:
|
78 |
+
self.logger = SummaryWriter(utils.make_name(
|
79 |
+
self.opt, prefix="garbage/logs/", eval_=True, do_epoch=False))
|
80 |
+
else:
|
81 |
+
self.logger = SummaryWriter(utils.make_name(
|
82 |
+
self.opt, prefix="logs/", eval_=True, do_epoch=False))
|
83 |
+
print("Logging Tensorboard Files at: {}".format(self.logger.logdir))
|
84 |
+
|
85 |
+
def stop_logger(self):
|
86 |
+
self.logger.close()
|
87 |
+
|
88 |
+
def run(self):
|
89 |
+
self.set_logger()
|
90 |
+
self.count = 0
|
91 |
+
for epoch in range(self.epochs):
|
92 |
+
self.model.train()
|
93 |
+
self.opt.train.dynamic.epoch += 1
|
94 |
+
self.epoch()
|
95 |
+
|
96 |
+
self.stop_logger()
|
97 |
+
|
98 |
+
def epoch(self):
|
99 |
+
nums = self.reset_losses()
|
100 |
+
|
101 |
+
# Initialize progress bar
|
102 |
+
bar = utils.initialize_progress_bar(
|
103 |
+
self.data_loader.sequences["train"])
|
104 |
+
|
105 |
+
reset = False
|
106 |
+
|
107 |
+
while not reset:
|
108 |
+
loss, nums, reset = self.do_forward_pass(nums)
|
109 |
+
self.do_backward_pass(loss)
|
110 |
+
self.update_parameters()
|
111 |
+
|
112 |
+
bar.update(self.opt.train.dynamic.bs)
|
113 |
+
self.count += 1
|
114 |
+
|
115 |
+
for loss_name in self.losses["train"]:
|
116 |
+
self.logger.add_scalar(
|
117 |
+
"train/{}".format(loss_name),
|
118 |
+
loss.item() / self.opt.train.dynamic.bs,
|
119 |
+
self.count)
|
120 |
+
|
121 |
+
if cfg.toy and self.counter(nums) > 300:
|
122 |
+
break
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
self.run_evaluation_cycle()
|
126 |
+
|
127 |
+
self.log_losses(self.opt, self.losses)
|
128 |
+
self.update_top_score(self.opt)
|
129 |
+
self.save_model(self.get_tracked_score())
|
130 |
+
|
131 |
+
self.data_loader.reset_offsets("train")
|
132 |
+
|
133 |
+
def run_evaluation_cycle(self):
|
134 |
+
for split in ["dev", "test"]:
|
135 |
+
self.evaluator.validate(
|
136 |
+
self.opt.train.dynamic.epoch, split,
|
137 |
+
self.losses[split])
|
138 |
+
|
139 |
+
if self.do_gen:
|
140 |
+
gen.do_gen_run(
|
141 |
+
self.opt, self.generator, self.opt.train.dynamic.epoch,
|
142 |
+
split, self.losses[split])
|
143 |
+
iter_num = self.opt.train.dynamic.epoch
|
144 |
+
|
145 |
+
for loss_name in self.losses[split]:
|
146 |
+
self.logger.add_scalar(
|
147 |
+
"{}/{}".format(split, loss_name),
|
148 |
+
self.losses[split][loss_name][iter_num],
|
149 |
+
iter_num)
|
150 |
+
|
151 |
+
def clip_gradients(self):
|
152 |
+
if self.opt.train.static.clip:
|
153 |
+
torch.nn.utils.clip_grad_norm_(
|
154 |
+
self.model.parameters(), self.opt.train.static.clip)
|
155 |
+
|
156 |
+
def do_forward_pass(self, nums):
|
157 |
+
token_loss, nums, reset = self.batch(
|
158 |
+
self.opt, nums, self.losses["train"],
|
159 |
+
self.batch_variables)
|
160 |
+
return token_loss, nums, reset
|
161 |
+
|
162 |
+
def do_backward_pass(self, loss):
|
163 |
+
loss.backward()
|
164 |
+
|
165 |
+
def update_parameters(self):
|
166 |
+
if self.opt.model == "lstm":
|
167 |
+
self.clip_gradients()
|
168 |
+
self.optimizer.step()
|
169 |
+
self.optimizer.zero_grad()
|
170 |
+
|
171 |
+
def reset_losses(self):
|
172 |
+
loss_names = set([i.rstrip("maicro").rstrip("_") for
|
173 |
+
i in self.losses["train"].keys()])
|
174 |
+
return self.initialize_losses(list(loss_names))
|
175 |
+
|
176 |
+
|
177 |
+
class IteratorTrainer(Trainer):
|
178 |
+
def __init__(self, opt, meta, data_loader, model, optimizer):
|
179 |
+
super(IteratorTrainer, self).__init__(
|
180 |
+
opt, meta, data_loader, model, optimizer)
|
181 |
+
|
182 |
+
self.iters = meta.cycle
|
183 |
+
self.total_iters = meta.iterations
|
184 |
+
|
185 |
+
def run(self):
|
186 |
+
self.set_logger()
|
187 |
+
|
188 |
+
# Initialize progress bar
|
189 |
+
bar = utils.set_progress_bar(self.total_iters)
|
190 |
+
|
191 |
+
for cycle_num in range(int(self.total_iters / self.iters)):
|
192 |
+
self.model.train()
|
193 |
+
|
194 |
+
self.cycle(bar, cycle_num)
|
195 |
+
|
196 |
+
with torch.no_grad():
|
197 |
+
self.run_evaluation_cycle()
|
198 |
+
|
199 |
+
self.log_losses(self.opt, self.losses)
|
200 |
+
self.update_top_score(self.opt)
|
201 |
+
self.save_model(self.get_tracked_score())
|
202 |
+
|
203 |
+
self.stop_logger()
|
204 |
+
|
205 |
+
def cycle(self, bar, cycle_num):
|
206 |
+
nums = self.reset_losses()
|
207 |
+
print(self.losses["train"])
|
208 |
+
|
209 |
+
for i in range(1, self.iters + 1):
|
210 |
+
# self.model.zero_grad()
|
211 |
+
|
212 |
+
loss, nums, reset = self.do_forward_pass(nums)
|
213 |
+
self.do_backward_pass(loss)
|
214 |
+
|
215 |
+
self.update_parameters()
|
216 |
+
# print(loss)
|
217 |
+
# print(loss.item())
|
218 |
+
self.opt.train.dynamic.epoch += 1
|
219 |
+
|
220 |
+
for loss_name in self.losses["train"]:
|
221 |
+
self.logger.add_scalar(
|
222 |
+
"train/{}".format(loss_name),
|
223 |
+
loss.item() / self.opt.train.dynamic.bs,
|
224 |
+
self.opt.train.dynamic.epoch)
|
225 |
+
|
226 |
+
bar.update(1)
|
227 |
+
|
228 |
+
if cfg.toy and i > 10:
|
229 |
+
break
|
230 |
+
|
231 |
+
if reset:
|
232 |
+
self.data_loader.reset_offsets("train")
|
233 |
+
|
Model/COSMIC/feature_extraction/comet/src/train/utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.optim
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import copy
|
6 |
+
|
7 |
+
|
8 |
+
def update_generation_losses(losses, nums, micro, macro, bs, length, loss):
|
9 |
+
# Update Losses
|
10 |
+
losses[micro] += \
|
11 |
+
[copy.deepcopy(losses[micro][-1])]
|
12 |
+
losses[macro] += \
|
13 |
+
[copy.deepcopy(losses[macro][-1])]
|
14 |
+
|
15 |
+
losses[micro][-1] *= nums[micro]
|
16 |
+
losses[macro][-1] *= nums[macro]
|
17 |
+
|
18 |
+
nums[macro] += bs
|
19 |
+
|
20 |
+
if isinstance(length, int):
|
21 |
+
update_indiv_generation_losses(
|
22 |
+
losses, nums, micro, macro, bs, length, loss)
|
23 |
+
else:
|
24 |
+
update_tensor_generation_losses(
|
25 |
+
losses, nums, micro, macro, bs, length, loss)
|
26 |
+
|
27 |
+
|
28 |
+
def update_indiv_generation_losses(losses, nums, micro,
|
29 |
+
macro, bs, length, loss):
|
30 |
+
nums[micro] += (bs * length)
|
31 |
+
|
32 |
+
batch_loss = loss * bs
|
33 |
+
|
34 |
+
losses[micro][-1] += batch_loss
|
35 |
+
losses[micro][-1] /= nums[micro]
|
36 |
+
losses[macro][-1] += batch_loss / length
|
37 |
+
losses[macro][-1] /= nums[macro]
|
38 |
+
|
39 |
+
|
40 |
+
def update_tensor_generation_losses(losses, nums, micro,
|
41 |
+
macro, bs, length, loss):
|
42 |
+
nums[micro] += length.sum().item()
|
43 |
+
|
44 |
+
losses[micro][-1] += loss.sum().item()
|
45 |
+
losses[micro][-1] /= nums[micro]
|
46 |
+
losses[macro][-1] += (loss / length.float()).sum().item()
|
47 |
+
losses[macro][-1] /= nums[macro]
|
48 |
+
|
49 |
+
|
50 |
+
def modify_output_for_loss_fn(loss_fn, output, dim):
|
51 |
+
if loss_fn == "ce":
|
52 |
+
return output
|
53 |
+
if loss_fn == "mse":
|
54 |
+
return F.softmax(output, dim=dim)
|
55 |
+
if loss_fn == "nll":
|
56 |
+
return F.log_softmax(output, dim=dim)
|
57 |
+
if loss_fn in ["bce", "wbce", "wbce1"]:
|
58 |
+
return torch.sigmoid(output)
|
Model/COSMIC/feature_extraction/comet/utils/__init__.py
ADDED
File without changes
|
Model/COSMIC/feature_extraction/comet/utils/utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import copy
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import contextlib
|
8 |
+
|
9 |
+
from distutils.dir_util import mkpath
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def make_new_tensor_from_list(items, device_num, dtype=torch.float32):
|
15 |
+
if device_num is not None:
|
16 |
+
device = torch.device("cuda:{}".format(device_num))
|
17 |
+
else:
|
18 |
+
device = torch.device("cpu")
|
19 |
+
return torch.tensor(items, dtype=dtype, device=device)
|
20 |
+
|
21 |
+
|
22 |
+
# is_dir look ast at whether the name we make
|
23 |
+
# should be a directory or a filename
|
24 |
+
def make_name(opt, prefix="", eval_=False, is_dir=True, set_epoch=None,
|
25 |
+
do_epoch=True):
|
26 |
+
string = prefix
|
27 |
+
string += "{}-{}".format(opt.dataset, opt.exp)
|
28 |
+
string += "/"
|
29 |
+
string += "{}-{}-{}".format(opt.trainer, opt.cycle, opt.iters)
|
30 |
+
string += "/"
|
31 |
+
string += opt.model
|
32 |
+
if opt.mle:
|
33 |
+
string += "-{}".format(opt.mle)
|
34 |
+
string += "/"
|
35 |
+
string += make_name_string(opt.data) + "/"
|
36 |
+
|
37 |
+
string += make_name_string(opt.net) + "/"
|
38 |
+
string += make_name_string(opt.train.static) + "/"
|
39 |
+
|
40 |
+
if eval_:
|
41 |
+
string += make_name_string(opt.eval) + "/"
|
42 |
+
# mkpath caches whether a directory has been created
|
43 |
+
# In IPython, this can be a problem if the kernel is
|
44 |
+
# not reset after a dir is deleted. Trying to recreate
|
45 |
+
# that dir will be a problem because mkpath will think
|
46 |
+
# the directory already exists
|
47 |
+
if not is_dir:
|
48 |
+
mkpath(string)
|
49 |
+
string += make_name_string(
|
50 |
+
opt.train.dynamic, True, do_epoch, set_epoch)
|
51 |
+
if is_dir:
|
52 |
+
mkpath(string)
|
53 |
+
|
54 |
+
return string
|
55 |
+
|
56 |
+
|
57 |
+
def make_name_string(dict_, final=False, do_epoch=False, set_epoch=None):
|
58 |
+
if final:
|
59 |
+
if not do_epoch:
|
60 |
+
string = "{}_{}_{}".format(
|
61 |
+
dict_.lr, dict_.optim, dict_.bs)
|
62 |
+
elif set_epoch is not None:
|
63 |
+
string = "{}_{}_{}_{}".format(
|
64 |
+
dict_.lr, dict_.optim, dict_.bs, set_epoch)
|
65 |
+
else:
|
66 |
+
string = "{}_{}_{}_{}".format(
|
67 |
+
dict_.lr, dict_.optim, dict_.bs, dict_.epoch)
|
68 |
+
|
69 |
+
return string
|
70 |
+
|
71 |
+
string = ""
|
72 |
+
|
73 |
+
for k, v in dict_.items():
|
74 |
+
if type(v) == DD:
|
75 |
+
continue
|
76 |
+
if isinstance(v, list):
|
77 |
+
val = "#".join(is_bool(str(vv)) for vv in v)
|
78 |
+
else:
|
79 |
+
val = is_bool(v)
|
80 |
+
if string:
|
81 |
+
string += "-"
|
82 |
+
string += "{}_{}".format(k, val)
|
83 |
+
|
84 |
+
return string
|
85 |
+
|
86 |
+
|
87 |
+
def is_bool(v):
|
88 |
+
if str(v) == "False":
|
89 |
+
return "F"
|
90 |
+
elif str(v) == "True":
|
91 |
+
return "T"
|
92 |
+
return v
|
93 |
+
|
94 |
+
|
95 |
+
def generate_config_files(type_, key, name="base", eval_mode=False):
|
96 |
+
with open("config/default.json".format(type_), "r") as f:
|
97 |
+
base_config = json.load(f)
|
98 |
+
with open("config/{}/default.json".format(type_), "r") as f:
|
99 |
+
base_config_2 = json.load(f)
|
100 |
+
if eval_mode:
|
101 |
+
with open("config/{}/eval_changes.json".format(type_), "r") as f:
|
102 |
+
changes_by_machine = json.load(f)
|
103 |
+
else:
|
104 |
+
with open("config/{}/changes.json".format(type_), "r") as f:
|
105 |
+
changes_by_machine = json.load(f)
|
106 |
+
|
107 |
+
base_config.update(base_config_2)
|
108 |
+
|
109 |
+
if name in changes_by_machine:
|
110 |
+
changes = changes_by_machine[name]
|
111 |
+
else:
|
112 |
+
changes = changes_by_machine["base"]
|
113 |
+
|
114 |
+
# for param in changes[key]:
|
115 |
+
# base_config[param] = changes[key][param]
|
116 |
+
|
117 |
+
replace_params(base_config, changes[key])
|
118 |
+
|
119 |
+
mkpath("config/{}".format(type_))
|
120 |
+
|
121 |
+
with open("config/{}/config_{}.json".format(type_, key), "w") as f:
|
122 |
+
json.dump(base_config, f, indent=4)
|
123 |
+
|
124 |
+
|
125 |
+
def replace_params(base_config, changes):
|
126 |
+
for param, value in changes.items():
|
127 |
+
if isinstance(value, dict) and param in base_config:
|
128 |
+
replace_params(base_config[param], changes[param])
|
129 |
+
else:
|
130 |
+
base_config[param] = value
|
131 |
+
|
132 |
+
|
133 |
+
def initialize_progress_bar(data_loader_list):
|
134 |
+
num_examples = sum([len(tensor) for tensor in
|
135 |
+
data_loader_list.values()])
|
136 |
+
return set_progress_bar(num_examples)
|
137 |
+
|
138 |
+
|
139 |
+
def set_progress_bar(num_examples):
|
140 |
+
bar = tqdm(total=num_examples)
|
141 |
+
bar.update(0)
|
142 |
+
return bar
|
143 |
+
|
144 |
+
|
145 |
+
def merge_list_of_dicts(L):
|
146 |
+
result = {}
|
147 |
+
for d in L:
|
148 |
+
result.update(d)
|
149 |
+
return result
|
150 |
+
|
151 |
+
|
152 |
+
def return_iterator_by_type(data_type):
|
153 |
+
if isinstance(data_type, dict):
|
154 |
+
iterator = data_type.items()
|
155 |
+
else:
|
156 |
+
iterator = enumerate(data_type)
|
157 |
+
return iterator
|
158 |
+
|
159 |
+
|
160 |
+
@contextlib.contextmanager
|
161 |
+
def temp_seed(seed):
|
162 |
+
state = np.random.get_state()
|
163 |
+
np.random.seed(seed)
|
164 |
+
try:
|
165 |
+
yield
|
166 |
+
finally:
|
167 |
+
np.random.set_state(state)
|
168 |
+
|
169 |
+
|
170 |
+
def flatten(outer):
|
171 |
+
return [el for inner in outer for el in inner]
|
172 |
+
|
173 |
+
|
174 |
+
def zipped_flatten(outer):
|
175 |
+
return [(key, fill, el) for key, fill, inner in outer for el in inner]
|
176 |
+
|
177 |
+
|
178 |
+
def remove_none(l):
|
179 |
+
return [e for e in l if e is not None]
|
180 |
+
|
181 |
+
|
182 |
+
# Taken from Jobman 0.1
|
183 |
+
class DD(dict):
|
184 |
+
def __getattr__(self, attr):
|
185 |
+
if attr == '__getstate__':
|
186 |
+
return super(DD, self).__getstate__
|
187 |
+
elif attr == '__setstate__':
|
188 |
+
return super(DD, self).__setstate__
|
189 |
+
elif attr == '__slots__':
|
190 |
+
return super(DD, self).__slots__
|
191 |
+
return self[attr]
|
192 |
+
|
193 |
+
def __setattr__(self, attr, value):
|
194 |
+
# Safety check to ensure consistent behavior with __getattr__.
|
195 |
+
assert attr not in ('__getstate__', '__setstate__', '__slots__')
|
196 |
+
# if attr.startswith('__'):
|
197 |
+
# return super(DD, self).__setattr__(attr, value)
|
198 |
+
self[attr] = value
|
199 |
+
|
200 |
+
def __str__(self):
|
201 |
+
return 'DD%s' % dict(self)
|
202 |
+
|
203 |
+
def __repr__(self):
|
204 |
+
return str(self)
|
205 |
+
|
206 |
+
def __deepcopy__(self, memo):
|
207 |
+
z = DD()
|
208 |
+
for k, kv in self.items():
|
209 |
+
z[k] = copy.deepcopy(kv, memo)
|
210 |
+
return z
|
Model/COSMIC/feature_extraction/multiprocessing_bpe_encoder.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import contextlib
|
10 |
+
import sys
|
11 |
+
|
12 |
+
from collections import Counter
|
13 |
+
from multiprocessing import Pool
|
14 |
+
|
15 |
+
from fairseq.data.encoders.gpt2_bpe import get_encoder
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
"""
|
20 |
+
Helper script to encode raw text with the GPT-2 BPE using multiple processes.
|
21 |
+
|
22 |
+
The encoder.json and vocab.bpe files can be obtained here:
|
23 |
+
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
24 |
+
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
25 |
+
"""
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument(
|
28 |
+
"--encoder-json",
|
29 |
+
help='path to encoder.json',
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--vocab-bpe",
|
33 |
+
type=str,
|
34 |
+
help='path to vocab.bpe',
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--inputs",
|
38 |
+
nargs="+",
|
39 |
+
default=['-'],
|
40 |
+
help="input files to filter/encode",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--outputs",
|
44 |
+
nargs="+",
|
45 |
+
default=['-'],
|
46 |
+
help="path to save encoded outputs",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--keep-empty",
|
50 |
+
action="store_true",
|
51 |
+
help="keep empty lines",
|
52 |
+
)
|
53 |
+
parser.add_argument("--workers", type=int, default=20)
|
54 |
+
args = parser.parse_args()
|
55 |
+
|
56 |
+
assert len(args.inputs) == len(args.outputs), \
|
57 |
+
"number of input and output paths should match"
|
58 |
+
|
59 |
+
with contextlib.ExitStack() as stack:
|
60 |
+
inputs = [
|
61 |
+
stack.enter_context(open(input, "r", encoding="utf-8"))
|
62 |
+
if input != "-" else sys.stdin
|
63 |
+
for input in args.inputs
|
64 |
+
]
|
65 |
+
outputs = [
|
66 |
+
stack.enter_context(open(output, "w", encoding="utf-8"))
|
67 |
+
if output != "-" else sys.stdout
|
68 |
+
for output in args.outputs
|
69 |
+
]
|
70 |
+
|
71 |
+
encoder = MultiprocessingEncoder(args)
|
72 |
+
pool = Pool(args.workers, initializer=encoder.initializer)
|
73 |
+
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
|
74 |
+
|
75 |
+
stats = Counter()
|
76 |
+
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
|
77 |
+
if filt == "PASS":
|
78 |
+
for enc_line, output_h in zip(enc_lines, outputs):
|
79 |
+
print(enc_line, file=output_h)
|
80 |
+
else:
|
81 |
+
stats["num_filtered_" + filt] += 1
|
82 |
+
if i % 10000 == 0:
|
83 |
+
print("processed {} lines".format(i), file=sys.stderr)
|
84 |
+
|
85 |
+
for k, v in stats.most_common():
|
86 |
+
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
|
87 |
+
|
88 |
+
|
89 |
+
class MultiprocessingEncoder(object):
|
90 |
+
|
91 |
+
def __init__(self, args):
|
92 |
+
self.args = args
|
93 |
+
|
94 |
+
def initializer(self):
|
95 |
+
global bpe
|
96 |
+
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
|
97 |
+
|
98 |
+
def encode(self, line):
|
99 |
+
global bpe
|
100 |
+
ids = bpe.encode(line)
|
101 |
+
return list(map(str, ids))
|
102 |
+
|
103 |
+
def decode(self, tokens):
|
104 |
+
global bpe
|
105 |
+
return bpe.decode(tokens)
|
106 |
+
|
107 |
+
def encode_lines(self, lines):
|
108 |
+
"""
|
109 |
+
Encode a set of lines. All lines will be encoded together.
|
110 |
+
"""
|
111 |
+
enc_lines = []
|
112 |
+
for line in lines:
|
113 |
+
line = line.strip()
|
114 |
+
if len(line) == 0 and not self.args.keep_empty:
|
115 |
+
return ["EMPTY", None]
|
116 |
+
tokens = self.encode(line)
|
117 |
+
enc_lines.append(" ".join(tokens))
|
118 |
+
return ["PASS", enc_lines]
|
119 |
+
|
120 |
+
def decode_lines(self, lines):
|
121 |
+
dec_lines = []
|
122 |
+
for line in lines:
|
123 |
+
tokens = map(int, line.strip().split())
|
124 |
+
dec_lines.append(self.decode(tokens))
|
125 |
+
return ["PASS", dec_lines]
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
main()
|
Model/COSMIC/feature_extraction/src/__init__.py
ADDED
File without changes
|
Model/COSMIC/feature_extraction/src/data/atomic.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.utils.utils as utils
|
2 |
+
import comet.src.data.utils as data_utils
|
3 |
+
import comet.src.data.config as cfg
|
4 |
+
|
5 |
+
import pandas
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def map_name(name):
|
15 |
+
if name == "train":
|
16 |
+
return "trn"
|
17 |
+
elif name == "test":
|
18 |
+
return "tst"
|
19 |
+
else:
|
20 |
+
return "dev"
|
21 |
+
|
22 |
+
|
23 |
+
class DataLoader(object):
|
24 |
+
def __init__(self, opt):
|
25 |
+
self.data = {}
|
26 |
+
self.data["train"] = {}
|
27 |
+
self.data["dev"] = {}
|
28 |
+
self.data["test"] = {}
|
29 |
+
|
30 |
+
self.sequences = {}
|
31 |
+
self.sequences["train"] = {}
|
32 |
+
self.sequences["dev"] = {}
|
33 |
+
self.sequences["test"] = {}
|
34 |
+
|
35 |
+
self.masks = {}
|
36 |
+
self.masks["train"] = {}
|
37 |
+
self.masks["dev"] = {}
|
38 |
+
self.masks["test"] = {}
|
39 |
+
|
40 |
+
self.offsets = {}
|
41 |
+
self.offsets["train"] = {}
|
42 |
+
self.offsets["dev"] = {}
|
43 |
+
self.offsets["test"] = {}
|
44 |
+
|
45 |
+
def offset_summary(self, split):
|
46 |
+
return self.offsets[split]["total"]
|
47 |
+
|
48 |
+
|
49 |
+
def do_take_partial_dataset(data_opts):
|
50 |
+
if data_opts.get("kr", None) is None:
|
51 |
+
return False
|
52 |
+
if data_opts.kr == 1:
|
53 |
+
return False
|
54 |
+
return True
|
55 |
+
|
56 |
+
|
57 |
+
def select_partial_dataset(data_opts, data):
|
58 |
+
num_selections = math.ceil(data_opts.kr * len(data))
|
59 |
+
return random.sample(data, num_selections)
|
60 |
+
|
61 |
+
|
62 |
+
class GenerationDataLoader(DataLoader):
|
63 |
+
def __init__(self, opt, categories):
|
64 |
+
super(GenerationDataLoader, self).__init__(opt)
|
65 |
+
|
66 |
+
self.categories = categories
|
67 |
+
self.opt = opt
|
68 |
+
|
69 |
+
for split in self.data:
|
70 |
+
self.data[split] = {"total": []}
|
71 |
+
self.offsets[split] = {"total": 0}
|
72 |
+
|
73 |
+
self.vocab_encoder = None
|
74 |
+
self.vocab_decoder = None
|
75 |
+
self.special_chars = None
|
76 |
+
self.max_event = None
|
77 |
+
self.max_effect = None
|
78 |
+
|
79 |
+
def load_data(self, path):
|
80 |
+
if ".pickle" in path:
|
81 |
+
print("Loading data from: {}".format(path))
|
82 |
+
data_utils.load_existing_data_loader(self, path)
|
83 |
+
|
84 |
+
return True
|
85 |
+
|
86 |
+
for split in self.data:
|
87 |
+
file_name = "v4_atomic_{}.csv".format(map_name(split))
|
88 |
+
|
89 |
+
df = pandas.read_csv("{}/{}".format(path, file_name), index_col=0)
|
90 |
+
df.iloc[:, :9] = df.iloc[:, :9].apply(
|
91 |
+
lambda col: col.apply(json.loads))
|
92 |
+
|
93 |
+
for cat in self.categories:
|
94 |
+
attr = df[cat]
|
95 |
+
self.data[split]["total"] += utils.zipped_flatten(zip(
|
96 |
+
attr.index, ["<{}>".format(cat)] * len(attr), attr.values))
|
97 |
+
|
98 |
+
if do_take_partial_dataset(self.opt.data):
|
99 |
+
self.data["train"]["total"] = select_partial_dataset(
|
100 |
+
self.opt.data, self.data["train"]["total"])
|
101 |
+
|
102 |
+
return False
|
103 |
+
|
104 |
+
def make_tensors(self, text_encoder, special,
|
105 |
+
splits=["train", "dev", "test"], test=False):
|
106 |
+
self.vocab_encoder = text_encoder.encoder
|
107 |
+
self.vocab_decoder = text_encoder.decoder
|
108 |
+
self.special_chars = special
|
109 |
+
|
110 |
+
sequences = {}
|
111 |
+
for split in splits:
|
112 |
+
sequences[split] = get_generation_sequences(
|
113 |
+
self.opt, self.data, split, text_encoder, test)
|
114 |
+
|
115 |
+
self.masks[split]["total"] = [(len(i[0]), len(i[1])) for
|
116 |
+
i in sequences[split]]
|
117 |
+
|
118 |
+
self.max_event = max([max([l[0] for l in self.masks[split]["total"]])
|
119 |
+
for split in self.masks])
|
120 |
+
self.max_effect = max([max([l[1] for l in self.masks[split]["total"]])
|
121 |
+
for split in self.masks])
|
122 |
+
|
123 |
+
print(self.max_event)
|
124 |
+
print(self.max_effect)
|
125 |
+
|
126 |
+
for split in splits:
|
127 |
+
num_elements = len(sequences[split])
|
128 |
+
self.sequences[split]["total"] = torch.LongTensor(
|
129 |
+
num_elements, self.max_event + self.max_effect).fill_(0)
|
130 |
+
|
131 |
+
for i, seq in enumerate(sequences[split]):
|
132 |
+
# print(self.sequences[split]["total"][i, :len(seq[0])].size())
|
133 |
+
# print(torch.FloatTensor(seq[0]).size())
|
134 |
+
self.sequences[split]["total"][i, :len(seq[0])] = \
|
135 |
+
torch.LongTensor(seq[0])
|
136 |
+
self.sequences[split]["total"][i, self.max_event:self.max_event + len(seq[1])] = \
|
137 |
+
torch.LongTensor(seq[1])
|
138 |
+
|
139 |
+
def sample_batch(self, split, bs, idxs=None):
|
140 |
+
offset = self.offsets[split]["total"]
|
141 |
+
|
142 |
+
batch = {}
|
143 |
+
|
144 |
+
# Decided not to reduce computation on here because it's all parallel
|
145 |
+
# anyway and we don't want to run out of memory in cases where we
|
146 |
+
# don't see the longest version quickly enough
|
147 |
+
|
148 |
+
if idxs:
|
149 |
+
seqs = self.sequences[split]["total"].index_select(
|
150 |
+
0, torch.LongTensor(idxs).to(
|
151 |
+
self.sequences[split]["total"].device))
|
152 |
+
else:
|
153 |
+
seqs = self.sequences[split]["total"][offset:offset + bs]
|
154 |
+
batch["sequences"] = seqs.to(cfg.device)
|
155 |
+
batch["attention_mask"] = make_attention_mask(seqs)
|
156 |
+
batch["loss_mask"] = make_loss_mask(
|
157 |
+
seqs, self.max_event, 1)
|
158 |
+
batch["key"] = ("total", offset, offset + bs)
|
159 |
+
|
160 |
+
offset += seqs.size(0)
|
161 |
+
|
162 |
+
self.offsets[split]["total"] = offset
|
163 |
+
|
164 |
+
if split == "train" and offset + bs > len(self.sequences[split]["total"]):
|
165 |
+
return batch, True
|
166 |
+
elif offset >= len(self.sequences[split]["total"]):
|
167 |
+
return batch, True
|
168 |
+
else:
|
169 |
+
return batch, False
|
170 |
+
|
171 |
+
def reset_offsets(self, splits=["train", "test", "dev"],
|
172 |
+
shuffle=True, keys=None):
|
173 |
+
if isinstance(splits, str):
|
174 |
+
splits = [splits]
|
175 |
+
|
176 |
+
for split in splits:
|
177 |
+
if keys is None:
|
178 |
+
keys = ["total"]
|
179 |
+
|
180 |
+
for key in keys:
|
181 |
+
self.offsets[split][key] = 0
|
182 |
+
|
183 |
+
if shuffle:
|
184 |
+
self.shuffle_sequences(split, keys)
|
185 |
+
|
186 |
+
def shuffle_sequences(self, split="train", keys=None):
|
187 |
+
if keys is None:
|
188 |
+
# print(type(self.data))
|
189 |
+
# print(type(self.data.keys()))
|
190 |
+
keys = self.data[split].keys()
|
191 |
+
|
192 |
+
for key in keys:
|
193 |
+
idxs = list(range(len(self.data[split][key])))
|
194 |
+
|
195 |
+
random.shuffle(idxs)
|
196 |
+
|
197 |
+
self.sequences[split][key] = \
|
198 |
+
self.sequences[split][key].index_select(
|
199 |
+
0, torch.LongTensor(idxs))
|
200 |
+
|
201 |
+
temp = [self.data[split][key][i] for i in idxs]
|
202 |
+
self.data[split][key] = temp
|
203 |
+
temp = [self.masks[split][key][i] for i in idxs]
|
204 |
+
self.masks[split][key] = temp
|
205 |
+
|
206 |
+
|
207 |
+
def prune_data_for_evaluation(data_loader, categories, split):
|
208 |
+
indices = []
|
209 |
+
for i, example in enumerate(data_loader.data[split]["total"]):
|
210 |
+
if example[1] in categories:
|
211 |
+
indices.append(i)
|
212 |
+
|
213 |
+
data_loader.masks[split]["total"] = [data_loader.masks[split]["total"][i]
|
214 |
+
for i in indices]
|
215 |
+
data_loader.sequences[split]["total"] = \
|
216 |
+
data_loader.sequences[split]["total"].index_select(
|
217 |
+
0, torch.LongTensor(indices))
|
218 |
+
data_loader.data[split]["total"] = [data_loader.data[split]["total"][i]
|
219 |
+
for i in indices]
|
220 |
+
|
221 |
+
|
222 |
+
def make_attention_mask(sequences):
|
223 |
+
return (sequences != 0).float().to(cfg.device)
|
224 |
+
|
225 |
+
|
226 |
+
def make_loss_mask(sequences, max_event, num_delim_tokens):
|
227 |
+
# print(num_delim_tokens)
|
228 |
+
# print(sequences.size())
|
229 |
+
mask = (sequences != 0).float()
|
230 |
+
mask[:, :max_event + num_delim_tokens] = 0
|
231 |
+
return mask[:, 1:].to(cfg.device)
|
232 |
+
|
233 |
+
|
234 |
+
def find_underscore_length(seq):
|
235 |
+
start = "_"
|
236 |
+
|
237 |
+
while start in seq:
|
238 |
+
start += "_"
|
239 |
+
return start[:-1]
|
240 |
+
|
241 |
+
|
242 |
+
def handle_underscores(suffix, text_encoder, prefix=False):
|
243 |
+
encoder = text_encoder.encoder
|
244 |
+
if prefix:
|
245 |
+
tok = "___"
|
246 |
+
else:
|
247 |
+
tok = find_underscore_length(suffix)
|
248 |
+
|
249 |
+
suffix_parts = [i.strip() for i in suffix.split("{}".format(tok))]
|
250 |
+
to_flatten = []
|
251 |
+
for i, part in enumerate(suffix_parts):
|
252 |
+
if part:
|
253 |
+
to_flatten.append(text_encoder.encode([part], verbose=False)[0])
|
254 |
+
|
255 |
+
if i != len(suffix_parts) - 1 and suffix_parts[i+1]:
|
256 |
+
to_flatten.append([encoder["<blank>"]])
|
257 |
+
else:
|
258 |
+
to_flatten.append([encoder["<blank>"]])
|
259 |
+
|
260 |
+
final_suffix = utils.flatten(to_flatten)
|
261 |
+
|
262 |
+
return final_suffix
|
263 |
+
|
264 |
+
|
265 |
+
def get_generation_sequences(opt, data, split, text_encoder, test):
|
266 |
+
sequences = []
|
267 |
+
count = 0
|
268 |
+
|
269 |
+
final_prefix = None
|
270 |
+
final_suffix = None
|
271 |
+
|
272 |
+
for prefix, category, suffix in tqdm(data[split]["total"]):
|
273 |
+
final_prefix, final_suffix = do_example(
|
274 |
+
text_encoder, prefix, suffix, True, True)
|
275 |
+
# if do_prefix:
|
276 |
+
# if "___" in prefix:
|
277 |
+
# final_prefix = handle_underscores(prefix, text_encoder, True)
|
278 |
+
# else:
|
279 |
+
# final_prefix = text_encoder.encode([prefix], verbose=False)[0]
|
280 |
+
# if do_suffix:
|
281 |
+
# if "_" in suffix:
|
282 |
+
# final_suffix = handle_underscores(suffix, text_encoder)
|
283 |
+
# else:
|
284 |
+
# final_suffix = text_encoder.encode([suffix], verbose=False)[0]
|
285 |
+
|
286 |
+
final = compile_final_sequence(
|
287 |
+
opt, final_prefix, final_suffix, category, text_encoder)
|
288 |
+
|
289 |
+
sequences.append(final)
|
290 |
+
|
291 |
+
count += 1
|
292 |
+
|
293 |
+
if count > 10 and test:
|
294 |
+
break
|
295 |
+
|
296 |
+
return sequences
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
def do_example(text_encoder, prefix, suffix, do_prefix, do_suffix):
|
301 |
+
final_prefix = None
|
302 |
+
final_suffix = None
|
303 |
+
|
304 |
+
if do_prefix:
|
305 |
+
if "___" in prefix:
|
306 |
+
final_prefix = handle_underscores(prefix, text_encoder, True)
|
307 |
+
else:
|
308 |
+
final_prefix = text_encoder.encode([prefix], verbose=False)[0]
|
309 |
+
if do_suffix:
|
310 |
+
if "_" in suffix:
|
311 |
+
final_suffix = handle_underscores(suffix, text_encoder)
|
312 |
+
else:
|
313 |
+
final_suffix = text_encoder.encode([suffix], verbose=False)[0]
|
314 |
+
|
315 |
+
return final_prefix, final_suffix
|
316 |
+
|
317 |
+
|
318 |
+
def compile_final_sequence(opt, final_prefix, final_suffix, category, text_encoder):
|
319 |
+
final = []
|
320 |
+
|
321 |
+
final.append(final_prefix)
|
322 |
+
final.append(
|
323 |
+
[text_encoder.encoder[category]]
|
324 |
+
+ final_suffix)
|
325 |
+
|
326 |
+
final[-1].append(text_encoder.encoder["<END>"])
|
327 |
+
|
328 |
+
return final
|
329 |
+
|
330 |
+
|
331 |
+
num_delimiter_tokens = {
|
332 |
+
"category": 1,
|
333 |
+
"hierarchy": 3,
|
334 |
+
"hierarchy+label": 4,
|
335 |
+
"category+hierarchy": 4,
|
336 |
+
"category+hierarchy+label": 5
|
337 |
+
}
|
Model/COSMIC/feature_extraction/src/data/conceptnet.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.src.data.utils as data_utils
|
2 |
+
import comet.src.data.atomic as adata
|
3 |
+
import comet.src.data.config as cfg
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def map_name(name, opt):
|
11 |
+
if name == "train":
|
12 |
+
return "train{}k.txt".format(opt.trainsize)
|
13 |
+
elif name == "test":
|
14 |
+
return "test.txt"
|
15 |
+
else:
|
16 |
+
return "dev{}.txt".format(opt.devversion)
|
17 |
+
|
18 |
+
|
19 |
+
conceptnet_relations = [
|
20 |
+
'AtLocation', 'CapableOf', 'Causes', 'CausesDesire',
|
21 |
+
'CreatedBy', 'DefinedAs', 'DesireOf', 'Desires', 'HasA',
|
22 |
+
'HasFirstSubevent', 'HasLastSubevent', 'HasPainCharacter',
|
23 |
+
'HasPainIntensity', 'HasPrerequisite', 'HasProperty',
|
24 |
+
'HasSubevent', 'InheritsFrom', 'InstanceOf', 'IsA',
|
25 |
+
'LocatedNear', 'LocationOfAction', 'MadeOf', 'MotivatedByGoal',
|
26 |
+
'NotCapableOf', 'NotDesires', 'NotHasA', 'NotHasProperty',
|
27 |
+
'NotIsA', 'NotMadeOf', 'PartOf', 'ReceivesAction', 'RelatedTo',
|
28 |
+
'SymbolOf', 'UsedFor'
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
split_into_words = {
|
33 |
+
'AtLocation': "at location",
|
34 |
+
'CapableOf': "capable of",
|
35 |
+
'Causes': "causes",
|
36 |
+
'CausesDesire': "causes desire",
|
37 |
+
'CreatedBy': "created by",
|
38 |
+
'DefinedAs': "defined as",
|
39 |
+
'DesireOf': "desire of",
|
40 |
+
'Desires': "desires",
|
41 |
+
'HasA': "has a",
|
42 |
+
'HasFirstSubevent': "has first subevent",
|
43 |
+
'HasLastSubevent': "has last subevent",
|
44 |
+
'HasPainCharacter': "has pain character",
|
45 |
+
'HasPainIntensity': "has pain intensity",
|
46 |
+
'HasPrerequisite': "has prequisite",
|
47 |
+
'HasProperty': "has property",
|
48 |
+
'HasSubevent': "has subevent",
|
49 |
+
'InheritsFrom': "inherits from",
|
50 |
+
'InstanceOf': 'instance of',
|
51 |
+
'IsA': "is a",
|
52 |
+
'LocatedNear': "located near",
|
53 |
+
'LocationOfAction': "location of action",
|
54 |
+
'MadeOf': "made of",
|
55 |
+
'MotivatedByGoal': "motivated by goal",
|
56 |
+
'NotCapableOf': "not capable of",
|
57 |
+
'NotDesires': "not desires",
|
58 |
+
'NotHasA': "not has a",
|
59 |
+
'NotHasProperty': "not has property",
|
60 |
+
'NotIsA': "not is a",
|
61 |
+
'NotMadeOf': "not made of",
|
62 |
+
'PartOf': "part of",
|
63 |
+
'ReceivesAction': "receives action",
|
64 |
+
'RelatedTo': "related to",
|
65 |
+
'SymbolOf': "symbol of",
|
66 |
+
'UsedFor': "used for"
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
class GenerationDataLoader(adata.DataLoader):
|
71 |
+
def __init__(self, opt, categories=None):
|
72 |
+
super(GenerationDataLoader, self).__init__(opt)
|
73 |
+
self.opt = opt
|
74 |
+
|
75 |
+
for split in self.data:
|
76 |
+
self.data[split] = {"total": []}
|
77 |
+
self.offsets[split] = {"total": 0}
|
78 |
+
|
79 |
+
self.vocab_encoder = None
|
80 |
+
self.vocab_decoder = None
|
81 |
+
self.special_chars = None
|
82 |
+
self.max_e1 = None
|
83 |
+
self.max_e2 = None
|
84 |
+
self.max_r = None
|
85 |
+
|
86 |
+
def offset_summary(self, split):
|
87 |
+
return sum(self.offsets[split].values())
|
88 |
+
|
89 |
+
def load_data(self, path):
|
90 |
+
if ".pickle" in path:
|
91 |
+
print("Loading data from: {}".format(path))
|
92 |
+
data_utils.load_existing_data_loader(self, path)
|
93 |
+
return True
|
94 |
+
|
95 |
+
for split in self.data:
|
96 |
+
file_name = map_name(split, self.opt.data)
|
97 |
+
|
98 |
+
if split != "dev" or self.opt.data.devversion != "12":
|
99 |
+
string_tuples = open("{}/{}".format(
|
100 |
+
path, file_name), "r").read().split("\n")
|
101 |
+
tuples = [x.split("\t") for x in string_tuples if x]
|
102 |
+
else:
|
103 |
+
string_tuples = open("{}/{}".format(
|
104 |
+
path, "dev1.txt"), "r").read().split("\n")
|
105 |
+
tuples = [x.split("\t") for x in string_tuples if x]
|
106 |
+
string_tuples = open("{}/{}".format(
|
107 |
+
path, "dev2.txt"), "r").read().split("\n")
|
108 |
+
tuples += [x.split("\t") for x in string_tuples if x]
|
109 |
+
|
110 |
+
if split in ["dev", "test"]:
|
111 |
+
if self.opt.data.rel == "language":
|
112 |
+
self.data[split]["total"] = \
|
113 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
114 |
+
i[2].lower().strip(), int(i[3])) for i in tuples]
|
115 |
+
self.data[split]["positive"] = \
|
116 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
117 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])]
|
118 |
+
self.data[split]["negative"] = \
|
119 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
120 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])]
|
121 |
+
elif self.opt.data.rel == "relation":
|
122 |
+
self.data[split]["total"] = \
|
123 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
124 |
+
i[2].lower().strip(), int(i[3])) for i in tuples]
|
125 |
+
self.data[split]["positive"] = \
|
126 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
127 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if int(i[3])]
|
128 |
+
self.data[split]["negative"] = \
|
129 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
130 |
+
i[2].lower().strip(), int(i[3])) for i in tuples if not int(i[3])]
|
131 |
+
else:
|
132 |
+
if self.opt.data.rel == "language":
|
133 |
+
self.data[split]["total"] = \
|
134 |
+
[(i[1].lower().strip(), split_into_words[i[0]],
|
135 |
+
i[2].lower().strip(), i[3]) for i in tuples]
|
136 |
+
elif self.opt.data.rel == "relation":
|
137 |
+
self.data[split]["total"] = \
|
138 |
+
[(i[1].lower().strip(), "<{}>".format(i[0]),
|
139 |
+
i[2].lower().strip(), i[3]) for i in tuples]
|
140 |
+
|
141 |
+
return False
|
142 |
+
|
143 |
+
def make_tensors(self, text_encoder, special,
|
144 |
+
splits=["train", "dev", "test"], test=False):
|
145 |
+
self.vocab_encoder = text_encoder.encoder
|
146 |
+
self.vocab_decoder = text_encoder.decoder
|
147 |
+
self.special_chars = special
|
148 |
+
|
149 |
+
sequences = {}
|
150 |
+
for split in splits:
|
151 |
+
sequences[split], discarded = get_generation_sequences(
|
152 |
+
self.data, split, text_encoder, test, self.opt.data.maxe1,
|
153 |
+
self.opt.data.maxe2)
|
154 |
+
|
155 |
+
if split == "train":
|
156 |
+
self.data[split]["total"] = [j for i, j in enumerate(
|
157 |
+
self.data[split]["total"]) if i not in set(discarded)]
|
158 |
+
self.masks[split]["total"] = [(len(i[0]), len(i[1]), len(i[2])) for
|
159 |
+
i in sequences[split]]
|
160 |
+
|
161 |
+
self.max_e1 = max([max([l[0] for l in self.masks[split]["total"]])
|
162 |
+
for split in self.masks])
|
163 |
+
self.max_r = max([max([l[1] for l in self.masks[split]["total"]])
|
164 |
+
for split in self.masks])
|
165 |
+
self.max_e2 = max([max([l[2] for l in self.masks[split]["total"]])
|
166 |
+
for split in self.masks])
|
167 |
+
|
168 |
+
print(self.max_e1)
|
169 |
+
print(self.max_r)
|
170 |
+
print(self.max_e2)
|
171 |
+
|
172 |
+
for split in splits:
|
173 |
+
num_elements = len(sequences[split])
|
174 |
+
self.sequences[split]["total"] = torch.LongTensor(
|
175 |
+
num_elements, self.max_e1 + self.max_e2 + self.max_r).fill_(0)
|
176 |
+
|
177 |
+
for i, seq in enumerate(sequences[split]):
|
178 |
+
# print(self.sequences[split]["total"][i, :len(seq[0])].size())
|
179 |
+
# print(torch.FloatTensor(seq[0]).size())
|
180 |
+
self.sequences[split]["total"][i, :len(seq[0])] = \
|
181 |
+
torch.LongTensor(seq[0])
|
182 |
+
start_r = self.max_e1
|
183 |
+
end_r = self.max_e1 + len(seq[1])
|
184 |
+
self.sequences[split]["total"][i, start_r:end_r] = \
|
185 |
+
torch.LongTensor(seq[1])
|
186 |
+
start_e2 = self.max_e1 + self.max_r
|
187 |
+
end_e2 = self.max_e1 + self.max_r + len(seq[2])
|
188 |
+
self.sequences[split]["total"][i, start_e2:end_e2] = \
|
189 |
+
torch.LongTensor(seq[2])
|
190 |
+
|
191 |
+
if split in ["test", "dev"]:
|
192 |
+
print(split)
|
193 |
+
self.sequences[split]["negative"] = \
|
194 |
+
self.sequences[split]["total"].index_select(
|
195 |
+
0, torch.LongTensor([i for i, j in enumerate(
|
196 |
+
self.data[split]['total']) if not j[3]]))
|
197 |
+
# self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if not j[3]]))
|
198 |
+
self.sequences[split]["positive"] = \
|
199 |
+
self.sequences[split]["total"].index_select(
|
200 |
+
0, torch.LongTensor([i for i, j in enumerate(
|
201 |
+
self.data[split]['total']) if j[3]]))
|
202 |
+
# self.data[split]['total'][:self.sequences[split]["total"].size(0)]) if j[3]]))
|
203 |
+
|
204 |
+
def sample_batch(self, split, bs, cat="total", idxs=None):
|
205 |
+
offset = self.offsets[split][cat]
|
206 |
+
|
207 |
+
batch = {}
|
208 |
+
|
209 |
+
# Decided not to reduce computation on here because it's all parallel
|
210 |
+
# anyway and we don't want to run out of memory in cases where we
|
211 |
+
# don't see the longest version quickly enough
|
212 |
+
|
213 |
+
if idxs:
|
214 |
+
seqs = self.sequences[split][cat].index_select(
|
215 |
+
0, torch.LongTensor(idxs).to(
|
216 |
+
self.sequences[split][cat].device))
|
217 |
+
else:
|
218 |
+
seqs = self.sequences[split][cat][offset:offset + bs]
|
219 |
+
batch["sequences"] = seqs.to(cfg.device)
|
220 |
+
batch["attention_mask"] = make_attention_mask(seqs)
|
221 |
+
batch["loss_mask"] = make_loss_mask(seqs, self.max_e1 + self.max_r)
|
222 |
+
batch["key"] = (cat, offset, offset + bs)
|
223 |
+
|
224 |
+
offset += seqs.size(0)
|
225 |
+
|
226 |
+
self.offsets[split][cat] = offset
|
227 |
+
|
228 |
+
if split == "train" and offset + bs > len(self.sequences[split][cat]):
|
229 |
+
return batch, True
|
230 |
+
elif offset >= len(self.sequences[split][cat]):
|
231 |
+
return batch, True
|
232 |
+
else:
|
233 |
+
return batch, False
|
234 |
+
|
235 |
+
def reset_offsets(self, splits=["train", "test", "dev"],
|
236 |
+
shuffle=True, keys=None):
|
237 |
+
if isinstance(splits, str):
|
238 |
+
splits = [splits]
|
239 |
+
|
240 |
+
for split in splits:
|
241 |
+
if keys is None:
|
242 |
+
keys = ["total", "positive", "negative"]
|
243 |
+
|
244 |
+
for key in keys:
|
245 |
+
self.offsets[split][key] = 0
|
246 |
+
|
247 |
+
if shuffle:
|
248 |
+
self.shuffle_sequences(split, keys)
|
249 |
+
|
250 |
+
def shuffle_sequences(self, split="train", keys=None):
|
251 |
+
if keys is None:
|
252 |
+
# print(type(self.data))
|
253 |
+
# print(type(self.data.keys()))
|
254 |
+
keys = self.data[split].keys()
|
255 |
+
|
256 |
+
for key in keys:
|
257 |
+
if key in ["positive", "negative"]:
|
258 |
+
continue
|
259 |
+
idxs = list(range(len(self.data[split][key])))
|
260 |
+
|
261 |
+
random.shuffle(idxs)
|
262 |
+
|
263 |
+
self.sequences[split][key] = \
|
264 |
+
self.sequences[split][key].index_select(
|
265 |
+
0, torch.LongTensor(idxs))
|
266 |
+
|
267 |
+
temp = [self.data[split][key][i] for i in idxs]
|
268 |
+
self.data[split][key] = temp
|
269 |
+
|
270 |
+
temp = [self.masks[split][key][i] for i in idxs]
|
271 |
+
self.masks[split][key] = temp
|
272 |
+
|
273 |
+
|
274 |
+
def make_attention_mask(sequences):
|
275 |
+
return (sequences != 0).float().to(cfg.device)
|
276 |
+
|
277 |
+
|
278 |
+
def make_loss_mask(sequences, max_event):
|
279 |
+
# print(sequences.size())
|
280 |
+
mask = (sequences != 0).float()
|
281 |
+
mask[:, :max_event] = 0
|
282 |
+
return mask[:, 1:].to(cfg.device)
|
283 |
+
|
284 |
+
|
285 |
+
def get_generation_sequences(data, split, text_encoder, test,
|
286 |
+
max_e1=10, max_e2=15):
|
287 |
+
sequences = []
|
288 |
+
count = 0
|
289 |
+
|
290 |
+
final_event1 = None
|
291 |
+
final_event2 = None
|
292 |
+
final_relation = None
|
293 |
+
|
294 |
+
discarded = []
|
295 |
+
|
296 |
+
for event1, relation, event2, _ in tqdm(data[split]["total"]):
|
297 |
+
e1, r, e2 = do_example(text_encoder, event1, relation, event2)
|
298 |
+
|
299 |
+
if (split == "train" and len(e1) > max_e1 or
|
300 |
+
len(e2) > max_e2):
|
301 |
+
discarded.append(count)
|
302 |
+
count += 1
|
303 |
+
continue
|
304 |
+
|
305 |
+
final = compile_final_sequence(
|
306 |
+
e1, e2, r, text_encoder)
|
307 |
+
|
308 |
+
sequences.append(final)
|
309 |
+
|
310 |
+
count += 1
|
311 |
+
|
312 |
+
if count > 10 and test:
|
313 |
+
break
|
314 |
+
|
315 |
+
return sequences, discarded
|
316 |
+
|
317 |
+
|
318 |
+
def do_example(text_encoder, event1, relation, event2):
|
319 |
+
final_event1 = text_encoder.encode([event1], verbose=False)[0]
|
320 |
+
if relation.lower() != relation:
|
321 |
+
final_relation = [text_encoder.encoder[relation]]
|
322 |
+
else:
|
323 |
+
final_relation = text_encoder.encode(
|
324 |
+
[relation], verbose=False)[0]
|
325 |
+
if event2 is not None:
|
326 |
+
final_event2 = text_encoder.encode([event2], verbose=False)[0]
|
327 |
+
else:
|
328 |
+
final_event2 = None
|
329 |
+
|
330 |
+
return final_event1, final_relation, final_event2
|
331 |
+
|
332 |
+
|
333 |
+
def compile_final_sequence(final_event1, final_event2, final_relation, text_encoder):
|
334 |
+
final = []
|
335 |
+
|
336 |
+
final.append(final_event1)
|
337 |
+
final.append(final_relation)
|
338 |
+
final.append(final_event2)
|
339 |
+
|
340 |
+
final[-1].append(text_encoder.encoder["<END>"])
|
341 |
+
|
342 |
+
return final
|
Model/COSMIC/feature_extraction/src/data/data.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import comet.src.data.atomic as atomic_data
|
3 |
+
import comet.src.data.conceptnet as conceptnet_data
|
4 |
+
import comet.src.data.config as cfg
|
5 |
+
|
6 |
+
import comet.utils.utils as utils
|
7 |
+
|
8 |
+
import pickle
|
9 |
+
import torch
|
10 |
+
import json
|
11 |
+
|
12 |
+
|
13 |
+
start_token = "<START>"
|
14 |
+
end_token = "<END>"
|
15 |
+
blank_token = "<blank>"
|
16 |
+
|
17 |
+
|
18 |
+
def save_checkpoint(state, filename):
|
19 |
+
print("Saving model to {}".format(filename))
|
20 |
+
torch.save(state, filename)
|
21 |
+
|
22 |
+
|
23 |
+
def save_step(model, vocab, optimizer, opt, length, lrs):
|
24 |
+
if cfg.test_save:
|
25 |
+
name = "{}.pickle".format(utils.make_name(
|
26 |
+
opt, prefix="garbage/models/", is_dir=False, eval_=True))
|
27 |
+
else:
|
28 |
+
name = "{}.pickle".format(utils.make_name(
|
29 |
+
opt, prefix="models/", is_dir=False, eval_=True))
|
30 |
+
save_checkpoint({
|
31 |
+
"epoch": length, "state_dict": model.state_dict(),
|
32 |
+
"optimizer": optimizer.state_dict(), "opt": opt,
|
33 |
+
"vocab": vocab, "epoch_learning_rates": lrs},
|
34 |
+
name)
|
35 |
+
|
36 |
+
|
37 |
+
def save_eval_file(opt, stats, eval_type="losses", split="dev", ext="pickle"):
|
38 |
+
if cfg.test_save:
|
39 |
+
name = "{}/{}.{}".format(utils.make_name(
|
40 |
+
opt, prefix="garbage/{}/".format(eval_type),
|
41 |
+
is_dir=True, eval_=True), split, ext)
|
42 |
+
else:
|
43 |
+
name = "{}/{}.{}".format(utils.make_name(
|
44 |
+
opt, prefix="results/{}/".format(eval_type),
|
45 |
+
is_dir=True, eval_=True), split, ext)
|
46 |
+
print("Saving {} {} to {}".format(split, eval_type, name))
|
47 |
+
|
48 |
+
if ext == "pickle":
|
49 |
+
with open(name, "wb") as f:
|
50 |
+
pickle.dump(stats, f)
|
51 |
+
elif ext == "txt":
|
52 |
+
with open(name, "w") as f:
|
53 |
+
f.write(stats)
|
54 |
+
elif ext == "json":
|
55 |
+
with open(name, "w") as f:
|
56 |
+
json.dump(stats, f)
|
57 |
+
else:
|
58 |
+
raise
|
59 |
+
|
60 |
+
|
61 |
+
def load_checkpoint(filename, gpu=True):
|
62 |
+
if os.path.exists(filename):
|
63 |
+
checkpoint = torch.load(
|
64 |
+
filename, map_location=lambda storage, loc: storage)
|
65 |
+
else:
|
66 |
+
print("No model found at {}".format(filename))
|
67 |
+
return checkpoint
|
68 |
+
|
69 |
+
|
70 |
+
def make_data_loader(opt, *args):
|
71 |
+
if opt.dataset == "atomic":
|
72 |
+
return atomic_data.GenerationDataLoader(opt, *args)
|
73 |
+
elif opt.dataset == "conceptnet":
|
74 |
+
return conceptnet_data.GenerationDataLoader(opt, *args)
|
75 |
+
|
76 |
+
|
77 |
+
def set_max_sizes(data_loader, force_split=None):
|
78 |
+
data_loader.total_size = {}
|
79 |
+
if force_split is not None:
|
80 |
+
data_loader.total_size[force_split] = \
|
81 |
+
data_loader.sequences[force_split]["total"].size(0)
|
82 |
+
return
|
83 |
+
for split in data_loader.sequences:
|
84 |
+
data_loader.total_size[split] = \
|
85 |
+
data_loader.sequences[split]["total"].size(0)
|
Model/COSMIC/feature_extraction/src/data/utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import ftfy
|
3 |
+
import json
|
4 |
+
import spacy
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
def load_existing_data_loader(data_loader, path):
|
11 |
+
old_data_loader = torch.load(path)
|
12 |
+
for attr in data_loader.__dict__.keys():
|
13 |
+
if attr not in old_data_loader.__dict__.keys():
|
14 |
+
continue
|
15 |
+
setattr(data_loader, attr, getattr(old_data_loader, attr))
|
16 |
+
|
17 |
+
|
18 |
+
################################################################################
|
19 |
+
#
|
20 |
+
# Code Below taken from HuggingFace pytorch-openai-lm repository
|
21 |
+
#
|
22 |
+
################################################################################
|
23 |
+
|
24 |
+
|
25 |
+
def get_pairs(word):
|
26 |
+
"""
|
27 |
+
Return set of symbol pairs in a word.
|
28 |
+
word is represented as tuple of symbols (symbols being variable-length strings)
|
29 |
+
"""
|
30 |
+
pairs = set()
|
31 |
+
prev_char = word[0]
|
32 |
+
for char in word[1:]:
|
33 |
+
pairs.add((prev_char, char))
|
34 |
+
prev_char = char
|
35 |
+
return pairs
|
36 |
+
|
37 |
+
|
38 |
+
def text_standardize(text):
|
39 |
+
"""
|
40 |
+
fixes some issues the spacy tokenizer had on books corpus
|
41 |
+
also does some whitespace standardization
|
42 |
+
"""
|
43 |
+
text = text.replace('—', '-')
|
44 |
+
text = text.replace('–', '-')
|
45 |
+
text = text.replace('―', '-')
|
46 |
+
text = text.replace('…', '...')
|
47 |
+
text = text.replace('´', "'")
|
48 |
+
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
|
49 |
+
text = re.sub(r'\s*\n\s*', ' \n ', text)
|
50 |
+
text = re.sub(r'[^\S\n]+', ' ', text)
|
51 |
+
return text.strip()
|
52 |
+
|
53 |
+
|
54 |
+
class TextEncoder(object):
|
55 |
+
"""
|
56 |
+
mostly a wrapper for a public python bpe tokenizer
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, encoder_path, bpe_path):
|
60 |
+
self.nlp = spacy.load(
|
61 |
+
'en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
62 |
+
self.encoder = json.load(open(encoder_path))
|
63 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
64 |
+
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
|
65 |
+
merges = [tuple(merge.split()) for merge in merges]
|
66 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
67 |
+
self.cache = {}
|
68 |
+
|
69 |
+
def bpe(self, token):
|
70 |
+
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
71 |
+
if token in self.cache:
|
72 |
+
return self.cache[token]
|
73 |
+
pairs = get_pairs(word)
|
74 |
+
|
75 |
+
if not pairs:
|
76 |
+
return token+'</w>'
|
77 |
+
|
78 |
+
while True:
|
79 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(
|
80 |
+
pair, float('inf')))
|
81 |
+
if bigram not in self.bpe_ranks:
|
82 |
+
break
|
83 |
+
first, second = bigram
|
84 |
+
new_word = []
|
85 |
+
i = 0
|
86 |
+
while i < len(word):
|
87 |
+
try:
|
88 |
+
j = word.index(first, i)
|
89 |
+
new_word.extend(word[i:j])
|
90 |
+
i = j
|
91 |
+
except:
|
92 |
+
new_word.extend(word[i:])
|
93 |
+
break
|
94 |
+
|
95 |
+
if (word[i] == first and i < len(word) - 1 and
|
96 |
+
word[i+1] == second):
|
97 |
+
new_word.append(first+second)
|
98 |
+
i += 2
|
99 |
+
else:
|
100 |
+
new_word.append(word[i])
|
101 |
+
i += 1
|
102 |
+
new_word = tuple(new_word)
|
103 |
+
word = new_word
|
104 |
+
if len(word) == 1:
|
105 |
+
break
|
106 |
+
else:
|
107 |
+
pairs = get_pairs(word)
|
108 |
+
word = ' '.join(word)
|
109 |
+
if word == '\n </w>':
|
110 |
+
word = '\n</w>'
|
111 |
+
self.cache[token] = word
|
112 |
+
return word
|
113 |
+
|
114 |
+
def encode(self, texts, verbose=True):
|
115 |
+
texts_tokens = []
|
116 |
+
if verbose:
|
117 |
+
for text in tqdm(texts, ncols=80, leave=False):
|
118 |
+
text = self.nlp(text_standardize(ftfy.fix_text(text)))
|
119 |
+
text_tokens = []
|
120 |
+
for token in text:
|
121 |
+
text_tokens.extend(
|
122 |
+
[self.encoder.get(t, 0) for t in
|
123 |
+
self.bpe(token.text.lower()).split(' ')])
|
124 |
+
texts_tokens.append(text_tokens)
|
125 |
+
else:
|
126 |
+
for text in texts:
|
127 |
+
text = self.nlp(text_standardize(ftfy.fix_text(text)))
|
128 |
+
text_tokens = []
|
129 |
+
for token in text:
|
130 |
+
text_tokens.extend(
|
131 |
+
[self.encoder.get(t, 0) for t in
|
132 |
+
self.bpe(token.text.lower()).split(' ')])
|
133 |
+
texts_tokens.append(text_tokens)
|
134 |
+
return texts_tokens
|
Model/COSMIC/feature_extraction/src/evaluate/atomic_evaluate.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.src.train.batch as batch
|
2 |
+
import comet.src.evaluate.evaluate as base_evaluate
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def make_evaluator(opt, *args):
|
6 |
+
if opt.exp == "generation":
|
7 |
+
return AtomicGenerationEvaluator(opt, *args)
|
8 |
+
else:
|
9 |
+
return AtomicClassificationEvaluator(opt, *args)
|
10 |
+
|
11 |
+
|
12 |
+
class AtomicGenerationEvaluator(base_evaluate.Evaluator):
|
13 |
+
def __init__(self, opt, model, data_loader):
|
14 |
+
super(AtomicGenerationEvaluator, self).__init__(
|
15 |
+
opt, model, data_loader)
|
16 |
+
|
17 |
+
self.batch = batch.batch_atomic_generate
|
18 |
+
|
19 |
+
def initialize_losses(self):
|
20 |
+
average_loss = {"total_micro": 0, "total_macro": 0}
|
21 |
+
nums = {"total_micro": 0, "total_macro": 0}
|
22 |
+
return average_loss, nums
|
23 |
+
|
24 |
+
def compute_final_scores(self, average_loss, nums):
|
25 |
+
average_loss["total_macro"] /= nums["total_macro"]
|
26 |
+
average_loss["total_micro"] /= nums["total_micro"]
|
27 |
+
|
28 |
+
average_loss["ppl_macro"] = np.exp(average_loss["total_macro"])
|
29 |
+
average_loss["ppl_micro"] = np.exp(average_loss["total_micro"])
|
30 |
+
|
31 |
+
return average_loss
|
32 |
+
|
33 |
+
def counter(self, nums):
|
34 |
+
return nums["total_macro"]
|
35 |
+
|
36 |
+
def print_result(self, split, epoch_losses):
|
37 |
+
print("{} Loss: \t {}".format(
|
38 |
+
split, epoch_losses["total_micro"]))
|
39 |
+
print("{} Perplexity: \t {}".format(
|
40 |
+
split, epoch_losses["ppl_micro"]))
|
Model/COSMIC/feature_extraction/src/evaluate/conceptnet_evaluate.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import comet.src.train.batch as batch_utils
|
5 |
+
import comet.utils.utils as utils
|
6 |
+
import comet.src.evaluate.evaluate as base_evaluate
|
7 |
+
|
8 |
+
|
9 |
+
def make_evaluator(opt, *args, **kwargs):
|
10 |
+
return ConceptNetGenerationEvaluator(opt, *args, **kwargs)
|
11 |
+
|
12 |
+
|
13 |
+
class ConceptNetGenerationEvaluator(base_evaluate.Evaluator):
|
14 |
+
def __init__(self, opt, model, data_loader, track=False):
|
15 |
+
super(ConceptNetGenerationEvaluator, self).__init__(
|
16 |
+
opt, model, data_loader)
|
17 |
+
|
18 |
+
if track:
|
19 |
+
self.tracker = {"positive": [], "negative": []}
|
20 |
+
else:
|
21 |
+
self.tracker = None
|
22 |
+
|
23 |
+
def batch(self, opt, nums, average_loss, batch_variables, eval_mode):
|
24 |
+
batch_variables["category"] = self.current_category
|
25 |
+
|
26 |
+
outputs = batch_utils.batch_conceptnet_generate(
|
27 |
+
opt, nums, average_loss, batch_variables, eval_mode,
|
28 |
+
tracking_mode=self.tracker is not None)
|
29 |
+
|
30 |
+
if outputs.get("tracking", None) is not None:
|
31 |
+
self.tracker[self.current_category] += outputs["tracking"]
|
32 |
+
|
33 |
+
if outputs["reset"] and batch_variables["category"] == "positive":
|
34 |
+
outputs["reset"] = False
|
35 |
+
self.current_category = "negative"
|
36 |
+
|
37 |
+
return outputs
|
38 |
+
|
39 |
+
def initialize_losses(self):
|
40 |
+
average_loss = {"total_micro": 0, "total_macro": 0,
|
41 |
+
"negative_micro": 0, "negative_macro": 0}
|
42 |
+
nums = {"total_micro": 0, "total_macro": 0,
|
43 |
+
"negative_micro": 0, "negative_macro": 0}
|
44 |
+
|
45 |
+
self.current_category = "positive"
|
46 |
+
|
47 |
+
if self.tracker is not None:
|
48 |
+
self.tracker = {"positive": [], "negative": []}
|
49 |
+
|
50 |
+
return average_loss, nums
|
51 |
+
|
52 |
+
def compute_final_scores(self, average_loss, nums):
|
53 |
+
average_loss["total_macro"] /= nums["total_macro"]
|
54 |
+
average_loss["total_micro"] /= nums["total_micro"]
|
55 |
+
|
56 |
+
if nums["negative_micro"]:
|
57 |
+
average_loss["negative_macro"] /= nums["negative_macro"]
|
58 |
+
average_loss["negative_micro"] /= nums["negative_micro"]
|
59 |
+
else:
|
60 |
+
average_loss["negative_macro"] = 0
|
61 |
+
average_loss["negative_micro"] = 0
|
62 |
+
|
63 |
+
average_loss["macro_diff"] = (average_loss["negative_macro"] -
|
64 |
+
average_loss["total_macro"])
|
65 |
+
average_loss["micro_diff"] = (average_loss["negative_micro"] -
|
66 |
+
average_loss["total_micro"])
|
67 |
+
|
68 |
+
average_loss["ppl_macro"] = np.exp(average_loss["total_macro"])
|
69 |
+
average_loss["ppl_micro"] = np.exp(average_loss["total_micro"])
|
70 |
+
|
71 |
+
return average_loss
|
72 |
+
|
73 |
+
def counter(self, nums):
|
74 |
+
return nums["total_macro"]
|
75 |
+
|
76 |
+
def print_result(self, split, epoch_losses):
|
77 |
+
print("{} Loss: \t {}".format(
|
78 |
+
split, epoch_losses["total_micro"]))
|
79 |
+
print("{} Diff: \t {}".format(
|
80 |
+
split, epoch_losses["micro_diff"]))
|
81 |
+
print("{} Perplexity: \t {}".format(
|
82 |
+
split, epoch_losses["ppl_micro"]))
|
Model/COSMIC/feature_extraction/src/evaluate/conceptnet_generate.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import comet.src.evaluate.generate as base_generate
|
5 |
+
import comet.src.evaluate.sampler as sampling
|
6 |
+
import comet.utils.utils as utils
|
7 |
+
import comet.src.data.config as cfg
|
8 |
+
|
9 |
+
|
10 |
+
def make_generator(opt, *args):
|
11 |
+
return ConceptNetGenerator(opt, *args)
|
12 |
+
|
13 |
+
|
14 |
+
class ConceptNetGenerator(base_generate.Generator):
|
15 |
+
def __init__(self, opt, model, data_loader):
|
16 |
+
self.opt = opt
|
17 |
+
|
18 |
+
self.model = model
|
19 |
+
self.data_loader = data_loader
|
20 |
+
|
21 |
+
self.sampler = sampling.make_sampler(
|
22 |
+
opt.eval.sample, opt, data_loader)
|
23 |
+
|
24 |
+
def reset_sequences(self):
|
25 |
+
return []
|
26 |
+
|
27 |
+
def generate(self, split="dev"):
|
28 |
+
print("Generating Sequences")
|
29 |
+
|
30 |
+
# Set evaluation mode
|
31 |
+
self.model.eval()
|
32 |
+
|
33 |
+
# Reset evaluation set for dataset split
|
34 |
+
self.data_loader.reset_offsets(splits=split, shuffle=False)
|
35 |
+
|
36 |
+
start = time.time()
|
37 |
+
count = 0
|
38 |
+
sequences = None
|
39 |
+
|
40 |
+
# Reset generated sequence buffer
|
41 |
+
sequences = self.reset_sequences()
|
42 |
+
|
43 |
+
# Initialize progress bar
|
44 |
+
bar = utils.set_progress_bar(
|
45 |
+
self.data_loader.total_size[split] / 2)
|
46 |
+
|
47 |
+
reset = False
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
# Cycle through development set
|
51 |
+
while not reset:
|
52 |
+
|
53 |
+
start = len(sequences)
|
54 |
+
# Generate a single batch
|
55 |
+
reset = self.generate_batch(sequences, split, bs=1)
|
56 |
+
|
57 |
+
end = len(sequences)
|
58 |
+
|
59 |
+
if not reset:
|
60 |
+
bar.update(end - start)
|
61 |
+
else:
|
62 |
+
print(end)
|
63 |
+
|
64 |
+
count += 1
|
65 |
+
|
66 |
+
if cfg.toy and count > 10:
|
67 |
+
break
|
68 |
+
if (self.opt.eval.gs != "full" and (count > opt.eval.gs)):
|
69 |
+
break
|
70 |
+
|
71 |
+
torch.cuda.synchronize()
|
72 |
+
print("{} generations completed in: {} s".format(
|
73 |
+
split, time.time() - start))
|
74 |
+
|
75 |
+
# Compute scores for sequences (e.g., BLEU, ROUGE)
|
76 |
+
# Computes scores that the generator is initialized with
|
77 |
+
# Change define_scorers to add more scorers as possibilities
|
78 |
+
# avg_scores, indiv_scores = self.compute_sequence_scores(
|
79 |
+
# sequences, split)
|
80 |
+
avg_scores, indiv_scores = None, None
|
81 |
+
|
82 |
+
return sequences, avg_scores, indiv_scores
|
83 |
+
|
84 |
+
def generate_batch(self, sequences, split, verbose=False, bs=1):
|
85 |
+
# Sample batch from data loader
|
86 |
+
batch, reset = self.data_loader.sample_batch(
|
87 |
+
split, bs=bs, cat="positive")
|
88 |
+
|
89 |
+
start_idx = self.data_loader.max_e1 + self.data_loader.max_r
|
90 |
+
max_end_len = self.data_loader.max_e2
|
91 |
+
|
92 |
+
context = batch["sequences"][:, :start_idx]
|
93 |
+
reference = batch["sequences"][:, start_idx:]
|
94 |
+
init = "".join([self.data_loader.vocab_decoder[i].replace(
|
95 |
+
'</w>', ' ') for i in context[:, :self.data_loader.max_e1].squeeze().tolist() if i]).strip()
|
96 |
+
|
97 |
+
start = self.data_loader.max_e1
|
98 |
+
end = self.data_loader.max_e1 + self.data_loader.max_r
|
99 |
+
|
100 |
+
attr = "".join([self.data_loader.vocab_decoder[i].replace(
|
101 |
+
'</w>', ' ') for i in context[:, start:end].squeeze(0).tolist() if i]).strip()
|
102 |
+
|
103 |
+
# Decode sequence
|
104 |
+
sampling_result = self.sampler.generate_sequence(
|
105 |
+
batch, self.model, self.data_loader, start_idx, max_end_len)
|
106 |
+
|
107 |
+
sampling_result["key"] = batch["key"]
|
108 |
+
sampling_result["e1"] = init
|
109 |
+
sampling_result["r"] = attr
|
110 |
+
sequences.append(sampling_result)
|
111 |
+
|
112 |
+
return reset
|
Model/COSMIC/feature_extraction/src/evaluate/evaluate.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import comet.utils.utils as utils
|
5 |
+
import comet.src.data.config as cfg
|
6 |
+
|
7 |
+
|
8 |
+
class Evaluator(object):
|
9 |
+
def __init__(self, opt, model, data_loader):
|
10 |
+
super(Evaluator, self).__init__()
|
11 |
+
|
12 |
+
self.data_loader = data_loader
|
13 |
+
self.model = model
|
14 |
+
|
15 |
+
self.batch_variables = {
|
16 |
+
"model": model,
|
17 |
+
"data": data_loader
|
18 |
+
}
|
19 |
+
|
20 |
+
self.opt = opt
|
21 |
+
|
22 |
+
def validate(self, l, split="dev", losses={}, keyset=None):
|
23 |
+
self.batch_variables["split"] = split
|
24 |
+
print("Evaluating {}".format(split))
|
25 |
+
|
26 |
+
epoch_losses = self.epoch(
|
27 |
+
self.opt, self.model, self.data_loader, split, keyset)
|
28 |
+
|
29 |
+
self.print_result(split, epoch_losses)
|
30 |
+
|
31 |
+
for loss_name, loss_val in epoch_losses.items():
|
32 |
+
losses.setdefault(loss_name, {})
|
33 |
+
losses[loss_name][l] = loss_val
|
34 |
+
|
35 |
+
def epoch(self, opt, model, data_loader, split, keyset=None):
|
36 |
+
average_loss, nums = self.initialize_losses()
|
37 |
+
|
38 |
+
data_loader.reset_offsets(splits=split, shuffle=False)
|
39 |
+
|
40 |
+
# Set evaluation mode
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
start = time.time()
|
44 |
+
|
45 |
+
# Initialize progress bar
|
46 |
+
bar = utils.set_progress_bar(
|
47 |
+
data_loader.total_size[split])
|
48 |
+
|
49 |
+
reset = False
|
50 |
+
|
51 |
+
with torch.no_grad():
|
52 |
+
while not reset:
|
53 |
+
|
54 |
+
start = data_loader.offset_summary(split)
|
55 |
+
|
56 |
+
outputs = self.batch(
|
57 |
+
opt, nums, average_loss,
|
58 |
+
self.batch_variables, eval_mode=True)
|
59 |
+
|
60 |
+
end = data_loader.offset_summary(split)
|
61 |
+
|
62 |
+
reset = outputs["reset"]
|
63 |
+
|
64 |
+
if not reset:
|
65 |
+
bar.update(end - start)
|
66 |
+
else:
|
67 |
+
print(end)
|
68 |
+
|
69 |
+
if cfg.toy and self.counter(nums) > 100:
|
70 |
+
break
|
71 |
+
if (opt.eval.es != "full" and
|
72 |
+
(self.counter(nums) > opt.eval.es)):
|
73 |
+
break
|
74 |
+
|
75 |
+
nums = outputs["nums"]
|
76 |
+
|
77 |
+
torch.cuda.synchronize()
|
78 |
+
|
79 |
+
print("{} evaluation completed in: {} s".format(
|
80 |
+
split.capitalize(), time.time() - start))
|
81 |
+
|
82 |
+
average_loss = self.compute_final_scores(
|
83 |
+
average_loss, nums)
|
84 |
+
|
85 |
+
return average_loss
|
Model/COSMIC/feature_extraction/src/evaluate/generate.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import comet.src.data.data as data
|
2 |
+
import comet.src.data.config as cfg
|
3 |
+
import comet.src.evaluate.sampler as sampling
|
4 |
+
|
5 |
+
|
6 |
+
def do_gen_run(opt, generator, l, split="dev", scores={}):
|
7 |
+
# Generate sequences for examples in evaluation set using
|
8 |
+
# current trained model
|
9 |
+
|
10 |
+
if opt.eval.gs == "full":
|
11 |
+
sequences, avg_scores, indiv_scores = generator.generate(split)
|
12 |
+
else:
|
13 |
+
sequences, avg_scores, indiv_scores = generator.generate_some(split)
|
14 |
+
|
15 |
+
if avg_scores is not None:
|
16 |
+
# Record scores from generated sequences
|
17 |
+
for score_name, score_val in avg_scores.items():
|
18 |
+
scores.setdefault(score_name, {})
|
19 |
+
scores[score_name].setdefault(l, [])
|
20 |
+
scores[score_name][l] += [score_val]
|
21 |
+
|
22 |
+
# Save generated sequences
|
23 |
+
save_sequences(opt, sequences, avg_scores, indiv_scores,
|
24 |
+
l, split, opt.eval.gs == "full",
|
25 |
+
generator.data_loader)
|
26 |
+
|
27 |
+
|
28 |
+
def save_sequences(opt, sequences, avg_scores, indiv_scores,
|
29 |
+
l, split, full, data_loader):
|
30 |
+
# This seems a bit roundabout since l = opt.train.dynamic in train.py
|
31 |
+
# But it's in case we start checkpointing outside of epoch boundaries
|
32 |
+
opt.train.dynamic.epoch = l
|
33 |
+
|
34 |
+
if cfg.save:
|
35 |
+
if full:
|
36 |
+
names = {"gens": "gens", "scores": "scores",
|
37 |
+
"indiv": "indiv.scores"}
|
38 |
+
else:
|
39 |
+
names = {"gens": "gens.small", "scores": "scores.small",
|
40 |
+
"indiv": "indiv.scores.small"}
|
41 |
+
# Save generated sequences
|
42 |
+
data.save_eval_file(opt, sequences, names["gens"], split)
|
43 |
+
|
44 |
+
if avg_scores is not None:
|
45 |
+
# Save average scores over evaluation set for generated sequences
|
46 |
+
# Scores computed are the ones the generator was initialized with
|
47 |
+
data.save_eval_file(opt, avg_scores, names["scores"], split)
|
48 |
+
|
49 |
+
if split == "dev":
|
50 |
+
# Save individual scores
|
51 |
+
data.save_eval_file(
|
52 |
+
opt, indiv_scores, names["indiv"], split)
|
53 |
+
|
54 |
+
|
55 |
+
class Generator(object):
|
56 |
+
def __init__(self, opt, model, data_loader, scorers, reward_function=None):
|
57 |
+
super(Generator, self).__init__()
|
58 |
+
self.opt = opt
|
59 |
+
|
60 |
+
self.model = model
|
61 |
+
self.data_loader = data_loader
|
62 |
+
|
63 |
+
self.sampler = sampling.make_sampler(
|
64 |
+
opt.eval.sample, opt, data_loader)
|
65 |
+
|
66 |
+
|
67 |
+
def generate(self, split="dev"):
|
68 |
+
pass
|
69 |
+
|
70 |
+
def generate_batch(self, sequences, split, verbose=False, bs=32):
|
71 |
+
pass
|
72 |
+
|
Model/COSMIC/feature_extraction/src/evaluate/sampler.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import comet.src.data.data as data
|
6 |
+
import comet.src.data.config as cfg
|
7 |
+
import comet.src.models.utils as model_utils
|
8 |
+
import comet.src.evaluate.utils as eval_utils
|
9 |
+
import comet.src.train.batch as batch_utils
|
10 |
+
|
11 |
+
def make_sampler(sampler_type, opt, *args, **kwargs):
|
12 |
+
print("Initializing Greedy Sampler")
|
13 |
+
return GreedySampler(opt, *args, **kwargs)
|
14 |
+
|
15 |
+
class Sampler():
|
16 |
+
def __init__(self, opt, data_loader, batch_mode=False):
|
17 |
+
# Token on which to end sampling
|
18 |
+
self.end_token = data_loader.vocab_encoder[data.end_token]
|
19 |
+
|
20 |
+
self.opt = opt
|
21 |
+
|
22 |
+
def generate_sequence(self, batch, model):
|
23 |
+
raise
|
24 |
+
|
25 |
+
|
26 |
+
class GreedySampler(Sampler):
|
27 |
+
def __init__(self, opt, data_loader, batch_mode=True):
|
28 |
+
super(GreedySampler, self).__init__(opt, data_loader)
|
29 |
+
|
30 |
+
def append_batch(self, X, next_idx, mask):
|
31 |
+
next_pos = X[:, -1:, 1] + 1
|
32 |
+
next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
|
33 |
+
next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
|
34 |
+
return torch.cat((X, next_x), 1), next_mask
|
35 |
+
|
36 |
+
def generate_sequence(self, batch, model, data_loader, start_idx, end_len):
|
37 |
+
XMB = batch["sequences"][:, :start_idx]
|
38 |
+
MMB = batch["attention_mask"][:, :start_idx]
|
39 |
+
|
40 |
+
XMB = model_utils.prepare_position_embeddings(
|
41 |
+
self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1))
|
42 |
+
|
43 |
+
_, lp = model(
|
44 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
45 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
46 |
+
|
47 |
+
values, indices = lm_probs[:, -1, :].max(dim=-1)
|
48 |
+
seqs = indices.clone().unsqueeze(1)
|
49 |
+
|
50 |
+
loss = values
|
51 |
+
counts = 1
|
52 |
+
next_pos = XMB[:, -1:, 1] + 1
|
53 |
+
next_x = torch.cat((indices.view(-1, 1), next_pos), -1).unsqueeze(1)
|
54 |
+
XMB = torch.cat((XMB, next_x), 1)
|
55 |
+
MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)
|
56 |
+
|
57 |
+
# Sample from top k
|
58 |
+
|
59 |
+
for _ in range(self.opt.eval.smax):
|
60 |
+
_, lp = model(
|
61 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
62 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
63 |
+
|
64 |
+
# Sample from top k
|
65 |
+
values, next_idx = lm_probs[:, -1, :].max(dim=-1)
|
66 |
+
|
67 |
+
loss += values
|
68 |
+
counts += 1
|
69 |
+
|
70 |
+
next_idx = next_idx.unsqueeze(1)
|
71 |
+
|
72 |
+
seqs = torch.cat([seqs, next_idx], 1)
|
73 |
+
|
74 |
+
if (next_idx.item() == self.end_token) or (_ == end_len - 1):
|
75 |
+
break
|
76 |
+
|
77 |
+
XMB, MMB = self.append_batch(XMB, next_idx, MMB)
|
78 |
+
|
79 |
+
beams = []
|
80 |
+
|
81 |
+
for beam in seqs:
|
82 |
+
beams.append(" ".join("".join(
|
83 |
+
[data_loader.vocab_decoder[tok.item()].replace(
|
84 |
+
'</w>', ' ').replace('\n', '')
|
85 |
+
for tok in beam if tok != self.end_token]).split()))
|
86 |
+
|
87 |
+
sampling_result = {
|
88 |
+
"sequence": beams[0],
|
89 |
+
"beams": beams,
|
90 |
+
"beam_losses": [loss.item()],
|
91 |
+
"loss": loss.item(),
|
92 |
+
"beam_lengths": [counts],
|
93 |
+
"length": counts
|
94 |
+
}
|
95 |
+
|
96 |
+
return sampling_result
|
97 |
+
|
98 |
+
|
99 |
+
class TopKSampler(Sampler):
|
100 |
+
def __init__(self, opt, data_loader, batch_mode=True):
|
101 |
+
super(TopKSampler, self).__init__(opt, data_loader)
|
102 |
+
|
103 |
+
def append_batch(self, X, next_idx, mask):
|
104 |
+
next_pos = X[:, -1:, 1] + 1
|
105 |
+
next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
|
106 |
+
next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
|
107 |
+
return torch.cat((X, next_x), 1), next_mask
|
108 |
+
|
109 |
+
def generate_sequence(self, batch, model, data_loader, start_idx, end_len):
|
110 |
+
# start_idx = context_size_event + 1
|
111 |
+
# start_idx = max_e1 + max_r
|
112 |
+
# end_idx = context_size_effect - 1
|
113 |
+
# end_idx = max_e2
|
114 |
+
XMB = batch["sequences"][:, :start_idx]
|
115 |
+
MMB = batch["attention_mask"][:, :start_idx]
|
116 |
+
|
117 |
+
XMB = model_utils.prepare_position_embeddings(
|
118 |
+
self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1))
|
119 |
+
|
120 |
+
_, lp = model(
|
121 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
122 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
123 |
+
|
124 |
+
values, indices = lm_probs[:, -1, :].topk(self.opt.eval.k)
|
125 |
+
seqs = indices.t().clone()
|
126 |
+
|
127 |
+
losses = - values.view(-1, 1)
|
128 |
+
|
129 |
+
ended = (seqs == self.end_token).float()
|
130 |
+
counts = (1 - ended)
|
131 |
+
XMB = XMB.repeat(self.opt.eval.k, 1, 1)
|
132 |
+
MMB = MMB.repeat(self.opt.eval.k, 1)
|
133 |
+
next_pos = XMB[:, -1:, 1] + 1
|
134 |
+
next_x = torch.cat((indices.view(self.opt.eval.k, -1), next_pos), -1).unsqueeze(1)
|
135 |
+
XMB = torch.cat((XMB, next_x), 1)
|
136 |
+
MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)
|
137 |
+
|
138 |
+
# Sample from top k
|
139 |
+
|
140 |
+
for _ in range(end_len):
|
141 |
+
_, lp = model(XMB.unsqueeze(1), sequence_mask=MMB)
|
142 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
143 |
+
|
144 |
+
# Sample from top k
|
145 |
+
values, indices = lm_probs[:, -1, :].topk(self.opt.eval.k)
|
146 |
+
choice = torch.multinomial(values.exp(), 1)
|
147 |
+
next_idx = indices.gather(-1, choice)
|
148 |
+
|
149 |
+
ended = ended + (next_idx == self.end_token).float() * (1 - ended)
|
150 |
+
|
151 |
+
next_idx = next_idx * (1 - ended).long() + ended.long() * self.end_token
|
152 |
+
|
153 |
+
counts += (1 - ended)
|
154 |
+
|
155 |
+
seqs = torch.cat([seqs, next_idx], 1)
|
156 |
+
|
157 |
+
if ended.sum().item() == self.opt.eval.k:
|
158 |
+
break
|
159 |
+
|
160 |
+
losses -= values.gather(-1, choice) * (1 - ended)
|
161 |
+
|
162 |
+
XMB, MMB = self.append_batch(XMB, next_idx, MMB)
|
163 |
+
|
164 |
+
beams = []
|
165 |
+
|
166 |
+
for beam in seqs:
|
167 |
+
beams.append(" ".join("".join(
|
168 |
+
[data_loader.vocab_decoder[tok.item()].replace(
|
169 |
+
'</w>', ' ').replace('\n', '')
|
170 |
+
for tok in beam if tok != self.end_token]).split()))
|
171 |
+
|
172 |
+
sampling_result = {
|
173 |
+
"sequence": beams[0],
|
174 |
+
"beams": beams,
|
175 |
+
"beam_losses": losses.squeeze().tolist(),
|
176 |
+
"loss": losses[0].item(),
|
177 |
+
"beam_lengths": counts.long().squeeze().tolist(),
|
178 |
+
"length": counts[0].long().item()
|
179 |
+
}
|
180 |
+
|
181 |
+
return sampling_result
|
182 |
+
|
183 |
+
|
184 |
+
class BeamSampler(TopKSampler):
|
185 |
+
def __init__(self, opt, data_loader, batch_mode=True, scorer=None):
|
186 |
+
super(BeamSampler, self).__init__(opt, data_loader, batch_mode)
|
187 |
+
|
188 |
+
self.kill_mask = torch.ones(opt.eval.bs, opt.eval.bs).to(cfg.device) * 9000
|
189 |
+
self.kill_mask[:, 0] = 0
|
190 |
+
|
191 |
+
def make_batch(self, X):
|
192 |
+
X = np.array(X)
|
193 |
+
assert X.ndim in [1, 2]
|
194 |
+
if X.ndim == 1:
|
195 |
+
X = np.expand_dims(X, axis=0)
|
196 |
+
pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
|
197 |
+
pos_enc = np.expand_dims(pos_enc, axis=0)
|
198 |
+
batch = np.stack([X, pos_enc], axis=-1)
|
199 |
+
batch = torch.tensor(batch, dtype=torch.long).to(device)
|
200 |
+
return batch
|
201 |
+
|
202 |
+
def append_batch(self, X, beam_toks, mask):
|
203 |
+
next_pos = X[:, -1:, 1] + 1
|
204 |
+
next_x = torch.cat((beam_toks.unsqueeze(1), next_pos), -1).unsqueeze(1)
|
205 |
+
next_mask = torch.cat([mask, torch.ones(X.size(0), 1, device=mask.device)], 1)
|
206 |
+
return torch.cat((X, next_x), 1), next_mask
|
207 |
+
|
208 |
+
def generate_sequence(self, batch, model, data_loader, start_idx, end_len):
|
209 |
+
# start_idx = context_size_event + 1
|
210 |
+
# start_idx = max_e1 + max_r
|
211 |
+
# end_idx = context_size_effect - 1
|
212 |
+
# end_idx = max_e2
|
213 |
+
XMB = batch["sequences"][:, :start_idx]
|
214 |
+
MMB = batch["attention_mask"][:, :start_idx]
|
215 |
+
|
216 |
+
XMB = model_utils.prepare_position_embeddings(
|
217 |
+
self.opt, data_loader.vocab_encoder, XMB.unsqueeze(-1))
|
218 |
+
|
219 |
+
tokens = []
|
220 |
+
beam_losses = []
|
221 |
+
# Beam Search
|
222 |
+
beam_lls, beam_toks, beam_seqs = None, None, None
|
223 |
+
_, lp = model(XMB.unsqueeze(1), sequence_mask=MMB)
|
224 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
225 |
+
dist = lm_probs[:, -1, :].squeeze()
|
226 |
+
beam_lls, beam_toks = dist.topk(self.opt.eval.bs)
|
227 |
+
beam_losses.append(beam_lls)
|
228 |
+
|
229 |
+
ended = (beam_toks == self.end_token).float()
|
230 |
+
counts = (2 - ended)
|
231 |
+
beam_toks = beam_toks.unsqueeze(1)
|
232 |
+
beam_seqs = beam_toks.clone()
|
233 |
+
XMB = XMB.repeat(self.opt.eval.bs, 1, 1)
|
234 |
+
MMB = MMB.repeat(self.opt.eval.bs, 1)
|
235 |
+
next_pos = XMB[:, -1:, 1] + 1
|
236 |
+
next_x = torch.cat((beam_toks, next_pos), -1).unsqueeze(1)
|
237 |
+
XMB = torch.cat((XMB, next_x), 1)
|
238 |
+
MMB = torch.cat([MMB, torch.ones(XMB.size(0), 1, device=MMB.device)], 1)
|
239 |
+
|
240 |
+
for _ in range(end_len):
|
241 |
+
|
242 |
+
# Compute distribution for current beam
|
243 |
+
_, lp = model(
|
244 |
+
XMB.unsqueeze(1), sequence_mask=MMB)
|
245 |
+
lm_probs = F.log_softmax(lp, dim=-1)
|
246 |
+
dist = lm_probs[:, -1, :].squeeze()
|
247 |
+
|
248 |
+
# get hypothesis tokens for distribution
|
249 |
+
hyp_beam_lls, hyp_beam_toks = dist.topk(self.opt.eval.bs)
|
250 |
+
|
251 |
+
# Compute masks and expand beam
|
252 |
+
expanded_ended = ended.unsqueeze(1).repeat(1, self.opt.eval.bs)
|
253 |
+
hypothesis_mask = expanded_ended * self.kill_mask + (1 - expanded_ended)
|
254 |
+
|
255 |
+
paper_results = False
|
256 |
+
|
257 |
+
if paper_results:
|
258 |
+
# Results from paper with slightly buggy beam search
|
259 |
+
current_beam_lls = beam_lls.unsqueeze(1).repeat(
|
260 |
+
1, self.opt.eval.bs).view(self.opt.eval.bs**2)
|
261 |
+
else:
|
262 |
+
# Current beam search implementation
|
263 |
+
current_beam_lls = beam_losses[-1].unsqueeze(1).repeat(
|
264 |
+
1, self.opt.eval.bs).view(self.opt.eval.bs**2)
|
265 |
+
|
266 |
+
# Compute losses of hypotheses, masking those that have ended
|
267 |
+
hyp_beam_lls = (hyp_beam_lls.view(self.opt.eval.bs**2) *
|
268 |
+
hypothesis_mask.view(-1)) + current_beam_lls
|
269 |
+
|
270 |
+
# Get normalizer for sequences
|
271 |
+
temp_counts = counts.unsqueeze(1).repeat(1, self.opt.eval.bs).view(
|
272 |
+
self.opt.eval.bs ** 2)
|
273 |
+
|
274 |
+
# Select best beams with lowest aggregate loss
|
275 |
+
beam_lls, top_beam_idxs = (hyp_beam_lls / temp_counts).topk(self.opt.eval.bs)
|
276 |
+
|
277 |
+
# Update placements in beam based on selecetion
|
278 |
+
beam_losses = [i.index_select(0, top_beam_idxs // self.opt.eval.bs)
|
279 |
+
for i in beam_losses]
|
280 |
+
ended = ended.index_select(0, top_beam_idxs // self.opt.eval.bs)
|
281 |
+
counts = temp_counts.index_select(0, top_beam_idxs)
|
282 |
+
|
283 |
+
# Save beam losses
|
284 |
+
beam_losses.append(beam_lls * counts)
|
285 |
+
|
286 |
+
# Update beam tokens
|
287 |
+
ended_mask = (1 - ended).long()
|
288 |
+
end_replacement = (self.end_token * ended).long()
|
289 |
+
next_toks = hyp_beam_toks.view(-1)[top_beam_idxs]
|
290 |
+
beam_toks = next_toks * ended_mask + end_replacement
|
291 |
+
|
292 |
+
# Update ended and counts
|
293 |
+
ended = ended + (beam_toks == self.end_token).float() * (1 - ended)
|
294 |
+
counts = counts + (1 - ended)
|
295 |
+
|
296 |
+
# Update beam sequences
|
297 |
+
beam_seqs = beam_seqs.t().repeat(self.opt.eval.bs, 1).t().contiguous().view(
|
298 |
+
self.opt.eval.bs**2, -1)[top_beam_idxs]
|
299 |
+
beam_seqs = torch.cat((beam_seqs, beam_toks.unsqueeze(1)), dim=1)
|
300 |
+
|
301 |
+
# I have no idea what's going on but Ari's on point with it
|
302 |
+
XMB = XMB.transpose(0, 1).transpose(1, 2).repeat(
|
303 |
+
self.opt.eval.bs, 1, 1).transpose(2, 1).transpose(
|
304 |
+
1, 0).contiguous().view(
|
305 |
+
self.opt.eval.bs**2, XMB.size(1), XMB.size(2))[top_beam_idxs]
|
306 |
+
|
307 |
+
XMB, MMB = self.append_batch(XMB, beam_toks, MMB)
|
308 |
+
|
309 |
+
if (beam_toks == self.end_token).sum().item() == self.opt.eval.bs:
|
310 |
+
break
|
311 |
+
|
312 |
+
beams = []
|
313 |
+
|
314 |
+
for beam in beam_seqs:
|
315 |
+
beams.append(" ".join("".join(
|
316 |
+
[data_loader.vocab_decoder[tok.item()].replace(
|
317 |
+
'</w>', ' ').replace('\n', '')
|
318 |
+
for tok in beam if tok != self.end_token]).split()))
|
319 |
+
|
320 |
+
sampling_result = {
|
321 |
+
"sequence": beams[0],
|
322 |
+
"beams": beams,
|
323 |
+
"beam_losses": beam_lls.tolist(),
|
324 |
+
"loss": beam_lls[0].item(),
|
325 |
+
"beam_lengths": counts.tolist(),
|
326 |
+
"length": counts[0].item()
|
327 |
+
}
|
328 |
+
|
329 |
+
return sampling_result
|
Model/COSMIC/feature_extraction/src/evaluate/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def update_classification_losses(losses, nums, name, bs, loss):
|
3 |
+
if not isinstance(loss, float):
|
4 |
+
print(type(loss))
|
5 |
+
raise
|
6 |
+
|
7 |
+
nums[name] += bs
|
8 |
+
|
9 |
+
losses[name] += loss * bs
|
10 |
+
|
11 |
+
|
12 |
+
def update_generation_losses(losses, nums, micro, macro, bs, length, loss):
|
13 |
+
# Update Losses
|
14 |
+
nums[macro] += bs
|
15 |
+
|
16 |
+
if isinstance(length, int):
|
17 |
+
update_indiv_generation_losses(
|
18 |
+
losses, nums, micro, macro, bs, length, loss)
|
19 |
+
else:
|
20 |
+
update_tensor_generation_losses(
|
21 |
+
losses, nums, micro, macro, bs, length, loss)
|
22 |
+
|
23 |
+
|
24 |
+
def update_indiv_generation_losses(losses, nums, micro,
|
25 |
+
macro, bs, length, loss):
|
26 |
+
nums[micro] += bs * length
|
27 |
+
|
28 |
+
batch_loss = loss * bs
|
29 |
+
|
30 |
+
losses[micro] += batch_loss
|
31 |
+
losses[macro] += batch_loss / length
|
32 |
+
|
33 |
+
|
34 |
+
def update_tensor_generation_losses(losses, nums, micro,
|
35 |
+
macro, bs, length, loss):
|
36 |
+
nums[micro] += length.sum().item()
|
37 |
+
|
38 |
+
losses[micro] += loss.sum().item()
|
39 |
+
losses[macro] += (loss / length.float()).sum().item()
|
Model/COSMIC/feature_extraction/src/interactive/functions.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from comet.src.data.utils import TextEncoder
|
4 |
+
import comet.src.data.config as cfg
|
5 |
+
import comet.src.data.data as data
|
6 |
+
import comet.src.models.models as models
|
7 |
+
from comet.src.evaluate.sampler import BeamSampler, GreedySampler, TopKSampler
|
8 |
+
|
9 |
+
import comet.utils.utils as utils
|
10 |
+
|
11 |
+
|
12 |
+
def load_model_file(model_file):
|
13 |
+
model_stuff = data.load_checkpoint(model_file)
|
14 |
+
opt = model_stuff["opt"]
|
15 |
+
state_dict = model_stuff["state_dict"]
|
16 |
+
|
17 |
+
return opt, state_dict
|
18 |
+
|
19 |
+
def load_data(dataset, opt):
|
20 |
+
if dataset == "atomic":
|
21 |
+
data_loader = load_atomic_data(opt)
|
22 |
+
elif dataset == "conceptnet":
|
23 |
+
data_loader = load_conceptnet_data(opt)
|
24 |
+
|
25 |
+
# Initialize TextEncoder
|
26 |
+
encoder_path = "comet/model/encoder_bpe_40000.json"
|
27 |
+
bpe_path = "comet/model/vocab_40000.bpe"
|
28 |
+
text_encoder = TextEncoder(encoder_path, bpe_path)
|
29 |
+
text_encoder.encoder = data_loader.vocab_encoder
|
30 |
+
text_encoder.decoder = data_loader.vocab_decoder
|
31 |
+
|
32 |
+
return data_loader, text_encoder
|
33 |
+
|
34 |
+
|
35 |
+
def load_atomic_data(opt):
|
36 |
+
# Hacky workaround, you may have to change this
|
37 |
+
# if your models use different pad lengths for e1, e2, r
|
38 |
+
if opt.data.get("maxe1", None) is None:
|
39 |
+
opt.data.maxe1 = 17
|
40 |
+
opt.data.maxe2 = 35
|
41 |
+
opt.data.maxr = 1
|
42 |
+
# path = "data/atomic/processed/generation/{}.pickle".format(
|
43 |
+
# utils.make_name_string(opt.data))
|
44 |
+
path = "comet/data/atomic/processed/generation/categories_oEffect#oReact#oWant#xAttr#xEffect#xIntent#xNeed#xReact#xWant-maxe1_17-maxe2_35-maxr_1.pickle"
|
45 |
+
data_loader = data.make_data_loader(opt, opt.data.categories)
|
46 |
+
loaded = data_loader.load_data(path)
|
47 |
+
|
48 |
+
return data_loader
|
49 |
+
|
50 |
+
|
51 |
+
def load_conceptnet_data(opt):
|
52 |
+
# Hacky workaround, you may have to change this
|
53 |
+
# if your models use different pad lengths for r
|
54 |
+
if opt.data.get("maxr", None) is None:
|
55 |
+
if opt.data.rel == "language":
|
56 |
+
opt.data.maxr = 5
|
57 |
+
else:
|
58 |
+
opt.data.maxr = 1
|
59 |
+
path = "comet/data/conceptnet/processed/generation/{}.pickle".format(
|
60 |
+
utils.make_name_string(opt.data))
|
61 |
+
data_loader = data.make_data_loader(opt)
|
62 |
+
loaded = data_loader.load_data(path)
|
63 |
+
return data_loader
|
64 |
+
|
65 |
+
|
66 |
+
def make_model(opt, n_vocab, n_ctx, state_dict):
|
67 |
+
model = models.make_model(
|
68 |
+
opt, n_vocab, n_ctx, None, load=False,
|
69 |
+
return_acts=True, return_probs=False)
|
70 |
+
|
71 |
+
models.load_state_dict(model, state_dict)
|
72 |
+
|
73 |
+
model.eval()
|
74 |
+
return model
|
75 |
+
|
76 |
+
|
77 |
+
def set_sampler(opt, sampling_algorithm, data_loader):
|
78 |
+
if "beam" in sampling_algorithm:
|
79 |
+
opt.eval.bs = int(sampling_algorithm.split("-")[1])
|
80 |
+
sampler = BeamSampler(opt, data_loader)
|
81 |
+
elif "topk" in sampling_algorithm:
|
82 |
+
# print("Still bugs in the topk sampler. Use beam or greedy instead")
|
83 |
+
# raise NotImplementedError
|
84 |
+
opt.eval.k = int(sampling_algorithm.split("-")[1])
|
85 |
+
sampler = TopKSampler(opt, data_loader)
|
86 |
+
else:
|
87 |
+
sampler = GreedySampler(opt, data_loader)
|
88 |
+
|
89 |
+
return sampler
|
90 |
+
|
91 |
+
|
92 |
+
def get_atomic_sequence(input_event, model, sampler, data_loader, text_encoder, category):
|
93 |
+
if isinstance(category, list):
|
94 |
+
outputs = {}
|
95 |
+
for cat in category:
|
96 |
+
new_outputs = get_atomic_sequence(
|
97 |
+
input_event, model, sampler, data_loader, text_encoder, cat)
|
98 |
+
outputs.update(new_outputs)
|
99 |
+
return outputs
|
100 |
+
elif category == "all":
|
101 |
+
outputs = {}
|
102 |
+
|
103 |
+
for category in data_loader.categories:
|
104 |
+
new_outputs = get_atomic_sequence(
|
105 |
+
input_event, model, sampler, data_loader, text_encoder, category)
|
106 |
+
outputs.update(new_outputs)
|
107 |
+
return outputs
|
108 |
+
else:
|
109 |
+
|
110 |
+
sequence_all = {}
|
111 |
+
|
112 |
+
sequence_all["event"] = input_event
|
113 |
+
sequence_all["effect_type"] = category
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
|
117 |
+
batch = set_atomic_inputs(
|
118 |
+
input_event, category, data_loader, text_encoder)
|
119 |
+
|
120 |
+
sampling_result = sampler.generate_sequence(
|
121 |
+
batch, model, data_loader, data_loader.max_event +
|
122 |
+
data.atomic_data.num_delimiter_tokens["category"],
|
123 |
+
data_loader.max_effect -
|
124 |
+
data.atomic_data.num_delimiter_tokens["category"])
|
125 |
+
|
126 |
+
sequence_all['beams'] = sampling_result["beams"]
|
127 |
+
|
128 |
+
# print_atomic_sequence(sequence_all)
|
129 |
+
|
130 |
+
return {category: sequence_all}
|
131 |
+
|
132 |
+
|
133 |
+
def print_atomic_sequence(sequence_object):
|
134 |
+
input_event = sequence_object["event"]
|
135 |
+
category = sequence_object["effect_type"]
|
136 |
+
|
137 |
+
print("Input Event: {}".format(input_event))
|
138 |
+
print("Target Effect: {}".format(category))
|
139 |
+
print("")
|
140 |
+
print("Candidate Sequences:")
|
141 |
+
for beam in sequence_object["beams"]:
|
142 |
+
print(beam)
|
143 |
+
print("")
|
144 |
+
print("====================================================")
|
145 |
+
print("")
|
146 |
+
|
147 |
+
|
148 |
+
def set_atomic_inputs(input_event, category, data_loader, text_encoder):
|
149 |
+
XMB = torch.zeros(1, data_loader.max_event + 1).long().to(cfg.device)
|
150 |
+
prefix, suffix = data.atomic_data.do_example(text_encoder, input_event, None, True, None)
|
151 |
+
|
152 |
+
if len(prefix) > data_loader.max_event + 1:
|
153 |
+
prefix = prefix[:data_loader.max_event + 1]
|
154 |
+
|
155 |
+
XMB[:, :len(prefix)] = torch.LongTensor(prefix)
|
156 |
+
XMB[:, -1] = torch.LongTensor([text_encoder.encoder["<{}>".format(category)]])
|
157 |
+
|
158 |
+
batch = {}
|
159 |
+
batch["sequences"] = XMB
|
160 |
+
batch["attention_mask"] = data.atomic_data.make_attention_mask(XMB)
|
161 |
+
|
162 |
+
return batch
|
163 |
+
|
164 |
+
|
165 |
+
def get_conceptnet_sequence(e1, model, sampler, data_loader, text_encoder, relation, force=False):
|
166 |
+
if isinstance(relation, list):
|
167 |
+
outputs = {}
|
168 |
+
|
169 |
+
for rel in relation:
|
170 |
+
new_outputs = get_conceptnet_sequence(
|
171 |
+
e1, model, sampler, data_loader, text_encoder, rel)
|
172 |
+
outputs.update(new_outputs)
|
173 |
+
return outputs
|
174 |
+
elif relation == "all":
|
175 |
+
outputs = {}
|
176 |
+
|
177 |
+
for relation in data.conceptnet_data.conceptnet_relations:
|
178 |
+
new_outputs = get_conceptnet_sequence(
|
179 |
+
e1, model, sampler, data_loader, text_encoder, relation)
|
180 |
+
outputs.update(new_outputs)
|
181 |
+
return outputs
|
182 |
+
else:
|
183 |
+
|
184 |
+
sequence_all = {}
|
185 |
+
|
186 |
+
sequence_all["e1"] = e1
|
187 |
+
sequence_all["relation"] = relation
|
188 |
+
|
189 |
+
with torch.no_grad():
|
190 |
+
if data_loader.max_r != 1:
|
191 |
+
relation_sequence = data.conceptnet_data.split_into_words[relation]
|
192 |
+
else:
|
193 |
+
relation_sequence = "<{}>".format(relation)
|
194 |
+
|
195 |
+
batch, abort = set_conceptnet_inputs(
|
196 |
+
e1, relation_sequence, text_encoder,
|
197 |
+
data_loader.max_e1, data_loader.max_r, force)
|
198 |
+
|
199 |
+
if abort:
|
200 |
+
return {relation: sequence_all}
|
201 |
+
|
202 |
+
sampling_result = sampler.generate_sequence(
|
203 |
+
batch, model, data_loader,
|
204 |
+
data_loader.max_e1 + data_loader.max_r,
|
205 |
+
data_loader.max_e2)
|
206 |
+
|
207 |
+
sequence_all['beams'] = sampling_result["beams"]
|
208 |
+
|
209 |
+
print_conceptnet_sequence(sequence_all)
|
210 |
+
|
211 |
+
return {relation: sequence_all}
|
212 |
+
|
213 |
+
|
214 |
+
def set_conceptnet_inputs(input_event, relation, text_encoder, max_e1, max_r, force):
|
215 |
+
abort = False
|
216 |
+
|
217 |
+
e1_tokens, rel_tokens, _ = data.conceptnet_data.do_example(text_encoder, input_event, relation, None)
|
218 |
+
|
219 |
+
if len(e1_tokens) > max_e1:
|
220 |
+
if force:
|
221 |
+
XMB = torch.zeros(1, len(e1_tokens) + max_r).long().to(cfg.device)
|
222 |
+
else:
|
223 |
+
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device)
|
224 |
+
return {}, True
|
225 |
+
else:
|
226 |
+
XMB = torch.zeros(1, max_e1 + max_r).long().to(cfg.device)
|
227 |
+
|
228 |
+
XMB[:, :len(e1_tokens)] = torch.LongTensor(e1_tokens)
|
229 |
+
XMB[:, max_e1:max_e1 + len(rel_tokens)] = torch.LongTensor(rel_tokens)
|
230 |
+
|
231 |
+
batch = {}
|
232 |
+
batch["sequences"] = XMB
|
233 |
+
batch["attention_mask"] = data.conceptnet_data.make_attention_mask(XMB)
|
234 |
+
|
235 |
+
return batch, abort
|
236 |
+
|
237 |
+
|
238 |
+
def print_conceptnet_sequence(sequence_object):
|
239 |
+
e1 = sequence_object["e1"]
|
240 |
+
relation = sequence_object["relation"]
|
241 |
+
|
242 |
+
print("Input Entity: {}".format(e1))
|
243 |
+
print("Target Relation: {}".format(relation))
|
244 |
+
print("")
|
245 |
+
print("Candidate Sequences:")
|
246 |
+
for beam in sequence_object["beams"]:
|
247 |
+
print(beam)
|
248 |
+
print("")
|
249 |
+
print("====================================================")
|
250 |
+
print("")
|
251 |
+
|
252 |
+
|
253 |
+
def print_help(data):
|
254 |
+
print("")
|
255 |
+
if data == "atomic":
|
256 |
+
print("Provide a seed event such as \"PersonX goes to the mall\"")
|
257 |
+
print("Don't include names, instead replacing them with PersonX, PersonY, etc.")
|
258 |
+
print("The event should always have PersonX included")
|
259 |
+
if data == "conceptnet":
|
260 |
+
print("Provide a seed entity such as \"go to the mall\"")
|
261 |
+
print("Because the model was trained on lemmatized entities,")
|
262 |
+
print("it works best if the input entities are also lemmatized")
|
263 |
+
print("")
|
264 |
+
|
265 |
+
|
266 |
+
def print_relation_help(data):
|
267 |
+
print_category_help(data)
|
268 |
+
|
269 |
+
|
270 |
+
def print_category_help(data):
|
271 |
+
print("")
|
272 |
+
if data == "atomic":
|
273 |
+
print("Enter a possible effect type from the following effect types:")
|
274 |
+
print("all - compute the output for all effect types {{oEffect, oReact, oWant, xAttr, xEffect, xIntent, xNeed, xReact, xWant}}")
|
275 |
+
print("oEffect - generate the effect of the event on participants other than PersonX")
|
276 |
+
print("oReact - generate the reactions of participants other than PersonX to the event")
|
277 |
+
print("oEffect - generate what participants other than PersonX may want after the event")
|
278 |
+
elif data == "conceptnet":
|
279 |
+
print("Enter a possible relation from the following list:")
|
280 |
+
print("")
|
281 |
+
print('AtLocation')
|
282 |
+
print('CapableOf')
|
283 |
+
print('Causes')
|
284 |
+
print('CausesDesire')
|
285 |
+
print('CreatedBy')
|
286 |
+
print('DefinedAs')
|
287 |
+
print('DesireOf')
|
288 |
+
print('Desires')
|
289 |
+
print('HasA')
|
290 |
+
print('HasFirstSubevent')
|
291 |
+
print('HasLastSubevent')
|
292 |
+
print('HasPainCharacter')
|
293 |
+
print('HasPainIntensity')
|
294 |
+
print('HasPrerequisite')
|
295 |
+
print('HasProperty')
|
296 |
+
print('HasSubevent')
|
297 |
+
print('InheritsFrom')
|
298 |
+
print('InstanceOf')
|
299 |
+
print('IsA')
|
300 |
+
print('LocatedNear')
|
301 |
+
print('LocationOfAction')
|
302 |
+
print('MadeOf')
|
303 |
+
print('MotivatedByGoal')
|
304 |
+
print('NotCapableOf')
|
305 |
+
print('NotDesires')
|
306 |
+
print('NotHasA')
|
307 |
+
print('NotHasProperty')
|
308 |
+
print('NotIsA')
|
309 |
+
print('NotMadeOf')
|
310 |
+
print('PartOf')
|
311 |
+
print('ReceivesAction')
|
312 |
+
print('RelatedTo')
|
313 |
+
print('SymbolOf')
|
314 |
+
print('UsedFor')
|
315 |
+
print("")
|
316 |
+
print("NOTE: Capitalization is important")
|
317 |
+
else:
|
318 |
+
raise
|
319 |
+
print("")
|
320 |
+
|
321 |
+
def print_sampling_help():
|
322 |
+
print("")
|
323 |
+
print("Provide a sampling algorithm to produce the sequence with from the following:")
|
324 |
+
print("")
|
325 |
+
print("greedy")
|
326 |
+
print("beam-# where # is the beam size")
|
327 |
+
print("topk-# where # is k")
|
328 |
+
print("")
|
Model/COSMIC/feature_extraction/src/main.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument("--experiment_type", type=str, default='atomic',
|
9 |
+
choices=["atomic", "conceptnet"])
|
10 |
+
parser.add_argument("--experiment_num", type=str, default="0")
|
11 |
+
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
if args.experiment_type == "atomic":
|
15 |
+
from main_atomic import main
|
16 |
+
main(args.experiment_num)
|
17 |
+
if args.experiment_type == "conceptnet":
|
18 |
+
from main_conceptnet import main
|
19 |
+
main(args.experiment_num)
|