Spaces:
Runtime error
Runtime error
Deploy BitNet-Transformer Trainer
Browse files
scripts/train_ai_model.py
CHANGED
|
@@ -45,7 +45,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
|
|
| 45 |
if device.type == 'cpu':
|
| 46 |
return 64
|
| 47 |
|
| 48 |
-
|
| 49 |
batch_size = start_batch
|
| 50 |
last_success = batch_size
|
| 51 |
|
|
@@ -70,7 +70,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
|
|
| 70 |
except RuntimeError as e:
|
| 71 |
pbar.close()
|
| 72 |
if "out of memory" in str(e).lower():
|
| 73 |
-
|
| 74 |
torch.cuda.empty_cache()
|
| 75 |
else:
|
| 76 |
raise e
|
|
@@ -121,7 +121,7 @@ def train():
|
|
| 121 |
use_bf16 = (dtype == torch.bfloat16)
|
| 122 |
scaler = torch.amp.GradScaler(device_type, enabled=(not use_bf16 and device.type == 'cuda'))
|
| 123 |
|
| 124 |
-
|
| 125 |
best_val_loss = float('inf')
|
| 126 |
|
| 127 |
for epoch in range(EPOCHS):
|
|
@@ -184,13 +184,13 @@ def train():
|
|
| 184 |
pnl = float(np.sum(all_rets[all_preds == 1]) - np.sum(all_rets[all_preds == 2]))
|
| 185 |
win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
|
| 191 |
if buys + sells > 0:
|
| 192 |
cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
|
| 193 |
-
|
| 194 |
|
| 195 |
if avg_val_loss < best_val_loss:
|
| 196 |
best_val_loss = avg_val_loss
|
|
|
|
| 45 |
if device.type == 'cpu':
|
| 46 |
return 64
|
| 47 |
|
| 48 |
+
tqdm.write("π Searching for optimal batch size for your GPU...")
|
| 49 |
batch_size = start_batch
|
| 50 |
last_success = batch_size
|
| 51 |
|
|
|
|
| 70 |
except RuntimeError as e:
|
| 71 |
pbar.close()
|
| 72 |
if "out of memory" in str(e).lower():
|
| 73 |
+
tqdm.write(f"π‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.")
|
| 74 |
torch.cuda.empty_cache()
|
| 75 |
else:
|
| 76 |
raise e
|
|
|
|
| 121 |
use_bf16 = (dtype == torch.bfloat16)
|
| 122 |
scaler = torch.amp.GradScaler(device_type, enabled=(not use_bf16 and device.type == 'cuda'))
|
| 123 |
|
| 124 |
+
tqdm.write(f"π Starting training (Batch Size: {batch_size}, Precision: {dtype})...")
|
| 125 |
best_val_loss = float('inf')
|
| 126 |
|
| 127 |
for epoch in range(EPOCHS):
|
|
|
|
| 184 |
pnl = float(np.sum(all_rets[all_preds == 1]) - np.sum(all_rets[all_preds == 2]))
|
| 185 |
win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
|
| 186 |
|
| 187 |
+
tqdm.write(f"\n--- Epoch {epoch+1} Statistics ---")
|
| 188 |
+
tqdm.write(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {pnl:+.4f} | Win Rate: {win_rate:.1%}")
|
| 189 |
+
tqdm.write(f"Signals: {buys} BUY | {sells} SELL | Activity: {(buys+sells)/len(all_preds):.1%}")
|
| 190 |
|
| 191 |
if buys + sells > 0:
|
| 192 |
cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
|
| 193 |
+
tqdm.write(f"Confusion Matrix (HOLD/BUY/SELL):\n{cm}")
|
| 194 |
|
| 195 |
if avg_val_loss < best_val_loss:
|
| 196 |
best_val_loss = avg_val_loss
|