Spaces:
Running
Running
comment wandb
Browse files
main.py
CHANGED
|
@@ -9,7 +9,7 @@ from PIL import Image
|
|
| 9 |
import time
|
| 10 |
import numpy as np
|
| 11 |
import math
|
| 12 |
-
import wandb
|
| 13 |
from datasets import load_dataset
|
| 14 |
import os
|
| 15 |
|
|
@@ -98,7 +98,7 @@ if __name__ == "__main__":
|
|
| 98 |
"BS": 32,
|
| 99 |
"HIDDEN": 128,
|
| 100 |
"DROPOUT": 0.01,
|
| 101 |
-
"WANDB": True,
|
| 102 |
"REPO_ID": "TristanKE/RemainingLifespanPredictionFaces",
|
| 103 |
# "REPO_ID": "TristanKE/RemainingLifespanPredictionWholeImgs",
|
| 104 |
"DINO_MODEL": "dinov2_vitl14_reg",
|
|
@@ -107,8 +107,8 @@ if __name__ == "__main__":
|
|
| 107 |
# "DINO_DIM": 1536, #for the larger model
|
| 108 |
}
|
| 109 |
|
| 110 |
-
if cfg["WANDB"]:
|
| 111 |
-
|
| 112 |
|
| 113 |
torch.manual_seed(1)
|
| 114 |
ds = LifespanDataset(repo_id=cfg["REPO_ID"],transform=imgtransform)
|
|
@@ -147,12 +147,12 @@ if __name__ == "__main__":
|
|
| 147 |
|
| 148 |
tr_nll += loss.item() * imgs.size(0)
|
| 149 |
tr_mae += torch.abs(mu.detach() - tgt).sum().item()
|
| 150 |
-
if cfg["WANDB"]:
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
|
| 157 |
tr_nll /= train_sz
|
| 158 |
tr_mae = tr_mae / train_sz * ds.lifespan_std
|
|
@@ -172,13 +172,13 @@ if __name__ == "__main__":
|
|
| 172 |
|
| 173 |
print(f"Epoch {epoch+1}/{cfg['N_EPOCHS']} | {time.time()-t0:.1f}s | NLL tr {tr_nll:.3f} / te {te_nll:.3f} | MAE(te) {te_mae:.2f} yrs")
|
| 174 |
|
| 175 |
-
if cfg["WANDB"]:
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
|
| 183 |
scheduler.step()
|
| 184 |
|
|
@@ -190,5 +190,5 @@ if __name__ == "__main__":
|
|
| 190 |
torch.save(model.state_dict(), f"savedmodels/dino_finetuned_faces_l1_{cfg['DINO_DIM']}_best.pth")
|
| 191 |
print(f"\tNew best model saved (test MAE {te_mae:.3f})")
|
| 192 |
|
| 193 |
-
if cfg["WANDB"]:
|
| 194 |
-
|
|
|
|
| 9 |
import time
|
| 10 |
import numpy as np
|
| 11 |
import math
|
| 12 |
+
# import wandb
|
| 13 |
from datasets import load_dataset
|
| 14 |
import os
|
| 15 |
|
|
|
|
| 98 |
"BS": 32,
|
| 99 |
"HIDDEN": 128,
|
| 100 |
"DROPOUT": 0.01,
|
| 101 |
+
# "WANDB": True,
|
| 102 |
"REPO_ID": "TristanKE/RemainingLifespanPredictionFaces",
|
| 103 |
# "REPO_ID": "TristanKE/RemainingLifespanPredictionWholeImgs",
|
| 104 |
"DINO_MODEL": "dinov2_vitl14_reg",
|
|
|
|
| 107 |
# "DINO_DIM": 1536, #for the larger model
|
| 108 |
}
|
| 109 |
|
| 110 |
+
# if cfg["WANDB"]:
|
| 111 |
+
# wandb.init(project="mortpred", config=cfg)
|
| 112 |
|
| 113 |
torch.manual_seed(1)
|
| 114 |
ds = LifespanDataset(repo_id=cfg["REPO_ID"],transform=imgtransform)
|
|
|
|
| 147 |
|
| 148 |
tr_nll += loss.item() * imgs.size(0)
|
| 149 |
tr_mae += torch.abs(mu.detach() - tgt).sum().item()
|
| 150 |
+
# if cfg["WANDB"]:
|
| 151 |
+
# wandb.log({
|
| 152 |
+
# "train_nll": loss.item(),
|
| 153 |
+
# "train_mae": torch.abs(mu.detach() - tgt).mean().item() * ds.lifespan_std,
|
| 154 |
+
# "train_std": torch.exp(0.5 * logvar).mean().item() * ds.lifespan_std,
|
| 155 |
+
# })
|
| 156 |
|
| 157 |
tr_nll /= train_sz
|
| 158 |
tr_mae = tr_mae / train_sz * ds.lifespan_std
|
|
|
|
| 172 |
|
| 173 |
print(f"Epoch {epoch+1}/{cfg['N_EPOCHS']} | {time.time()-t0:.1f}s | NLL tr {tr_nll:.3f} / te {te_nll:.3f} | MAE(te) {te_mae:.2f} yrs")
|
| 174 |
|
| 175 |
+
# if cfg["WANDB"]:
|
| 176 |
+
# wandb.log({
|
| 177 |
+
# "train_nll": tr_nll,
|
| 178 |
+
# "test_nll": te_nll,
|
| 179 |
+
# "test_mae_yrs": te_mae,
|
| 180 |
+
# "lr": scheduler.get_last_lr()[0],
|
| 181 |
+
# })
|
| 182 |
|
| 183 |
scheduler.step()
|
| 184 |
|
|
|
|
| 190 |
torch.save(model.state_dict(), f"savedmodels/dino_finetuned_faces_l1_{cfg['DINO_DIM']}_best.pth")
|
| 191 |
print(f"\tNew best model saved (test MAE {te_mae:.3f})")
|
| 192 |
|
| 193 |
+
# if cfg["WANDB"]:
|
| 194 |
+
# wandb.finish()
|