luohoa97 commited on
Commit
e8b6287
Β·
verified Β·
1 Parent(s): 34e94cf

Deploy BitNet-Transformer Trainer

Browse files
Files changed (1) hide show
  1. scripts/train_ai_model.py +7 -7
scripts/train_ai_model.py CHANGED
@@ -45,7 +45,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
45
  if device.type == 'cpu':
46
  return 64
47
 
48
- print("πŸ” Searching for optimal batch size for your GPU...", flush=True)
49
  batch_size = start_batch
50
  last_success = batch_size
51
 
@@ -70,7 +70,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
70
  except RuntimeError as e:
71
  pbar.close()
72
  if "out of memory" in str(e).lower():
73
- print(f"πŸ’‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.", flush=True)
74
  torch.cuda.empty_cache()
75
  else:
76
  raise e
@@ -121,7 +121,7 @@ def train():
121
  use_bf16 = (dtype == torch.bfloat16)
122
  scaler = torch.amp.GradScaler(device_type, enabled=(not use_bf16 and device.type == 'cuda'))
123
 
124
- print(f"πŸš€ Starting training (Batch Size: {batch_size}, Precision: {dtype})...", flush=True)
125
  best_val_loss = float('inf')
126
 
127
  for epoch in range(EPOCHS):
@@ -184,13 +184,13 @@ def train():
184
  pnl = float(np.sum(all_rets[all_preds == 1]) - np.sum(all_rets[all_preds == 2]))
185
  win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
186
 
187
- print(f"\n--- Epoch {epoch+1} Statistics ---", flush=True)
188
- print(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {pnl:+.4f} | Win Rate: {win_rate:.1%}", flush=True)
189
- print(f"Signals: {buys} BUY | {sells} SELL | Activity: {(buys+sells)/len(all_preds):.1%}", flush=True)
190
 
191
  if buys + sells > 0:
192
  cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
193
- print(f"Confusion Matrix (HOLD/BUY/SELL):\n{cm}", flush=True)
194
 
195
  if avg_val_loss < best_val_loss:
196
  best_val_loss = avg_val_loss
 
45
  if device.type == 'cpu':
46
  return 64
47
 
48
+ tqdm.write("πŸ” Searching for optimal batch size for your GPU...")
49
  batch_size = start_batch
50
  last_success = batch_size
51
 
 
70
  except RuntimeError as e:
71
  pbar.close()
72
  if "out of memory" in str(e).lower():
73
+ tqdm.write(f"πŸ’‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.")
74
  torch.cuda.empty_cache()
75
  else:
76
  raise e
 
121
  use_bf16 = (dtype == torch.bfloat16)
122
  scaler = torch.amp.GradScaler(device_type, enabled=(not use_bf16 and device.type == 'cuda'))
123
 
124
+ tqdm.write(f"πŸš€ Starting training (Batch Size: {batch_size}, Precision: {dtype})...")
125
  best_val_loss = float('inf')
126
 
127
  for epoch in range(EPOCHS):
 
184
  pnl = float(np.sum(all_rets[all_preds == 1]) - np.sum(all_rets[all_preds == 2]))
185
  win_rate = float(np.sum((all_preds == 1) & (all_true == 1)) / (buys + 1e-6))
186
 
187
+ tqdm.write(f"\n--- Epoch {epoch+1} Statistics ---")
188
+ tqdm.write(f"Val Loss: {avg_val_loss:.4f} | Total PnL: {pnl:+.4f} | Win Rate: {win_rate:.1%}")
189
+ tqdm.write(f"Signals: {buys} BUY | {sells} SELL | Activity: {(buys+sells)/len(all_preds):.1%}")
190
 
191
  if buys + sells > 0:
192
  cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
193
+ tqdm.write(f"Confusion Matrix (HOLD/BUY/SELL):\n{cm}")
194
 
195
  if avg_val_loss < best_val_loss:
196
  best_val_loss = avg_val_loss