viethoangtranduong commited on
Commit
a679b46
1 Parent(s): 958469b

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ MAX_TOKENS_IN_BATCH = 4_000
6
+ DEFAULT_MAX_NEW_TOKENS = 10
7
+
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, path: str = ""):
11
+ assert torch.cuda.device_count() >= 4, f"Only found access to {torch.cuda.device_count()} GPUs"
12
+
13
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
14
+ self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
15
+ self.model = self.model.to('cuda:0')
16
+
17
+ self.model.parallelize()
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
+ """
21
+ Args:
22
+ data (:obj:):
23
+ includes the input data and the parameters for the inference.
24
+ Return:
25
+ A :obj:`list`:. The list contains the answer and scores of the inference inputs
26
+ """
27
+
28
+ prompts = [f"<human>: {prompt}\n<bot>:" for prompt in data["inputs"]]
29
+
30
+
31
+ inputs = tokenizer(prompts, padding=True, return_tensors='pt').to(model.device)
32
+ input_length = inputs.input_ids.shape[1]
33
+ outputs = model.generate(
34
+ **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
35
+ )
36
+ output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
37
+
38
+ return output_strs