curiousily commited on
Commit
bbf6392
1 Parent(s): 43005e8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -4
handler.py CHANGED
@@ -5,9 +5,9 @@ import torch
5
 
6
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
7
 
8
-
9
  class EndpointHandler:
10
  def __init__(self, path=""):
 
11
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  path,
@@ -27,10 +27,12 @@ class EndpointHandler:
27
  self.generation_config = generation_config
28
 
29
  self.pipeline = transformers.pipeline(
30
- "text-generation", model=model, tokenizer=tokenizer
 
 
31
  )
32
 
33
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
34
- prompt = data.pop("prompt", data)
35
  result = self.pipeline(prompt, generation_config=self.generation_config)
36
- return result
 
5
 
6
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
7
 
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
+
11
  tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  path,
 
27
  self.generation_config = generation_config
28
 
29
  self.pipeline = transformers.pipeline(
30
+ "text-generation",
31
+ model=model,
32
+ tokenizer=tokenizer
33
  )
34
 
35
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
36
+ prompt = data.pop("inputs", data)
37
  result = self.pipeline(prompt, generation_config=self.generation_config)
38
+ return result