kleinay commited on
Commit
417ef6d
1 Parent(s): ec979c9

add update_config

Browse files
Files changed (1) hide show
  1. pipeline.py +9 -1
pipeline.py CHANGED
@@ -62,7 +62,15 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
62
  self.data_args.use_bilateral_predicate_marker = True
63
  if "append_verb_form" not in vars(self.data_args):
64
  self.data_args.append_verb_form = True
65
-
 
 
 
 
 
 
 
 
66
 
67
  def _sanitize_parameters(self, **kwargs):
68
  preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)
62
  self.data_args.use_bilateral_predicate_marker = True
63
  if "append_verb_form" not in vars(self.data_args):
64
  self.data_args.append_verb_form = True
65
+ self._update_config(**kwargs)
66
+
67
+ def _update_config(self, **kwargs):
68
+ " Update self.model.config with initialization parameters and necessary defaults. "
69
+ # set default values that will always override model.config, but can overriden by __init__ kwargs
70
+ kwargs["max_length"] = kwargs.get("max_length", 80)
71
+ # override model.config with kwargs
72
+ for k,v in kwargs.items():
73
+ self.model.config.__dict__[k] = v
74
 
75
  def _sanitize_parameters(self, **kwargs):
76
  preprocess_kwargs, forward_kwargs, postprocess_kwargs = {}, {}, {} # super()._sanitize_parameters(**kwargs)