Ozgur98 commited on
Commit
3b97c7e
1 Parent(s): 470f098

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -5
handler.py CHANGED
@@ -6,10 +6,9 @@ import torch.cuda
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  LOGGER = logging.getLogger(__name__)
9
-
10
  class EndpointHandler():
11
  def __init__(self, path=""):
12
- self.model = AutoModelForCausalLM.from_pretrained("Ozgur98/pushed_model_mosaic_small")
13
  self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
14
  # Load the Lora model
15
 
@@ -18,13 +17,13 @@ class EndpointHandler():
18
  Args:
19
  data (Dict): The payload with the text prompt and generation parameters.
20
  """
21
- print("CALLED")
22
  LOGGER.info(data)
23
  # Forward
24
  LOGGER.info(f"Start generation.")
25
- tokenized_example = tokenizer(data, return_tensors='pt')
26
  outputs = self.model.generate(tokenized_example['input_ids'].to('cuda:0'), max_new_tokens=100, do_sample=True, top_k=10, top_p = 0.95)
 
27
  # Postprocess
28
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)
29
  prompt = answer[0].rstrip()
30
  return prompt
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  LOGGER = logging.getLogger(__name__)
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
+ self.model = AutoModelForCausalLM.from_pretrained("Ozgur98/pushed_model_mosaic_small", trust_remote_code=True).to(device='cuda:0', dtype=torch.bfloat16)
12
  self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
  # Load the Lora model
14
 
 
17
  Args:
18
  data (Dict): The payload with the text prompt and generation parameters.
19
  """
 
20
  LOGGER.info(data)
21
  # Forward
22
  LOGGER.info(f"Start generation.")
23
+ tokenized_example = self.tokenizer(data, return_tensors='pt')
24
  outputs = self.model.generate(tokenized_example['input_ids'].to('cuda:0'), max_new_tokens=100, do_sample=True, top_k=10, top_p = 0.95)
25
+
26
  # Postprocess
27
+ answer = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
28
  prompt = answer[0].rstrip()
29
  return prompt