apoorvkh commited on
Commit
e68919f
1 Parent(s): 76aff17

Updated bitsandbytes config

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Dict, Any
2
 
3
  import torch
4
- from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration
5
  from accelerate import init_empty_weights, infer_auto_device_map
6
 
7
  from PIL import Image
@@ -19,11 +19,11 @@ class EndpointHandler():
19
  model = Blip2ForConditionalGeneration(config)
20
  device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
21
  device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
22
-
23
  self.model = Blip2ForConditionalGeneration.from_pretrained(
24
  "Salesforce/blip2-flan-t5-xxl", device_map=device_map,
25
  torch_dtype=torch.float16,
26
- load_in_8bit=True, load_in_8bit_fp32_cpu_offload=True
27
  )
28
 
29
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
1
  from typing import Dict, Any
2
 
3
  import torch
4
+ from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration, BitsAndBytesConfig
5
  from accelerate import init_empty_weights, infer_auto_device_map
6
 
7
  from PIL import Image
 
19
  model = Blip2ForConditionalGeneration(config)
20
  device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
21
  device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
22
+
23
  self.model = Blip2ForConditionalGeneration.from_pretrained(
24
  "Salesforce/blip2-flan-t5-xxl", device_map=device_map,
25
  torch_dtype=torch.float16,
26
+ quantization_config=BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
27
  )
28
 
29
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: