jljiu commited on
Commit
88ca2f2
·
verified ·
1 Parent(s): a903621

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +6 -1
train.py CHANGED
@@ -9,6 +9,9 @@ import torch
9
 
10
  class ModelTrainer:
11
  def __init__(self, model_id, system_prompts_path):
 
 
 
12
  self.model_id = model_id
13
 
14
  # 加载系统提示词
@@ -27,7 +30,9 @@ class ModelTrainer:
27
  trust_remote_code=True,
28
  torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
29
  device_map='auto', # 自动选择设备
30
- low_cpu_mem_usage=True
 
 
31
  )
32
 
33
  # 使用更轻量的LoRA配置
 
9
 
10
  class ModelTrainer:
11
  def __init__(self, model_id, system_prompts_path):
12
+ # 确保临时文件夹存在
13
+ os.makedirs("temp_model_dir", exist_ok=True)
14
+
15
  self.model_id = model_id
16
 
17
  # 加载系统提示词
 
30
  trust_remote_code=True,
31
  torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
32
  device_map='auto', # 自动选择设备
33
+ low_cpu_mem_usage=True,
34
+ offload_folder="temp_model_dir", # 添加临时文件夹
35
+ use_safetensors=True # 使用 safetensors
36
  )
37
 
38
  # 使用更轻量的LoRA配置