Files changed (1) hide show
  1. handler.py +6 -9
handler.py CHANGED
@@ -8,14 +8,11 @@ DEFAULT_MAX_NEW_TOKENS = 10
8
 
9
  class EndpointHandler():
10
  def __init__(self, path: str = ""):
11
- assert torch.cuda.device_count() >= 4, f"Only found access to {torch.cuda.device_count()} GPUs"
12
 
13
  self.tokenizer = AutoTokenizer.from_pretrained(path)
14
  self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
15
  self.model = self.model.to('cuda:0')
16
 
17
- self.model.parallelize()
18
-
19
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
  """
21
  Args:
@@ -27,14 +24,14 @@ class EndpointHandler():
27
 
28
  prompts = [f"<human>: {prompt}\n<bot>:" for prompt in data["inputs"]]
29
 
30
- print("prompts")
31
- raise ValueError(inputs)
32
-
33
  inputs = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.model.device)
34
  input_length = inputs.input_ids.shape[1]
 
35
  outputs = self.model.generate(
36
- **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.7, top_k=50
37
  )
38
- output_strs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
39
 
40
- return {"generated_text": output_strs}
 
 
 
8
 
9
  class EndpointHandler():
10
  def __init__(self, path: str = ""):
 
11
 
12
  self.tokenizer = AutoTokenizer.from_pretrained(path)
13
  self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
14
  self.model = self.model.to('cuda:0')
15
 
 
 
16
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
  """
18
  Args:
 
24
 
25
  prompts = [f"<human>: {prompt}\n<bot>:" for prompt in data["inputs"]]
26
 
27
+ self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
28
  inputs = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.model.device)
29
  input_length = inputs.input_ids.shape[1]
30
+
31
  outputs = self.model.generate(
32
+ **inputs, **data["parameters"]
33
  )
 
34
 
35
+ output_strs = self.tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True)
36
+
37
+ return [{"generated_text": output_strs}]