jljiu commited on
Commit
a903621
·
verified ·
1 Parent(s): 26da30c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +18 -109
train.py CHANGED
@@ -12,118 +12,27 @@ class ModelTrainer:
12
  self.model_id = model_id
13
 
14
  # 加载系统提示词
15
- try:
16
- with open(system_prompts_path, 'r', encoding='utf-8') as f:
17
- self.system_prompts = json.load(f)
18
- except Exception as e:
19
- print(f"加载系统提示词失败: {e}")
20
- self.system_prompts = {"base_prompt": "默认系统提示词"}
21
-
22
- # 首先尝试检测可用资源
23
- self.device = self._detect_device()
24
- self.dtype = self._detect_optimal_dtype()
25
 
26
- # 初始化模型和分词器
27
- self._initialize_model_and_tokenizer()
28
-
29
- def _detect_device(self):
30
- """检测并返回最优设备配置"""
31
- try:
32
- if torch.cuda.is_available():
33
- print("检测到 CUDA 设备")
34
- return "cuda"
35
- elif torch.backends.mps.is_available():
36
- print("检测到 MPS 设备")
37
- return "mps"
38
- else:
39
- print("使用 CPU 设备")
40
- return "cpu"
41
- except:
42
- print("设备检测失败,默认使用 CPU")
43
- return "cpu"
44
-
45
- def _detect_optimal_dtype(self):
46
- """检测并返回最优数据类型"""
47
- try:
48
- if self.device == "cuda":
49
- if torch.cuda.get_device_capability()[0] >= 7:
50
- print("使用 float16 精度")
51
- return torch.float16
52
- print("使用 float32 精度")
53
- return torch.float32
54
- except:
55
- print("数据类型检测失败,默认使用 float32")
56
- return torch.float32
57
-
58
- def _initialize_model_and_tokenizer(self):
59
- """初始化模型和分词器,包含多个备选方案"""
60
- print(f"开始初始化模型,使用设备: {self.device},数据类型: {self.dtype}")
61
 
62
- # 首先尝试加载分词器
63
- try:
64
- self.tokenizer = AutoTokenizer.from_pretrained(
65
- self.model_id,
66
- trust_remote_code=True
67
- )
68
- except Exception as e:
69
- print(f"分词器加载失败: {e}")
70
- raise RuntimeError("分词器初始化失败")
71
-
72
- # 尝试不同的模型加载配置
73
- loading_configs = [
74
- # 配置1:标准加载
75
- {
76
- "trust_remote_code": True,
77
- "torch_dtype": self.dtype,
78
- "device_map": "auto",
79
- "low_cpu_mem_usage": True
80
- },
81
- # 配置2:8bit量化加载
82
- {
83
- "trust_remote_code": True,
84
- "load_in_8bit": True,
85
- "device_map": "auto",
86
- },
87
- # 配置3:CPU加载
88
- {
89
- "trust_remote_code": True,
90
- "torch_dtype": torch.float32,
91
- "device_map": "cpu",
92
- "low_cpu_mem_usage": True
93
- }
94
- ]
95
-
96
- last_exception = None
97
- for config in loading_configs:
98
- try:
99
- print(f"尝试加载模型,配置: {config}")
100
- self.model = AutoModelForCausalLM.from_pretrained(
101
- self.model_id,
102
- **config
103
- )
104
- print("模型加载成功")
105
-
106
- # 配置 LoRA
107
- try:
108
- self._setup_lora()
109
- print("LoRA 配置成功")
110
- except Exception as e:
111
- print(f"LoRA 配置失败: {e}")
112
- raise
113
-
114
- return # 成功加载后退出
115
- except Exception as e:
116
- last_exception = e
117
- print(f"当前配置加载失败: {e}")
118
- continue
119
-
120
- # 如果所有配置都失败
121
- raise RuntimeError(f"所有模型加载配置均失败,最后的错误: {last_exception}")
122
-
123
- def _setup_lora(self):
124
- """配置 LoRA"""
125
  self.lora_config = LoraConfig(
126
- r=4,
127
  lora_alpha=16,
128
  target_modules=["q_proj", "v_proj"],
129
  lora_dropout=0.05,
 
12
  self.model_id = model_id
13
 
14
  # 加载系统提示词
15
+ with open(system_prompts_path, 'r', encoding='utf-8') as f:
16
+ self.system_prompts = json.load(f)
 
 
 
 
 
 
 
 
17
 
18
+ # 修改模型初始化参数
19
+ self.tokenizer = AutoTokenizer.from_pretrained(
20
+ model_id,
21
+ trust_remote_code=True
22
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # 修改这部分的初始化参数
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.float32, # 使用 torch.float32 而不是字符串
29
+ device_map='auto', # 自动选择设备
30
+ low_cpu_mem_usage=True
31
+ )
32
+
33
+ # 使用更轻量的LoRA配置
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  self.lora_config = LoraConfig(
35
+ r=4, # 降低rank
36
  lora_alpha=16,
37
  target_modules=["q_proj", "v_proj"],
38
  lora_dropout=0.05,