|
|
|
import os |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torchvision |
|
import torch |
|
from torchvision import transforms |
|
import math |
|
from torch.utils.data import DataLoader |
|
import wandb |
|
import model |
|
import dataset as data |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning import Trainer |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
WANDB_API_KEY = os.getenv('WANDB_API_KEY') |
|
os.environ['WANDB_API_KEY'] = WANDB_API_KEY |
|
|
|
def main(): |
|
|
|
pl.seed_everything(0, workers=True) |
|
config = { |
|
'batch_size': [32], |
|
'lr': [5e-4], |
|
'epoch': [15], |
|
'arch': ['fpn'], |
|
'encoder': ["resnext101_32x8d"] |
|
} |
|
|
|
|
|
file_img = './roi/roi_img' |
|
file_mask = './roi/roi_mask' |
|
|
|
transfs = torch.nn.Sequential( |
|
transforms.Resize((128,128), interpolation=torchvision.transforms.InterpolationMode.NEAREST), |
|
) |
|
dataset = data.roiLeishDataset(file_img, file_mask, transfs) |
|
|
|
lengths = [math.ceil(len(dataset)*0.8), int(len(dataset)*0.2)] |
|
train_dt, valid_dt = torch.utils.data.random_split(dataset, lengths) |
|
|
|
for batch_size in config['batch_size']: |
|
for lr in config['lr']: |
|
for epoch in config['epoch']: |
|
for encoder in config['encoder']: |
|
for arch in config['arch']: |
|
|
|
train_dataloader = DataLoader(train_dt, batch_size=batch_size, shuffle=True) |
|
valid_dataloader = DataLoader(valid_dt, batch_size=batch_size, shuffle=False) |
|
|
|
model_leish = model.ModelRoiLeish(arch, encoder , in_channels=3, out_classes=1, lr = lr) |
|
|
|
pl_logger = WandbLogger(project = "ROI_LEISHMANIA_BINARY") |
|
callbacks = [ |
|
pl.callbacks.ModelCheckpoint( |
|
dirpath = "checkpoints", |
|
|
|
monitor='valid_jaccard', |
|
mode='max' |
|
), |
|
|
|
] |
|
|
|
trainer = pl.Trainer( |
|
gpus=1, |
|
max_epochs=epoch, |
|
logger=pl_logger, |
|
callbacks = callbacks, |
|
|
|
) |
|
|
|
trainer.fit( |
|
model_leish, |
|
train_dataloaders=train_dataloader, |
|
val_dataloaders=valid_dataloader, |
|
) |
|
|
|
if __name__ == '__main__': |
|
main() |