booksouls commited on
Commit
c3c4be8
·
verified ·
1 Parent(s): 8b9d004

update handler to load predefined 4-bit model

Browse files
Files changed (1) hide show
  1. handler.py +3 -16
handler.py CHANGED
@@ -4,19 +4,7 @@ from typing import Any
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
- # bitsandbytes quantization is only supported on CUDA devices.
8
- bits_and_bytes_config = BitsAndBytesConfig(
9
- load_in_4bit=True,
10
- bnb_4bit_compute_dtype=torch.bfloat16,
11
- )
12
- quantization_config = bits_and_bytes_config if torch.cuda.is_available() else None
13
-
14
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
16
- path,
17
- quantization_config=quantization_config,
18
- device_map="auto",
19
- )
20
  self.tokenizer = AutoTokenizer.from_pretrained(path)
21
 
22
  def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
@@ -41,8 +29,8 @@ class EndpointHandler():
41
  return_attention_mask=False,
42
  )
43
 
44
- # Ensure the input_ids and the model are on the same device to prevent errors.
45
- input_ids = tokens.input_ids.to(self.device)
46
 
47
  # Gradient calculation is not needed for inference.
48
  with torch.no_grad():
@@ -53,4 +41,3 @@ class EndpointHandler():
53
 
54
  generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
55
  return {"generated_text": generated_text}
56
-
 
4
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(f"{path}/4-bit", device_map="auto")
 
 
 
 
 
 
 
 
 
 
 
 
8
  self.tokenizer = AutoTokenizer.from_pretrained(path)
9
 
10
  def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
 
29
  return_attention_mask=False,
30
  )
31
 
32
+ # Ensure the input_ids and the model are both on the GPU to prevent errors.
33
+ input_ids = tokens.input_ids.to("cuda")
34
 
35
  # Gradient calculation is not needed for inference.
36
  with torch.no_grad():
 
41
 
42
  generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
43
  return {"generated_text": generated_text}