YasaminAbb
commited on
Commit
•
43d5442
1
Parent(s):
1227d51
Update handler.py
Browse files- handler.py +4 -4
handler.py
CHANGED
@@ -5,22 +5,22 @@ from peft import PeftConfig, PeftModel
|
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self,path=""):
|
|
|
|
|
8 |
bnb_config = BitsAndBytesConfig(
|
9 |
load_in_4bit=True,
|
10 |
bnb_4bit_use_double_quant=True,
|
11 |
bnb_4bit_quant_type="nf4",
|
12 |
-
bnb_4bit_compute_dtype=
|
13 |
)
|
14 |
config = PeftConfig.from_pretrained(path)
|
15 |
-
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
config.base_model_name_or_path,
|
18 |
return_dict=True,
|
19 |
quantization_config=bnb_config,
|
20 |
-
device_map="auto" ,
|
21 |
torch_dtype=dtype,
|
22 |
trust_remote_code=True,
|
23 |
-
load_in_8bit=True
|
24 |
)
|
25 |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
26 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self,path=""):
|
8 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
9 |
+
|
10 |
bnb_config = BitsAndBytesConfig(
|
11 |
load_in_4bit=True,
|
12 |
bnb_4bit_use_double_quant=True,
|
13 |
bnb_4bit_quant_type="nf4",
|
14 |
+
bnb_4bit_compute_dtype=dtype
|
15 |
)
|
16 |
config = PeftConfig.from_pretrained(path)
|
|
|
17 |
model = AutoModelForCausalLM.from_pretrained(
|
18 |
config.base_model_name_or_path,
|
19 |
return_dict=True,
|
20 |
quantization_config=bnb_config,
|
21 |
+
device_map="auto" ,
|
22 |
torch_dtype=dtype,
|
23 |
trust_remote_code=True,
|
|
|
24 |
)
|
25 |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
26 |
tokenizer.pad_token = tokenizer.eos_token
|