viethoangtranduong
commited on
Commit
•
8f1e990
1
Parent(s):
724f127
Update handler.py
Browse files- handler.py +6 -9
handler.py
CHANGED
@@ -8,14 +8,11 @@ DEFAULT_MAX_NEW_TOKENS = 10
|
|
8 |
|
9 |
class EndpointHandler():
|
10 |
def __init__(self, path: str = ""):
|
11 |
-
assert torch.cuda.device_count() >= 4, f"Only found access to {torch.cuda.device_count()} GPUs"
|
12 |
|
13 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
14 |
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
|
15 |
self.model = self.model.to('cuda:0')
|
16 |
|
17 |
-
self.model.parallelize()
|
18 |
-
|
19 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
20 |
"""
|
21 |
Args:
|
@@ -27,14 +24,14 @@ class EndpointHandler():
|
|
27 |
|
28 |
prompts = [f"<human>: {prompt}\n<bot>:" for prompt in data["inputs"]]
|
29 |
|
30 |
-
|
31 |
-
raise ValueError(inputs)
|
32 |
-
|
33 |
inputs = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.model.device)
|
34 |
input_length = inputs.input_ids.shape[1]
|
|
|
35 |
outputs = self.model.generate(
|
36 |
-
**inputs,
|
37 |
)
|
38 |
-
output_strs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
39 |
|
40 |
-
|
|
|
|
|
|
8 |
|
9 |
class EndpointHandler():
|
10 |
def __init__(self, path: str = ""):
|
|
|
11 |
|
12 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
13 |
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
|
14 |
self.model = self.model.to('cuda:0')
|
15 |
|
|
|
|
|
16 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
17 |
"""
|
18 |
Args:
|
|
|
24 |
|
25 |
prompts = [f"<human>: {prompt}\n<bot>:" for prompt in data["inputs"]]
|
26 |
|
27 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
28 |
inputs = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.model.device)
|
29 |
input_length = inputs.input_ids.shape[1]
|
30 |
+
|
31 |
outputs = self.model.generate(
|
32 |
+
**inputs, **data["parameters"]
|
33 |
)
|
|
|
34 |
|
35 |
+
output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
|
36 |
+
|
37 |
+
return [{"generated_text": output_strs}]
|