KoichiYasuoka
commited on
Commit
•
b96cf19
1
Parent(s):
d8ebfde
root analysis improved
Browse files
README.md
CHANGED
@@ -32,16 +32,16 @@ class UDgoeswith(object):
|
|
32 |
import numpy,torch,ufal.chu_liu_edmonds
|
33 |
w=self.tokenizer(text,return_offsets_mapping=True)
|
34 |
v=w["input_ids"]
|
35 |
-
|
36 |
with torch.no_grad():
|
37 |
-
|
38 |
-
|
39 |
-
e[
|
40 |
-
m=numpy.full((
|
41 |
m[1:,1:]=numpy.nanmax(e,axis=2).transpose()
|
42 |
-
p=numpy.zeros(
|
43 |
p[1:,1:]=numpy.nanargmax(e,axis=2).transpose()
|
44 |
-
for i in range(1,
|
45 |
m[i,0],m[i,i],p[i,0]=m[i,i],numpy.nan,p[i,i]
|
46 |
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
|
47 |
u="# text = "+text+"\n"
|
|
|
32 |
import numpy,torch,ufal.chu_liu_edmonds
|
33 |
w=self.tokenizer(text,return_offsets_mapping=True)
|
34 |
v=w["input_ids"]
|
35 |
+
x=[v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)]
|
36 |
with torch.no_grad():
|
37 |
+
e=self.model(input_ids=torch.tensor(x)).logits.numpy()[:,1:-2,:]
|
38 |
+
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
39 |
+
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
40 |
+
m=numpy.full((e.shape[0]+1,e.shape[1]+1),numpy.nan)
|
41 |
m[1:,1:]=numpy.nanmax(e,axis=2).transpose()
|
42 |
+
p=numpy.zeros(m.shape)
|
43 |
p[1:,1:]=numpy.nanargmax(e,axis=2).transpose()
|
44 |
+
for i in range(1,m.shape[0]):
|
45 |
m[i,0],m[i,i],p[i,0]=m[i,i],numpy.nan,p[i,i]
|
46 |
h=ufal.chu_liu_edmonds.chu_liu_edmonds(m)[0]
|
47 |
u="# text = "+text+"\n"
|