SupermanxKiaski commited on
Commit
c9cbb49
1 Parent(s): cf5e410

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -104
app.py DELETED
@@ -1,104 +0,0 @@
1
- import datetime
2
- import random
3
- from argparse import ArgumentParser
4
- from pathlib import Path
5
-
6
- import numpy as np
7
- import torch
8
- import yaml
9
- from tqdm import tqdm
10
-
11
- from datasets.video_dataset import AtlasDataset
12
- from models.video_model import VideoModel
13
- from util.atlas_loss import AtlasLoss
14
- from util.util import get_optimizer
15
- from util.video_logger import DataLogger
16
-
17
-
18
- def train_model(config):
19
- # set seed
20
- seed = config["seed"]
21
- if seed == -1:
22
- seed = np.random.randint(2 ** 32)
23
- random.seed(seed)
24
- np.random.seed(seed)
25
- torch.manual_seed(seed)
26
- print(f"running with seed: {seed}.")
27
-
28
- dataset = AtlasDataset(config)
29
- model = VideoModel(config)
30
- criterion = AtlasLoss(config)
31
- optimizer = get_optimizer(config, model.parameters())
32
-
33
- logger = DataLogger(config, dataset)
34
- with tqdm(range(1, config["n_epochs"] + 1)) as tepoch:
35
- for epoch in tepoch:
36
- inputs = dataset[0]
37
- optimizer.zero_grad()
38
- outputs = model(inputs)
39
- losses = criterion(outputs, inputs)
40
-
41
- loss = 0.
42
- if config["finetune_foreground"]:
43
- loss += losses["foreground"]["loss"]
44
- elif config["finetune_background"]:
45
- loss += losses["background"]["loss"]
46
-
47
- lr = optimizer.param_groups[0]["lr"]
48
- log_data = logger.log_data(epoch, lr, losses, model, dataset)
49
-
50
- loss.backward()
51
- optimizer.step()
52
- optimizer.param_groups[0]["lr"] = max(config["min_lr"], config["gamma"] * optimizer.param_groups[0]["lr"])
53
-
54
- if config["use_wandb"]:
55
- wandb.log(log_data)
56
- else:
57
- if epoch % config["log_images_freq"] == 0:
58
- logger.save_locally(log_data)
59
-
60
- tepoch.set_description(f"Epoch {epoch}")
61
- tepoch.set_postfix(loss=loss.item())
62
-
63
-
64
- if __name__ == "__main__":
65
- parser = ArgumentParser()
66
- parser.add_argument(
67
- "--config",
68
- default="./configs/video_config.yaml",
69
- help="Config path",
70
- )
71
- parser.add_argument(
72
- "--example_config",
73
- default="car-turn_winter.yaml",
74
- help="Example config name",
75
- )
76
- args = parser.parse_args()
77
- config_path = args.config
78
-
79
- with open(config_path, "r") as f:
80
- config = yaml.safe_load(f)
81
- with open(f"./configs/video_example_configs/{args.example_config}", "r") as f:
82
- example_config = yaml.safe_load(f)
83
- config["example_config"] = args.example_config
84
- config.update(example_config)
85
-
86
- run_name = f"-{config['checkpoint_path'].split('/')[-2]}"
87
- if config["use_wandb"]:
88
- import wandb
89
-
90
- wandb.init(project=config["wandb_project"], entity=config["wandb_entity"], config=config, name=run_name)
91
- wandb.run.name = str(wandb.run.id) + wandb.run.name
92
- config = dict(wandb.config)
93
- else:
94
- now = datetime.datetime.now()
95
- run_name = f"{now.strftime('%Y-%m-%d_%H-%M-%S')}{run_name}"
96
- path = Path(f"{config['results_folder']}/{run_name}")
97
- path.mkdir(parents=True, exist_ok=True)
98
- with open(path / "config.yaml", "w") as f:
99
- yaml.dump(config, f)
100
- config["results_folder"] = str(path)
101
-
102
- train_model(config)
103
- if config["use_wandb"]:
104
- wandb.finish()