Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
@@ -22,6 +22,7 @@ import transformers
|
|
22 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
|
23 |
#import auto_gptq
|
24 |
#from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
|
|
25 |
|
26 |
|
27 |
def reset_state():
|
@@ -99,6 +100,19 @@ def load_tokenizer_and_model(base_model,load_8bit=False):
|
|
99 |
return tokenizer,model,device
|
100 |
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
|
103 |
if torch.cuda.is_available():
|
104 |
device = "cuda"
|
|
|
22 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
|
23 |
#import auto_gptq
|
24 |
#from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
25 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
26 |
|
27 |
|
28 |
def reset_state():
|
|
|
100 |
return tokenizer,model,device
|
101 |
|
102 |
|
103 |
+
def load_tokenizer_and_model_Baize(base_model, load_8bit=True):
|
104 |
+
if torch.cuda.is_available():
|
105 |
+
device = "cuda"
|
106 |
+
else:
|
107 |
+
device = "cpu"
|
108 |
+
|
109 |
+
|
110 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True, use_auth_token=True)
|
111 |
+
model = LlamaForCausalLM.from_pretrained(base_model, load_in_8bit=True, device_map="auto")
|
112 |
+
|
113 |
+
return tokenizer,model, device
|
114 |
+
|
115 |
+
|
116 |
def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
|
117 |
if torch.cuda.is_available():
|
118 |
device = "cuda"
|