KoichiYasuoka commited on
Commit
5b3c45d
1 Parent(s): b441a93

support transformers>=4.28

Browse files
Files changed (1) hide show
  1. ud.py +2 -0
ud.py CHANGED
@@ -32,6 +32,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
32
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
33
  def postprocess(self,model_outputs,**kwargs):
34
  import numpy
 
 
35
  e=model_outputs["logits"].numpy()
36
  r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
37
  e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
 
32
  return {"logits":e.logits[:,1:-2,:],**model_inputs}
33
  def postprocess(self,model_outputs,**kwargs):
34
  import numpy
35
+ if "logits" not in model_outputs:
36
+ return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
37
  e=model_outputs["logits"].numpy()
38
  r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
39
  e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)