hdnh2006 commited on
Commit
539ede7
1 Parent(s): 4f267c9

handler uses LlamaForCausalLM

Browse files
Files changed (1) hide show
  1. handler.py +5 -2
handler.py CHANGED
@@ -13,7 +13,7 @@ for text generation, leveraging the capabilities of the Llama 2 model.
13
  """
14
 
15
  import torch
16
- from transformers import pipeline, BitsAndBytesConfig
17
  from typing import Dict, List, Any
18
  import logging
19
  import sys
@@ -51,7 +51,10 @@ class EndpointHandler:
51
  bnb_4bit_compute_dtype=torch.bfloat16
52
  )
53
 
54
- self.pipeline = pipeline('text-generation', model=path, quantization_config=self.bnb_config)
 
 
 
55
 
56
 
57
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
13
  """
14
 
15
  import torch
16
+ from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline, BitsAndBytesConfig
17
  from typing import Dict, List, Any
18
  import logging
19
  import sys
 
51
  bnb_4bit_compute_dtype=torch.bfloat16
52
  )
53
 
54
+ tokenizer = LlamaTokenizer.from_pretrained(path)
55
+ model = LlamaForCausalLM.from_pretrained(path, device_map=0, quantization_config=self.bnb_config)
56
+
57
+ self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
58
 
59
 
60
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: