nathantablang commited on
Commit
ed0457a
1 Parent(s): a323040

Upload DemoT5QAPipeline

Browse files
Files changed (1) hide show
  1. custom_pipeline.py +18 -18
custom_pipeline.py CHANGED
@@ -23,27 +23,27 @@ class DemoT5QAPipeline(QuestionAnsweringPipeline):
23
  # raise ValueError("Inputs must be a dictionary with 'question' and 'context' keys.")
24
 
25
 
26
- def _forward(self, model_inputs, **generate_kwargs):
27
- if self.framework == "pt":
28
- in_b, input_length = model_inputs["input_ids"].shape
29
- elif self.framework == "tf":
30
- in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
31
 
32
- self.check_inputs(
33
- input_length,
34
- generate_kwargs.get("min_length", self.model.config.min_length),
35
- generate_kwargs.get("max_length", self.model.config.max_length),
36
- )
37
 
38
- outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True)
39
- output_ids = outputs.sequences
40
- out_b = output_ids.shape[0]
41
- if self.framework == "pt":
42
- output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
43
- elif self.framework == "tf":
44
- output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
45
 
46
- return {"output_ids": output_ids, "output_sequences": outputs.sequences, "output_scores": outputs.scores}
47
 
48
  def postprocess(self, model_outputs):
49
  guess_text = super().postprocess(model_outputs)[0]['generated_text']
 
23
  # raise ValueError("Inputs must be a dictionary with 'question' and 'context' keys.")
24
 
25
 
26
+ # def _forward(self, model_inputs, **generate_kwargs):
27
+ # if self.framework == "pt":
28
+ # in_b, input_length = model_inputs["input_ids"].shape
29
+ # elif self.framework == "tf":
30
+ # in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
31
 
32
+ # self.check_inputs(
33
+ # input_length,
34
+ # generate_kwargs.get("min_length", self.model.config.min_length),
35
+ # generate_kwargs.get("max_length", self.model.config.max_length),
36
+ # )
37
 
38
+ # outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True)
39
+ # output_ids = outputs.sequences
40
+ # out_b = output_ids.shape[0]
41
+ # if self.framework == "pt":
42
+ # output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
43
+ # elif self.framework == "tf":
44
+ # output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
45
 
46
+ # return {"output_ids": output_ids, "output_sequences": outputs.sequences, "output_scores": outputs.scores}
47
 
48
  def postprocess(self, model_outputs):
49
  guess_text = super().postprocess(model_outputs)[0]['generated_text']