MrOvkill commited on
Commit
deb69ba
1 Parent(s): d82a4e5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -9
handler.py CHANGED
@@ -4,8 +4,11 @@ from typing import Dict, List, Any
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
- MAX_TOKENS=8192
8
- GPU_LAYERS=99 if torch.cuda.is_available() else 0
 
 
 
9
 
10
  class EndpointHandler():
11
  def __init__(self, data):
@@ -17,19 +20,20 @@ class EndpointHandler():
17
 
18
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
19
  inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ")
20
- max_length = data.pop("max_length", 1024)
21
  try:
22
- max_length = int(max_length)
23
  except Exception as e:
24
  return json.dumps({
25
  "status": "error",
26
  "reason": "max_length was passed as something that was absolutely not a plain old int"
27
  })
28
 
29
- res = self.model(f"""
30
- <|user|>
31
- {inputs} <|end|>
32
- <|assistant|>
33
- """, max_new_tokens=max_new_tokens, do_sample=False)
 
34
 
35
  return res
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
+ PROMPT_FORMAT= """
8
+ <|user|>
9
+ {inputs} <|end|>
10
+ <|assistant|>
11
+ """
12
 
13
  class EndpointHandler():
14
  def __init__(self, data):
 
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ")
23
+ max_new_tokens = data.pop("max_length", 1024)
24
  try:
25
+ max_new_tokens = int(max_new_tokens)
26
  except Exception as e:
27
  return json.dumps({
28
  "status": "error",
29
  "reason": "max_length was passed as something that was absolutely not a plain old int"
30
  })
31
 
32
+ res = PROMPT_FORMAT.format(do_sample=False)
33
+ retrurn model(
34
+ res,
35
+ max_new_tokens=max_new_tokens,
36
+ do_sample=False
37
+ )
38
 
39
  return res