philschmid HF staff commited on
Commit
28d71e8
1 Parent(s): 36a8565

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -21
handler.py CHANGED
@@ -1,26 +1,28 @@
1
- import torch
2
- from typing import  Dict, List, Any
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
5
  # check for GPU
6
- device = 0 if torch.cuda.is_available() else -1
7
 
8
- class EndpointHandler():
9
-     def __init__(self, path=""):
10
-         # load the model
11
-         tokenizer = AutoTokenizer.from_pretrained(path)
12
-         model = AutoModelForSeq2SeqLM.from_pretrained(path ,low_cpu_mem_usage=True)
13
-         # create inference pipeline
14
-         self.pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer,device=device)
15
 
16
-     def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
17
-         inputs = data.pop("inputs", data)
18
-         parameters = data.pop("parameters", None)
 
 
 
 
19
 
20
-         # pass inputs with all kwargs in data
21
-         if parameters is not None:
22
-             prediction = self.pipeline(inputs, **parameters)
23
-         else:
24
-             prediction = self.pipeline(inputs)
25
-         # postprocess the prediction
26
-         return prediction
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
5
  # check for GPU
6
+ device = 0 if torch.cuda.is_available() else -1
7
 
 
 
 
 
 
 
 
8
 
9
+ class EndpointHandler():
10
+ def __init__(self, path=""):
11
+ # load the model
12
+ tokenizer = AutoTokenizer.from_pretrained(path)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(path ,low_cpu_mem_usage=True)
14
+ # create inference pipeline
15
+ self.pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer,device=device)
16
 
17
+
18
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
19
+ inputs = data.pop("inputs", data)
20
+ parameters = data.pop("parameters", None)
21
+
22
+ # pass inputs with all kwargs in data
23
+ if parameters is not None:
24
+ prediction = self.pipeline(inputs, **parameters)
25
+ else:
26
+ prediction = self.pipeline(inputs)
27
+ # postprocess the prediction
28
+ return prediction