Iggnis commited on
Commit
68540c5
1 Parent(s): 4d652f5
Files changed (1) hide show
  1. handler.py +10 -3
handler.py CHANGED
@@ -1,14 +1,21 @@
 
1
  from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
 
 
 
 
 
 
7
  # load the model
8
  tokenizer = AutoTokenizer.from_pretrained(path)
9
- model = AutoModelForCausalLM.from_pretrained(path, device_map="auto")
10
  # create inference pipeline
11
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
 
13
  def __call__(self, data: Any):
14
  inputs = data.pop("inputs", data)
 
1
+ import torch
2
  from typing import Dict, List, Any
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
4
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ bnb_config = BitsAndBytesConfig(
9
+ load_in_4bit=True,
10
+ bnb_4bit_use_double_quant=True,
11
+ bnb_4bit_quant_type="nf4",
12
+ bnb_4bit_compute_dtype=torch.bfloat16
13
+ )
14
  # load the model
15
  tokenizer = AutoTokenizer.from_pretrained(path)
16
+ model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", quantization_config=bnb_config)
17
  # create inference pipeline
18
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
19
 
20
  def __call__(self, data: Any):
21
  inputs = data.pop("inputs", data)