luohoa97 commited on
Commit
3a263ff
·
verified ·
1 Parent(s): e8b6287

Deploy BitNet-Transformer Trainer

Browse files
Files changed (1) hide show
  1. scripts/train_ai_model.py +30 -9
scripts/train_ai_model.py CHANGED
@@ -109,11 +109,14 @@ def train():
109
  # 4. Dynamic Batch Sizing
110
  batch_size = get_max_batch_size(model, input_dim, SEQ_LEN, device)
111
 
112
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)
113
- val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True, num_workers=2)
114
 
115
  optimizer = optim.AdamW(model.parameters(), lr=LR)
116
- criterion = nn.CrossEntropyLoss()
 
 
 
117
 
118
  # Mixed Precision Setup
119
  dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
@@ -167,8 +170,19 @@ def train():
167
  loss = criterion(outputs, batch_y)
168
 
169
  val_loss += loss.item()
170
- preds = torch.argmax(outputs, dim=-1)
171
- all_preds.extend(preds.cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
172
  all_true.extend(batch_y.cpu().numpy())
173
  all_rets.extend(batch_r.numpy())
174
 
@@ -181,12 +195,19 @@ def train():
181
 
182
  buys = int((all_preds == 1).sum())
183
  sells = int((all_preds == 2).sum())
184
- pnl = float(np.sum(all_rets[all_preds == 1]) - np.sum(all_rets[all_preds == 2]))
185
- win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
 
 
 
 
 
186
 
187
  tqdm.write(f"\n--- Epoch {epoch+1} Statistics ---")
188
- tqdm.write(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {pnl:+.4f} | Win Rate: {win_rate:.1%}")
189
- tqdm.write(f"Signals: {buys} BUY | {sells} SELL | Activity: {(buys+sells)/len(all_preds):.1%}")
 
 
190
 
191
  if buys + sells > 0:
192
  cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
 
109
  # 4. Dynamic Batch Sizing
110
  batch_size = get_max_batch_size(model, input_dim, SEQ_LEN, device)
111
 
112
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0)
113
+ val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True, num_workers=0)
114
 
115
  optimizer = optim.AdamW(model.parameters(), lr=LR)
116
+
117
+ # 5. Class Weights (HOLD: 2.0, BUY: 1.0, SELL: 3.0)
118
+ class_weights = torch.tensor([2.0, 1.0, 3.0]).to(device)
119
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
120
 
121
  # Mixed Precision Setup
122
  dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
 
170
  loss = criterion(outputs, batch_y)
171
 
172
  val_loss += loss.item()
173
+
174
+ # Apply Probability Threshold (0.6)
175
+ probs = torch.softmax(outputs, dim=-1)
176
+ conf, preds = torch.max(probs, dim=-1)
177
+
178
+ # If confidence < 0.6, force HOLD (0)
179
+ # This reduces noisy trades and targets high-conviction signals
180
+ threshold = 0.6
181
+ final_preds = preds.clone()
182
+ mask = (conf < threshold) & (preds != 0)
183
+ final_preds[mask] = 0
184
+
185
+ all_preds.extend(final_preds.cpu().numpy())
186
  all_true.extend(batch_y.cpu().numpy())
187
  all_rets.extend(batch_r.numpy())
188
 
 
195
 
196
  buys = int((all_preds == 1).sum())
197
  sells = int((all_preds == 2).sum())
198
+
199
+ buy_pnl = float(np.sum(all_rets[all_preds == 1]))
200
+ sell_pnl = float(-np.sum(all_rets[all_preds == 2])) # Future return is inverse for SELL
201
+ total_pnl = buy_pnl + sell_pnl
202
+
203
+ buy_win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
204
+ sell_win_rate = float(np.sum((all_preds == 2) & (all_true == 2)) / (sells + 1e-6))
205
 
206
  tqdm.write(f"\n--- Epoch {epoch+1} Statistics ---")
207
+ tqdm.write(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {total_pnl:+.4f}")
208
+ tqdm.write(f"BUYs: {buys} | PnL: {buy_pnl:+.4f} | Win Rate: {buy_win_rate:.1%}")
209
+ tqdm.write(f"SELLs: {sells} | PnL: {sell_pnl:+.4f} | Win Rate: {sell_win_rate:.1%}")
210
+ tqdm.write(f"Activity: {(buys+sells)/len(all_preds):.1%}")
211
 
212
  if buys + sells > 0:
213
  cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])