wassemgtk commited on
Commit
da1ea71
1 Parent(s): 99129c0

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +52 -0
handler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+
5
+ # check for GPU
6
+ device = 0 if torch.cuda.is_available() else -1
7
+
8
+
9
+ format_input = (
10
+ "Below is an instruction that describes a task. "
11
+ "Write a response that appropriately completes the request.\n\n"
12
+ "### Instruction:\n{instruction}\n\n### Response:"
13
+ )
14
+
15
+
16
+ class EndpointHandler:
17
+ def __init__(self, path=""):
18
+ # load the model
19
+ tokenizer = AutoTokenizer.from_pretrained(path)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ path,
22
+ device_map="auto",
23
+ torch_dtype=torch.float16,
24
+ )
25
+ # create inference pipeline
26
+ self.pipeline = pipeline(
27
+ "text-generation",
28
+ model=model,
29
+ tokenizer=tokenizer,
30
+ device=device,
31
+ max_length=256,
32
+ )
33
+
34
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
35
+ inputs = data.pop("inputs", data)
36
+ parameters = data.pop("parameters", None)
37
+
38
+ text_input = format_input.format(instruction=inputs)
39
+
40
+ # pass inputs with all kwargs in data
41
+ if parameters is not None:
42
+ prediction = self.pipeline(text_input, **parameters)
43
+ else:
44
+ prediction = self.pipeline(text_input)
45
+
46
+ # postprocess the prediction
47
+ output = [
48
+ {"generated_text": pred["generated_text"].split("### Response:")[1].strip()}
49
+ for pred in prediction
50
+ ]
51
+
52
+ return output