segformer-sidewalk / training_loop.py
chainyo's picture
fix training_loop for flavors
c769ff4
raw history blame
No virus
3.53 kB
"""
Minimal command:
python training_loop.py --hub_dir "segments/sidewalk-semantic" --push_to_hub
Maximal command:
python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train --push_to_hub
"""
import json
import torch
from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning.loggers import WandbLogger
from transformers import AutoConfig, SegformerForSemanticSegmentation, SegformerFeatureExtractor
from dataloader import SidewalkSegmentationDataLoader
from model import SidewalkSegmentationModel
def main(
hub_dir: str,
batch_size: int = 32,
learning_rate: float = 6e-5,
model_flavor: int = 0,
seed: int = 42,
split: str = "train",
push_to_hub: bool = False,
):
seed_everything(seed)
logger = WandbLogger(project="sidewalk-segmentation")
gpu_value = 1 if torch.cuda.is_available() else 0
id2label_file = json.load(open("id2label.json", "r"))
id2label = {int(k): v for k, v in id2label_file.items()}
num_labels = len(id2label)
model = SidewalkSegmentationModel(
num_labels=num_labels,
id2label=id2label,
model_flavor=model_flavor,
learning_rate=learning_rate,
)
data_module = SidewalkSegmentationDataLoader(
hub_dir=hub_dir,
batch_size=batch_size,
split=split,
)
data_module.setup()
checkpoint_callback = callbacks.ModelCheckpoint(
dirpath="checkpoints",
save_top_k=1,
verbose=True,
monitor="val_mean_iou",
mode="max",
)
early_stopping_callback = callbacks.EarlyStopping(
monitor="val_mean_iou",
patience=5,
verbose=True,
mode="max",
)
trainer = Trainer(
max_epochs=200,
progress_bar_refresh_rate=10,
gpus=gpu_value,
logger=logger,
callbacks=[checkpoint_callback, early_stopping_callback],
deterministic=False,
)
trainer.fit(model, data_module)
if push_to_hub:
config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
config.num_labels = num_labels
config.id2label = id2label
config.label2id = {v: k for k, v in id2label_file.items()}
config.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co/ChainYo/segformer-{model_flavor}-sidewalk")
checkpoint_path = checkpoint_callback.best_model_filepath
model = SegformerForSemanticSegmentation.from_pretrained(checkpoint_path, config=config,)
model.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co/ChainYo/segformer-{model_flavor}-sidewalk")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--hub_dir", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", type=float, default=6e-5)
parser.add_argument("--model_flavor", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--push_to_hub", action="store_true")
args = parser.parse_args()
main(
hub_dir=args.hub_dir,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
model_flavor=args.model_flavor,
seed=args.seed,
split=args.split,
push_to_hub=args.push_to_hub,
)