Files changed (1) hide show
  1. README.md +3 -1
README.md CHANGED
@@ -78,7 +78,7 @@ def setup_pipeline(model_path, use_4bit=True):
78
  tokenizer = AutoTokenizer.from_pretrained(model_path)
79
 
80
  model_kwargs = {"device_map": "auto"}
81
-
82
  if use_4bit:
83
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
84
  load_in_4bit=True,
@@ -86,6 +86,8 @@ def setup_pipeline(model_path, use_4bit=True):
86
  bnb_4bit_use_double_quant=True,
87
  bnb_4bit_quant_type="nf4",
88
  )
 
 
89
 
90
  model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
91
 
 
78
  tokenizer = AutoTokenizer.from_pretrained(model_path)
79
 
80
  model_kwargs = {"device_map": "auto"}
81
+
82
  if use_4bit:
83
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
84
  load_in_4bit=True,
 
86
  bnb_4bit_use_double_quant=True,
87
  bnb_4bit_quant_type="nf4",
88
  )
89
+ else:
90
+ model_kwargs["torch_dtype"] = torch.bfloat16
91
 
92
  model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
93