KoichiYasuoka commited on
Commit
18d812a
1 Parent(s): a827d92

support transformers>=4.28

Browse files
Files changed (1) hide show
  1. ud.py +2 -0
ud.py CHANGED
@@ -9,6 +9,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
9
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
10
  def postprocess(self,model_outputs,**kwargs):
11
  import numpy
 
 
12
  e=model_outputs["logits"].numpy()
13
  r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
14
  e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
 
9
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
10
  def postprocess(self,model_outputs,**kwargs):
11
  import numpy
12
+ if "logits" not in model_outputs:
13
+ return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
14
  e=model_outputs["logits"].numpy()
15
  r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
16
  e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)