File size: 3,226 Bytes
bd8aef1
 
 
d26baab
bd8aef1
d26baab
bd8aef1
570fb37
d26baab
 
bd8aef1
df0553b
 
d26baab
bd8aef1
 
83960e8
 
 
 
 
 
0ff8f8f
 
 
 
2abab5e
 
0ff8f8f
d26baab
70336f3
fc5131d
dfe0ec5
70336f3
fceb6cc
a92173a
70336f3
a92173a
d26baab
bd8aef1
70336f3
 
bd8aef1
0ff8f8f
 
 
 
 
 
 
 
 
 
 
 
f796df0
 
ee10b29
 
62ffe79
2abab5e
62ffe79
f796df0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def _forward(self,model_inputs):
    import torch
    v=model_inputs["input_ids"][0].tolist()
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
    return {"logits":e.logits[:,1:-2,:],**model_inputs}
  def postprocess(self,model_outputs,**kwargs):
    import numpy
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    e=model_outputs["logits"].numpy()
    r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
    e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
    g=self.model.config.label2id["X|_|goeswith"]
    r=numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=r[i,j-1] if numpy.nanargmax(e[i,j-1])==g else 1
    e[:,:,g]+=numpy.where(r==0,0,numpy.nan)
    m,p=numpy.nanmax(e,axis=2),numpy.nanargmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.nanargmax(m[z,z])],numpy.nanmin(m)-numpy.nanmax(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
    if g:
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],t[s:e] if g else "_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    import numpy
    h=numpy.nanargmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.nanmax(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.nanargmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h