minhdang commited on
Commit
6abed24
1 Parent(s): 0c372c1

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +7 -1
inference.py CHANGED
@@ -33,12 +33,18 @@ from transformers import (
33
  from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
34
  from deepseek_vl.utils.conversation import Conversation
35
 
 
36
 
 
 
 
 
 
37
  def load_model(model_path):
38
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
39
  tokenizer = vl_chat_processor.tokenizer
40
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
41
- model_path, trust_remote_code=True
42
  )
43
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
44
  return tokenizer, vl_gpt, vl_chat_processor
 
33
  from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
34
  from deepseek_vl.utils.conversation import Conversation
35
 
36
+ from transformers import BitsAndBytesConfig
37
 
38
+ nf8_config = BitsAndBytesConfig(
39
+ load_in_8bit=True,
40
+ bnb_8bit_use_double_quant=True,
41
+ bnb_8bit_quant_type="nf8",
42
+ )
43
  def load_model(model_path):
44
  vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
45
  tokenizer = vl_chat_processor.tokenizer
46
  vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
47
+ model_path, trust_remote_code=True, quantization_config=nf8_config
48
  )
49
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
50
  return tokenizer, vl_gpt, vl_chat_processor