# run.py import os import matplotlib.pyplot as plt import numpy as np import torchvision import torch from torchvision import transforms import math from torch.utils.data import DataLoader import wandb import model import dataset as data import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning import Trainer from dotenv import load_dotenv load_dotenv() WANDB_API_KEY = os.getenv('WANDB_API_KEY') os.environ['WANDB_API_KEY'] = WANDB_API_KEY def main(): pl.seed_everything(0, workers=True) # seed config = { 'batch_size': [32], 'lr': [5e-4], 'epoch': [15], 'arch': ['fpn'], 'encoder': ["resnext101_32x8d"] } file_img = './roi/roi_img' file_mask = './roi/roi_mask' transfs = torch.nn.Sequential( transforms.Resize((128,128), interpolation=torchvision.transforms.InterpolationMode.NEAREST), ) dataset = data.roiLeishDataset(file_img, file_mask, transfs) lengths = [math.ceil(len(dataset)*0.8), int(len(dataset)*0.2)] train_dt, valid_dt = torch.utils.data.random_split(dataset, lengths) for batch_size in config['batch_size']: for lr in config['lr']: for epoch in config['epoch']: for encoder in config['encoder']: for arch in config['arch']: train_dataloader = DataLoader(train_dt, batch_size=batch_size, shuffle=True) valid_dataloader = DataLoader(valid_dt, batch_size=batch_size, shuffle=False) model_leish = model.ModelRoiLeish(arch, encoder , in_channels=3, out_classes=1, lr = lr) pl_logger = WandbLogger(project = "ROI_LEISHMANIA_BINARY") callbacks = [ pl.callbacks.ModelCheckpoint( dirpath = "checkpoints", #every_n_train_steps=500, monitor='valid_jaccard', mode='max' ), #pl.callbacks.early_stopping.EarlyStopping(monitor='valid_jaccard', mode='max', patience=5) ] trainer = pl.Trainer( gpus=1, max_epochs=epoch, logger=pl_logger, callbacks = callbacks, # strategy = 'ddp' # distributed ) trainer.fit( model_leish, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader, ) if __name__ == '__main__': main()