KoichiYasuoka commited on
Commit
891a875
1 Parent(s): a1c6f33

support transformers>=4.28

Browse files
Files changed (1) hide show
  1. ud.py +2 -0
ud.py CHANGED
@@ -16,6 +16,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
16
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
17
  def postprocess(self,model_outputs,**kwargs):
18
  import numpy
 
 
19
  e=model_outputs["logits"].numpy()
20
  r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
21
  e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
16
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
17
  def postprocess(self,model_outputs,**kwargs):
18
  import numpy
19
+ if "logits" not in model_outputs:
20
+ return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
21
  e=model_outputs["logits"].numpy()
22
  r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
23
  e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)