File size: 3,226 Bytes
6fb6bd4
 
 
753e33a
6fb6bd4
753e33a
6fb6bd4
6b6e7b4
753e33a
 
6fb6bd4
6cfc49c
 
753e33a
6fb6bd4
 
cedf196
 
 
 
 
 
b8a63dd
 
 
 
7c2f5c4
 
b8a63dd
753e33a
b8a63dd
3c9ddba
17bc6c9
b8a63dd
7cc383a
2d9b614
b8a63dd
2d9b614
753e33a
6fb6bd4
b8a63dd
 
6fb6bd4
b8a63dd
 
 
 
 
 
 
 
 
 
 
 
7c2f5c4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def _forward(self,model_inputs):
    import torch
    v=model_inputs["input_ids"][0].tolist()
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
    return {"logits":e.logits[:,1:-2,:],**model_inputs}
  def postprocess(self,model_outputs,**kwargs):
    import numpy
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    e=model_outputs["logits"].numpy()
    r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
    e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
    g=self.model.config.label2id["X|_|goeswith"]
    r=numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=r[i,j-1] if numpy.nanargmax(e[i,j-1])==g else 1
    e[:,:,g]+=numpy.where(r==0,0,numpy.nan)
    m,p=numpy.nanmax(e,axis=2),numpy.nanargmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.nanargmax(m[z,z])],numpy.nanmin(m)-numpy.nanmax(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
    if g:
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],t[s:e] if g else "_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    import numpy
    h=numpy.nanargmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.nanmax(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.nanmax(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.nanmax(z[y,:][:,x],axis=0),numpy.nanmax(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.nanargmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.nanargmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h