KoichiYasuoka
commited on
Commit
•
18d812a
1
Parent(s):
a827d92
support transformers>=4.28
Browse files
ud.py
CHANGED
@@ -9,6 +9,8 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
9 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
10 |
def postprocess(self,model_outputs,**kwargs):
|
11 |
import numpy
|
|
|
|
|
12 |
e=model_outputs["logits"].numpy()
|
13 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
14 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|
|
|
9 |
return {"logits":e.logits[:,1:-2,:],**model_inputs}
|
10 |
def postprocess(self,model_outputs,**kwargs):
|
11 |
import numpy
|
12 |
+
if "logits" not in model_outputs:
|
13 |
+
return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
|
14 |
e=model_outputs["logits"].numpy()
|
15 |
r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
|
16 |
e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,numpy.nan)
|