Update train.py
Browse files
train.py
CHANGED
@@ -9,6 +9,9 @@ import torch
|
|
9 |
|
10 |
class ModelTrainer:
|
11 |
def __init__(self, model_id, system_prompts_path):
|
|
|
|
|
|
|
12 |
self.model_id = model_id
|
13 |
|
14 |
# 加载系统提示词
|
@@ -27,7 +30,9 @@ class ModelTrainer:
|
|
27 |
trust_remote_code=True,
|
28 |
torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
|
29 |
device_map='auto', # 自动选择设备
|
30 |
-
low_cpu_mem_usage=True
|
|
|
|
|
31 |
)
|
32 |
|
33 |
# 使用更轻量的LoRA配置
|
|
|
9 |
|
10 |
class ModelTrainer:
|
11 |
def __init__(self, model_id, system_prompts_path):
|
12 |
+
# 确保临时文件夹存在
|
13 |
+
os.makedirs("temp_model_dir", exist_ok=True)
|
14 |
+
|
15 |
self.model_id = model_id
|
16 |
|
17 |
# 加载系统提示词
|
|
|
30 |
trust_remote_code=True,
|
31 |
torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
|
32 |
device_map='auto', # 自动选择设备
|
33 |
+
low_cpu_mem_usage=True,
|
34 |
+
offload_folder="temp_model_dir", # 添加临时文件夹
|
35 |
+
use_safetensors=True # 使用 safetensors
|
36 |
)
|
37 |
|
38 |
# 使用更轻量的LoRA配置
|