SiyaYan commited on
Commit
1810dad
·
verified ·
1 Parent(s): 3c92847

comment wandb

Browse files
Files changed (1) hide show
  1. main.py +19 -19
main.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
  import time
10
  import numpy as np
11
  import math
12
- import wandb
13
  from datasets import load_dataset
14
  import os
15
 
@@ -98,7 +98,7 @@ if __name__ == "__main__":
98
  "BS": 32,
99
  "HIDDEN": 128,
100
  "DROPOUT": 0.01,
101
- "WANDB": True,
102
  "REPO_ID": "TristanKE/RemainingLifespanPredictionFaces",
103
  # "REPO_ID": "TristanKE/RemainingLifespanPredictionWholeImgs",
104
  "DINO_MODEL": "dinov2_vitl14_reg",
@@ -107,8 +107,8 @@ if __name__ == "__main__":
107
  # "DINO_DIM": 1536, #for the larger model
108
  }
109
 
110
- if cfg["WANDB"]:
111
- wandb.init(project="mortpred", config=cfg)
112
 
113
  torch.manual_seed(1)
114
  ds = LifespanDataset(repo_id=cfg["REPO_ID"],transform=imgtransform)
@@ -147,12 +147,12 @@ if __name__ == "__main__":
147
 
148
  tr_nll += loss.item() * imgs.size(0)
149
  tr_mae += torch.abs(mu.detach() - tgt).sum().item()
150
- if cfg["WANDB"]:
151
- wandb.log({
152
- "train_nll": loss.item(),
153
- "train_mae": torch.abs(mu.detach() - tgt).mean().item() * ds.lifespan_std,
154
- "train_std": torch.exp(0.5 * logvar).mean().item() * ds.lifespan_std,
155
- })
156
 
157
  tr_nll /= train_sz
158
  tr_mae = tr_mae / train_sz * ds.lifespan_std
@@ -172,13 +172,13 @@ if __name__ == "__main__":
172
 
173
  print(f"Epoch {epoch+1}/{cfg['N_EPOCHS']} | {time.time()-t0:.1f}s | NLL tr {tr_nll:.3f} / te {te_nll:.3f} | MAE(te) {te_mae:.2f} yrs")
174
 
175
- if cfg["WANDB"]:
176
- wandb.log({
177
- "train_nll": tr_nll,
178
- "test_nll": te_nll,
179
- "test_mae_yrs": te_mae,
180
- "lr": scheduler.get_last_lr()[0],
181
- })
182
 
183
  scheduler.step()
184
 
@@ -190,5 +190,5 @@ if __name__ == "__main__":
190
  torch.save(model.state_dict(), f"savedmodels/dino_finetuned_faces_l1_{cfg['DINO_DIM']}_best.pth")
191
  print(f"\tNew best model saved (test MAE {te_mae:.3f})")
192
 
193
- if cfg["WANDB"]:
194
- wandb.finish()
 
9
  import time
10
  import numpy as np
11
  import math
12
+ # import wandb
13
  from datasets import load_dataset
14
  import os
15
 
 
98
  "BS": 32,
99
  "HIDDEN": 128,
100
  "DROPOUT": 0.01,
101
+ # "WANDB": True,
102
  "REPO_ID": "TristanKE/RemainingLifespanPredictionFaces",
103
  # "REPO_ID": "TristanKE/RemainingLifespanPredictionWholeImgs",
104
  "DINO_MODEL": "dinov2_vitl14_reg",
 
107
  # "DINO_DIM": 1536, #for the larger model
108
  }
109
 
110
+ # if cfg["WANDB"]:
111
+ # wandb.init(project="mortpred", config=cfg)
112
 
113
  torch.manual_seed(1)
114
  ds = LifespanDataset(repo_id=cfg["REPO_ID"],transform=imgtransform)
 
147
 
148
  tr_nll += loss.item() * imgs.size(0)
149
  tr_mae += torch.abs(mu.detach() - tgt).sum().item()
150
+ # if cfg["WANDB"]:
151
+ # wandb.log({
152
+ # "train_nll": loss.item(),
153
+ # "train_mae": torch.abs(mu.detach() - tgt).mean().item() * ds.lifespan_std,
154
+ # "train_std": torch.exp(0.5 * logvar).mean().item() * ds.lifespan_std,
155
+ # })
156
 
157
  tr_nll /= train_sz
158
  tr_mae = tr_mae / train_sz * ds.lifespan_std
 
172
 
173
  print(f"Epoch {epoch+1}/{cfg['N_EPOCHS']} | {time.time()-t0:.1f}s | NLL tr {tr_nll:.3f} / te {te_nll:.3f} | MAE(te) {te_mae:.2f} yrs")
174
 
175
+ # if cfg["WANDB"]:
176
+ # wandb.log({
177
+ # "train_nll": tr_nll,
178
+ # "test_nll": te_nll,
179
+ # "test_mae_yrs": te_mae,
180
+ # "lr": scheduler.get_last_lr()[0],
181
+ # })
182
 
183
  scheduler.step()
184
 
 
190
  torch.save(model.state_dict(), f"savedmodels/dino_finetuned_faces_l1_{cfg['DINO_DIM']}_best.pth")
191
  print(f"\tNew best model saved (test MAE {te_mae:.3f})")
192
 
193
+ # if cfg["WANDB"]:
194
+ # wandb.finish()