Pierce Maloney commited on
Commit
b358b49
1 Parent(s): 7e24db7

quantization trial

Browse files
Files changed (1) hide show
  1. handler.py +10 -2
handler.py CHANGED
@@ -1,6 +1,8 @@
1
  import logging
2
  from typing import Dict, List, Any
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList
 
 
4
 
5
  # Configure logging
6
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -10,7 +12,13 @@ class EndpointHandler():
10
  logging.info("Initializing EndpointHandler with model path: %s", path)
11
  tokenizer = AutoTokenizer.from_pretrained(path)
12
  tokenizer.pad_token = tokenizer.eos_token
13
- self.model = AutoModelForCausalLM.from_pretrained(path)
 
 
 
 
 
 
14
  self.tokenizer = tokenizer
15
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
16
 
 
1
  import logging
2
  from typing import Dict, List, Any
3
+ import torch
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList, BitsAndBytesConfig
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
12
  logging.info("Initializing EndpointHandler with model path: %s", path)
13
  tokenizer = AutoTokenizer.from_pretrained(path)
14
  tokenizer.pad_token = tokenizer.eos_token
15
+ bnb_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_use_double_quant=True,
18
+ bnb_4bit_quant_type="nf4",
19
+ bnb_4bit_compute_dtype=torch.bfloat16
20
+ )
21
+ self.model = AutoModelForCausalLM.from_pretrained(path, quantization_config=bnb_config)
22
  self.tokenizer = tokenizer
23
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
24