Spaces:
Runtime error
Runtime error
Deploy BitNet-Transformer Trainer
Browse files- scripts/train_ai_model.py +30 -9
scripts/train_ai_model.py
CHANGED
|
@@ -109,11 +109,14 @@ def train():
|
|
| 109 |
# 4. Dynamic Batch Sizing
|
| 110 |
batch_size = get_max_batch_size(model, input_dim, SEQ_LEN, device)
|
| 111 |
|
| 112 |
-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=
|
| 113 |
-
val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True, num_workers=
|
| 114 |
|
| 115 |
optimizer = optim.AdamW(model.parameters(), lr=LR)
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# Mixed Precision Setup
|
| 119 |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
|
@@ -167,8 +170,19 @@ def train():
|
|
| 167 |
loss = criterion(outputs, batch_y)
|
| 168 |
|
| 169 |
val_loss += loss.item()
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
all_true.extend(batch_y.cpu().numpy())
|
| 173 |
all_rets.extend(batch_r.numpy())
|
| 174 |
|
|
@@ -181,12 +195,19 @@ def train():
|
|
| 181 |
|
| 182 |
buys = int((all_preds == 1).sum())
|
| 183 |
sells = int((all_preds == 2).sum())
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
tqdm.write(f"\n--- Epoch {epoch+1} Statistics ---")
|
| 188 |
-
tqdm.write(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {
|
| 189 |
-
tqdm.write(f"
|
|
|
|
|
|
|
| 190 |
|
| 191 |
if buys + sells > 0:
|
| 192 |
cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
|
|
|
|
| 109 |
# 4. Dynamic Batch Sizing
|
| 110 |
batch_size = get_max_batch_size(model, input_dim, SEQ_LEN, device)
|
| 111 |
|
| 112 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=0)
|
| 113 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True, num_workers=0)
|
| 114 |
|
| 115 |
optimizer = optim.AdamW(model.parameters(), lr=LR)
|
| 116 |
+
|
| 117 |
+
# 5. Class Weights (HOLD: 2.0, BUY: 1.0, SELL: 3.0)
|
| 118 |
+
class_weights = torch.tensor([2.0, 1.0, 3.0]).to(device)
|
| 119 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
| 120 |
|
| 121 |
# Mixed Precision Setup
|
| 122 |
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
|
|
|
| 170 |
loss = criterion(outputs, batch_y)
|
| 171 |
|
| 172 |
val_loss += loss.item()
|
| 173 |
+
|
| 174 |
+
# Apply Probability Threshold (0.6)
|
| 175 |
+
probs = torch.softmax(outputs, dim=-1)
|
| 176 |
+
conf, preds = torch.max(probs, dim=-1)
|
| 177 |
+
|
| 178 |
+
# If confidence < 0.6, force HOLD (0)
|
| 179 |
+
# This reduces noisy trades and targets high-conviction signals
|
| 180 |
+
threshold = 0.6
|
| 181 |
+
final_preds = preds.clone()
|
| 182 |
+
mask = (conf < threshold) & (preds != 0)
|
| 183 |
+
final_preds[mask] = 0
|
| 184 |
+
|
| 185 |
+
all_preds.extend(final_preds.cpu().numpy())
|
| 186 |
all_true.extend(batch_y.cpu().numpy())
|
| 187 |
all_rets.extend(batch_r.numpy())
|
| 188 |
|
|
|
|
| 195 |
|
| 196 |
buys = int((all_preds == 1).sum())
|
| 197 |
sells = int((all_preds == 2).sum())
|
| 198 |
+
|
| 199 |
+
buy_pnl = float(np.sum(all_rets[all_preds == 1]))
|
| 200 |
+
sell_pnl = float(-np.sum(all_rets[all_preds == 2])) # Future return is inverse for SELL
|
| 201 |
+
total_pnl = buy_pnl + sell_pnl
|
| 202 |
+
|
| 203 |
+
buy_win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
|
| 204 |
+
sell_win_rate = float(np.sum((all_preds == 2) & (all_true == 2)) / (sells + 1e-6))
|
| 205 |
|
| 206 |
tqdm.write(f"\n--- Epoch {epoch+1} Statistics ---")
|
| 207 |
+
tqdm.write(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {total_pnl:+.4f}")
|
| 208 |
+
tqdm.write(f"BUYs: {buys} | PnL: {buy_pnl:+.4f} | Win Rate: {buy_win_rate:.1%}")
|
| 209 |
+
tqdm.write(f"SELLs: {sells} | PnL: {sell_pnl:+.4f} | Win Rate: {sell_win_rate:.1%}")
|
| 210 |
+
tqdm.write(f"Activity: {(buys+sells)/len(all_preds):.1%}")
|
| 211 |
|
| 212 |
if buys + sells > 0:
|
| 213 |
cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
|