skorkmaz88 commited on
Commit
1ae4036
1 Parent(s): 754c9b5

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +23 -0
pipeline.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+
3
+
4
+ class MyPipeline(Pipeline):
5
+ def _sanitize_parameters(self, **kwargs):
6
+ preprocess_kwargs = {}
7
+ if "maybe_arg" in kwargs:
8
+ preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
9
+ return preprocess_kwargs, {}, {}
10
+
11
+ def preprocess(self, inputs, maybe_arg=2):
12
+ model_input = Tensor(inputs["input_ids"])
13
+ return {"model_input": model_input}
14
+
15
+ def _forward(self, model_inputs):
16
+ # model_inputs == {"model_input": model_input}
17
+ outputs = self.model(**model_inputs)
18
+ # Maybe {"logits": Tensor(...)}
19
+ return outputs
20
+
21
+ def postprocess(self, model_outputs):
22
+ best_class = model_outputs["logits"].softmax(-1)
23
+ return best_class