File size: 3,533 Bytes
c75c928
 
56db83f
c75c928
 
56db83f
c75c928
 
 
 
 
 
 
 
56db83f
 
c75c928
 
 
 
 
 
 
 
 
 
 
56db83f
c75c928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56db83f
 
af16e92
 
 
c769ff4
af16e92
56db83f
af16e92
c769ff4
56db83f
c75c928
 
 
 
 
 
 
 
 
 
 
56db83f
c75c928
 
 
 
 
 
 
 
 
56db83f
c75c928
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Minimal command:
    python training_loop.py --hub_dir "segments/sidewalk-semantic" --push_to_hub

Maximal command:
    python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train --push_to_hub
"""

import json
import torch

from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning.loggers import WandbLogger

from transformers import AutoConfig, SegformerForSemanticSegmentation, SegformerFeatureExtractor

from dataloader import SidewalkSegmentationDataLoader
from model import SidewalkSegmentationModel


def main(
    hub_dir: str, 
    batch_size: int = 32, 
    learning_rate: float = 6e-5, 
    model_flavor: int = 0, 
    seed: int = 42,
    split: str = "train",
    push_to_hub: bool = False,
):
    seed_everything(seed)
    logger = WandbLogger(project="sidewalk-segmentation")
    gpu_value = 1 if torch.cuda.is_available() else 0

    id2label_file = json.load(open("id2label.json", "r"))
    id2label = {int(k): v for k, v in id2label_file.items()}
    num_labels = len(id2label)

    model = SidewalkSegmentationModel(
        num_labels=num_labels,
        id2label=id2label,
        model_flavor=model_flavor,
        learning_rate=learning_rate,
    )
    data_module = SidewalkSegmentationDataLoader(
        hub_dir=hub_dir,
        batch_size=batch_size,
        split=split,
    )
    data_module.setup()

    checkpoint_callback = callbacks.ModelCheckpoint(
        dirpath="checkpoints",
        save_top_k=1,
        verbose=True,
        monitor="val_mean_iou",
        mode="max",
    )
    early_stopping_callback = callbacks.EarlyStopping(
        monitor="val_mean_iou",
        patience=5,
        verbose=True,
        mode="max",
    )

    trainer = Trainer(
        max_epochs=200,
        progress_bar_refresh_rate=10,
        gpus=gpu_value,
        logger=logger,
        callbacks=[checkpoint_callback, early_stopping_callback],
        deterministic=False,
    )
    trainer.fit(model, data_module)

    if push_to_hub:
        config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
        config.num_labels = num_labels
        config.id2label = id2label
        config.label2id = {v: k for k, v in id2label_file.items()}
        config.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co/ChainYo/segformer-{model_flavor}-sidewalk")

        checkpoint_path = checkpoint_callback.best_model_filepath
        model = SegformerForSemanticSegmentation.from_pretrained(checkpoint_path, config=config,)
        model.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co/ChainYo/segformer-{model_flavor}-sidewalk")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--hub_dir", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--learning_rate", type=float, default=6e-5)
    parser.add_argument("--model_flavor", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--split", type=str, default="train")
    parser.add_argument("--push_to_hub", action="store_true")
    args = parser.parse_args()

    main(
        hub_dir=args.hub_dir,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        model_flavor=args.model_flavor,
        seed=args.seed,
        split=args.split,
        push_to_hub=args.push_to_hub,
    )