KoichiYasuoka
commited on
Commit
·
0bb8002
1
Parent(s):
f1179fb
bug fix
Browse files
ud.py
CHANGED
@@ -85,7 +85,7 @@ class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline
|
|
85 |
m=torch.stack(m)
|
86 |
k=list(range(len(d)+1))
|
87 |
with torch.no_grad():
|
88 |
-
e=self.model(inputs_embeds=torch.stack([m[k+list(range(i,len(d)))+[-1]*i,:] for i in range(len(d))])).logits[:,-len(d):,:].numpy()
|
89 |
for i in range(len(d)):
|
90 |
for j in range(i):
|
91 |
e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
|
@@ -118,4 +118,7 @@ class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline
|
|
118 |
z=matrix-numpy.nanmax(matrix,axis=0)
|
119 |
m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
|
120 |
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
|
121 |
-
|
|
|
|
|
|
|
|
85 |
m=torch.stack(m)
|
86 |
k=list(range(len(d)+1))
|
87 |
with torch.no_grad():
|
88 |
+
e=self.model(inputs_embeds=torch.stack([m[k+list(range(i,len(d)))+[-1]*i,:] for i in range(len(d))]).to(self.device)).logits[:,-len(d):,:].numpy()
|
89 |
for i in range(len(d)):
|
90 |
for j in range(i):
|
91 |
e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
|
|
|
118 |
z=matrix-numpy.nanmax(matrix,axis=0)
|
119 |
m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
|
120 |
k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
|
121 |
+
h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
|
122 |
+
i=y[numpy.nanargmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
|
123 |
+
h[i]=x[k[-1]] if k[-1]<len(x) else i
|
124 |
+
return h
|