kleinay commited on
Commit
1af26a8
1 Parent(s): 7188707

update _sanitize_parameters to accept `generate_kwargs`

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -6
pipeline.py CHANGED
@@ -66,12 +66,13 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
66
 
67
  def _sanitize_parameters(self, **kwargs):
68
  preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)
69
- if "predicate_marker" in kwargs:
70
- preprocess_kwargs["predicate_marker"] = kwargs["predicate_marker"]
71
- if "predicate_type" in kwargs:
72
- preprocess_kwargs["predicate_type"] = kwargs["predicate_type"]
73
- if "verb_form" in kwargs:
74
- preprocess_kwargs["verb_form"] = kwargs["verb_form"]
 
75
  return preprocess_kwargs, forward_kwargs, postprocess_kwargs
76
 
77
  def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):
 
66
 
67
  def _sanitize_parameters(self, **kwargs):
68
  preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)
69
+ forward_kwargs.update(kwargs.get("generate_kwargs", dict()))
70
+ forward_kwargs.update(kwargs.get("model_kwargs", dict()))
71
+ preprocess_keywords = ("predicate_marker", "predicate_type", "verb_form")
72
+ for key in preprocess_keywords:
73
+ if key in kwargs:
74
+ preprocess_kwargs[key] = kwargs[key]
75
+
76
  return preprocess_kwargs, forward_kwargs, postprocess_kwargs
77
 
78
  def preprocess(self, inputs, predicate_marker="<predicate>", predicate_type=None, verb_form=None):