kleinay commited on
Commit
7b79038
1 Parent(s): 3dc25ec

Update qasrl_model_pipeline.py

Browse files

add update_config to fix max_length problem and allow customization in __init__

Files changed (1) hide show
  1. qasrl_model_pipeline.py +9 -1
qasrl_model_pipeline.py CHANGED
@@ -61,7 +61,15 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
61
  self.data_args.use_bilateral_predicate_marker = True
62
  if "append_verb_form" not in vars(self.data_args):
63
  self.data_args.append_verb_form = True
64
-
 
 
 
 
 
 
 
 
65
 
66
  def _sanitize_parameters(self, **kwargs):
67
  preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}
 
61
  self.data_args.use_bilateral_predicate_marker = True
62
  if "append_verb_form" not in vars(self.data_args):
63
  self.data_args.append_verb_form = True
64
+ self._update_config(**kwargs)
65
+
66
+ def _update_config(self, **kwargs):
67
+ " Update self.model.config with initialization parameters and necessary defaults. "
68
+ # set default values that will always override model.config, but can overriden by __init__ kwargs
69
+ kwargs["max_length"] = kwargs.get("max_length", 80)
70
+ # override model.config with kwargs
71
+ for k,v in kwargs.items():
72
+ self.model.config.__dict__[k] = v
73
 
74
  def _sanitize_parameters(self, **kwargs):
75
  preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {}