File size: 1,606 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import matplotlib
matplotlib.use('Agg')

import os
import wandb
import pytorch_lightning as pl

from data import VerseDataModule
from model import VerseFxClassifier
from utils.config import get_config

if __name__ == '__main__':
    config = get_config("config.yaml")

    USE_WANDB = 'online' if config.pop('USE_WANDB', False) else 'disabled'
    WANDB_API_KEY = config.pop('WANDB_API_KEY')
    SAVE_MODEL = config.pop('SAVE_MODEL')

    wandb.login(key=WANDB_API_KEY)

    run = wandb.init(
        project=f'fx-{config["task"]}-baseline-3d', 
        entity='ifl-diva',
        config=config,
        mode=USE_WANDB
    )

    hparams = wandb.config

    wandb_logger = pl.loggers.WandbLogger()

    model = VerseFxClassifier(hparams)
    data = VerseDataModule(hparams)

    callbacks = [pl.callbacks.EarlyStopping(monitor="val/F1", mode="max", patience=hparams.early_stopping_patience)]

    if bool(SAVE_MODEL):
        callbacks.append(pl.callbacks.model_checkpoint.ModelCheckpoint(monitor='val/F1', mode="max",
                                                                       dirpath='saved_models',
                                                                       filename=f"{wandb.run.name}-epoch{{epoch}}-val_F1={{val/F1:.3f}}",
                                                                       auto_insert_metric_name=False))

    trainer = pl.Trainer(
        gpus=1,
        logger=wandb_logger,
        log_every_n_steps=2,
        #max_epochs=2,
        callbacks=callbacks,
        # auto_lr_find=hparams.auto_lr_find,
    )

    with run:
        trainer.fit(model, data)