fwittel commited on
Commit
040e104
1 Parent(s): 4c61288

Added tokenizer to handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -4
handler.py CHANGED
@@ -1,16 +1,26 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoModel, pipeline
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
6
  # load the model
 
7
  model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True)
8
  # create inference pipeline
9
  # Do I have to check device?
10
- self.pipeline = pipeline("text-generation", model=model)
 
11
 
12
  # (Might have to adjust typing)
13
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
14
- inputs = data.pop("inputs", data) # Should I get and pass parameters?
15
- prediction = self.pipeline(inputs)
 
 
 
 
 
 
 
16
  return prediction
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoModel, AutoTokenizer, pipeline
3
+
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # load the model
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
  model = AutoModel.from_pretrained(path, low_cpu_mem_usage=True)
10
  # create inference pipeline
11
  # Do I have to check device?
12
+ self.pipeline = pipeline(
13
+ "text-generation", model=model, tokenizer=tokenizer)
14
 
15
  # (Might have to adjust typing)
16
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
17
+ inputs = data.pop("inputs", data)
18
+ parameters = data.pop("parameters", None)
19
+
20
+ # pass inputs with all kwargs in data
21
+ if parameters is not None:
22
+ prediction = self.pipeline(inputs, **parameters)
23
+ else:
24
+ prediction = self.pipeline(inputs)
25
+ # postprocess the prediction
26
  return prediction