kleinay commited on
Commit
8af025c
1 Parent(s): 64232f7

fix source prefix

Browse files
Files changed (1) hide show
  1. pipeline.py +2 -2
pipeline.py CHANGED
@@ -125,9 +125,9 @@ class QASRL_Pipeline(Text2TextGenerationPipeline):
125
  def _get_source_prefix(self, predicate_type: Optional[str]):
126
  if not self.is_t5_model or self.data_args.source_prefix is None:
127
  return ''
128
- if "Generate QAs for <predicate_type> QASRL: " in self.data_args.source_prefix:
129
  if predicate_type is None:
130
- raise ValueError("source_prefix includes 'Generate QAs for <predicate_type> QASRL: ' but input has no `predicate_type`.")
131
  if self.data_args.source_prefix == "Generate QAs for <predicate_type> QASRL: ": # backwrad compatibility - "Generate QAs for <predicate_type> QASRL: " alone was a sign for a longer prefix
132
  return f"Generate QAs for {predicate_type} QASRL: "
133
  else:
 
125
  def _get_source_prefix(self, predicate_type: Optional[str]):
126
  if not self.is_t5_model or self.data_args.source_prefix is None:
127
  return ''
128
+ if "<predicate_type>" in self.data_args.source_prefix:
129
  if predicate_type is None:
130
+ raise ValueError("source_prefix includes '<predicate_type>' but input has no `predicate_type`.")
131
  if self.data_args.source_prefix == "Generate QAs for <predicate_type> QASRL: ": # backwrad compatibility - "Generate QAs for <predicate_type> QASRL: " alone was a sign for a longer prefix
132
  return f"Generate QAs for {predicate_type} QASRL: "
133
  else: