SaffalPoosh's picture
Upload folder using huggingface_hub
e5765b1
raw
history blame
No virus
5.58 kB
import argparse
import os
from pathlib import Path
from typing import Dict
import pytorch_lightning as pl
import torch
import yaml
from albumentations.core.serialization import from_dict
from iglovikov_helper_functions.config_parsing.utils import object_from_dict
from iglovikov_helper_functions.dl.pytorch.lightning import find_average
from iglovikov_helper_functions.dl.pytorch.utils import state_dict_from_disk
from pytorch_lightning.loggers import WandbLogger
from pytorch_toolbelt.losses import JaccardLoss, BinaryFocalLoss
from torch.utils.data import DataLoader
from cloths_segmentation.dataloaders import SegmentationDataset
from cloths_segmentation.metrics import binary_mean_iou
from cloths_segmentation.utils import get_samples
image_path = Path(os.environ["IMAGE_PATH"])
mask_path = Path(os.environ["MASK_PATH"])
def get_args():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg("-c", "--config_path", type=Path, help="Path to the config.", required=True)
return parser.parse_args()
class SegmentPeople(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.model = object_from_dict(self.hparams["model"])
if "resume_from_checkpoint" in self.hparams:
corrections: Dict[str, str] = {"model.": ""}
state_dict = state_dict_from_disk(
file_path=self.hparams["resume_from_checkpoint"],
rename_in_layers=corrections,
)
self.model.load_state_dict(state_dict)
self.losses = [
("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)),
("focal", 0.9, BinaryFocalLoss()),
]
def forward(self, batch: torch.Tensor) -> torch.Tensor: # type: ignore
return self.model(batch)
def setup(self, stage=0):
samples = get_samples(image_path, mask_path)
num_train = int((1 - self.hparams["val_split"]) * len(samples))
self.train_samples = samples[:num_train]
self.val_samples = samples[num_train:]
print("Len train samples = ", len(self.train_samples))
print("Len val samples = ", len(self.val_samples))
def train_dataloader(self):
train_aug = from_dict(self.hparams["train_aug"])
if "epoch_length" not in self.hparams["train_parameters"]:
epoch_length = None
else:
epoch_length = self.hparams["train_parameters"]["epoch_length"]
result = DataLoader(
SegmentationDataset(self.train_samples, train_aug, epoch_length),
batch_size=self.hparams["train_parameters"]["batch_size"],
num_workers=self.hparams["num_workers"],
shuffle=True,
pin_memory=True,
drop_last=True,
)
print("Train dataloader = ", len(result))
return result
def val_dataloader(self):
val_aug = from_dict(self.hparams["val_aug"])
result = DataLoader(
SegmentationDataset(self.val_samples, val_aug, length=None),
batch_size=self.hparams["val_parameters"]["batch_size"],
num_workers=self.hparams["num_workers"],
shuffle=False,
pin_memory=True,
drop_last=False,
)
print("Val dataloader = ", len(result))
return result
def configure_optimizers(self):
optimizer = object_from_dict(
self.hparams["optimizer"],
params=[x for x in self.model.parameters() if x.requires_grad],
)
scheduler = object_from_dict(self.hparams["scheduler"], optimizer=optimizer)
self.optimizers = [optimizer]
return self.optimizers, [scheduler]
def training_step(self, batch, batch_idx):
features = batch["features"]
masks = batch["masks"]
logits = self.forward(features)
total_loss = 0
logs = {}
for loss_name, weight, loss in self.losses:
ls_mask = loss(logits, masks)
total_loss += weight * ls_mask
logs[f"train_mask_{loss_name}"] = ls_mask
logs["train_loss"] = total_loss
logs["lr"] = self._get_current_lr()
return {"loss": total_loss, "log": logs}
def _get_current_lr(self) -> torch.Tensor:
lr = [x["lr"] for x in self.optimizers[0].param_groups][0] # type: ignore
return torch.Tensor([lr])[0].cuda()
def validation_step(self, batch, batch_id):
features = batch["features"]
masks = batch["masks"]
logits = self.forward(features)
result = {}
for loss_name, _, loss in self.losses:
result[f"val_mask_{loss_name}"] = loss(logits, masks)
result["val_iou"] = binary_mean_iou(logits, masks)
return result
def validation_epoch_end(self, outputs):
logs = {"epoch": self.trainer.current_epoch}
avg_val_iou = find_average(outputs, "val_iou")
logs["val_iou"] = avg_val_iou
return {"val_iou": avg_val_iou, "log": logs}
def main():
args = get_args()
with open(args.config_path) as f:
hparams = yaml.load(f, Loader=yaml.SafeLoader)
pipeline = SegmentPeople(hparams)
Path(hparams["checkpoint_callback"]["filepath"]).mkdir(exist_ok=True, parents=True)
trainer = object_from_dict(
hparams["trainer"],
logger=WandbLogger(hparams["experiment_name"]),
checkpoint_callback=object_from_dict(hparams["checkpoint_callback"]),
)
trainer.fit(pipeline)
if __name__ == "__main__":
main()