KoichiYasuoka commited on
Commit
0bb8002
·
1 Parent(s): f1179fb
Files changed (1) hide show
  1. ud.py +5 -2
ud.py CHANGED
@@ -85,7 +85,7 @@ class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline
85
  m=torch.stack(m)
86
  k=list(range(len(d)+1))
87
  with torch.no_grad():
88
- e=self.model(inputs_embeds=torch.stack([m[k+list(range(i,len(d)))+[-1]*i,:] for i in range(len(d))])).logits[:,-len(d):,:].numpy()
89
  for i in range(len(d)):
90
  for j in range(i):
91
  e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
@@ -118,4 +118,7 @@ class UniversalDependenciesCausalPipeline(BellmanFordTokenClassificationPipeline
118
  z=matrix-numpy.nanmax(matrix,axis=0)
119
  m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
120
  k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
121
-
 
 
 
 
85
  m=torch.stack(m)
86
  k=list(range(len(d)+1))
87
  with torch.no_grad():
88
+ e=self.model(inputs_embeds=torch.stack([m[k+list(range(i,len(d)))+[-1]*i,:] for i in range(len(d))]).to(self.device)).logits[:,-len(d):,:].numpy()
89
  for i in range(len(d)):
90
  for j in range(i):
91
  e[-j-1,-i-1],e[-i-1,-j-1]=e[-i-1,i-j]+self.left_arc,e[-i-1,i-j]+self.right_arc
 
118
  z=matrix-numpy.nanmax(matrix,axis=0)
119
  m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
120
  k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
121
+ h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
122
+ i=y[numpy.nanargmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
123
+ h[i]=x[k[-1]] if k[-1]<len(x) else i
124
+ return h