leishmaniaModel / run.py
vannynakamura's picture
Update run.py
0010972
raw
history blame contribute delete
No virus
2.47 kB
# run.py
import os
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch
from torchvision import transforms
import math
from torch.utils.data import DataLoader
import wandb
import model
import dataset as data
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from dotenv import load_dotenv
load_dotenv()
WANDB_API_KEY = os.getenv('WANDB_API_KEY')
os.environ['WANDB_API_KEY'] = WANDB_API_KEY
def main():
pl.seed_everything(0, workers=True) # seed
config = {
'batch_size': [32],
'lr': [5e-4],
'epoch': [15],
'arch': ['fpn'],
'encoder': ["resnext101_32x8d"]
}
file_img = './roi/roi_img'
file_mask = './roi/roi_mask'
transfs = torch.nn.Sequential(
transforms.Resize((128,128), interpolation=torchvision.transforms.InterpolationMode.NEAREST),
)
dataset = data.roiLeishDataset(file_img, file_mask, transfs)
lengths = [math.ceil(len(dataset)*0.8), int(len(dataset)*0.2)]
train_dt, valid_dt = torch.utils.data.random_split(dataset, lengths)
for batch_size in config['batch_size']:
for lr in config['lr']:
for epoch in config['epoch']:
for encoder in config['encoder']:
for arch in config['arch']:
train_dataloader = DataLoader(train_dt, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dt, batch_size=batch_size, shuffle=False)
model_leish = model.ModelRoiLeish(arch, encoder , in_channels=3, out_classes=1, lr = lr)
pl_logger = WandbLogger(project = "ROI_LEISHMANIA_BINARY")
callbacks = [
pl.callbacks.ModelCheckpoint(
dirpath = "checkpoints",
#every_n_train_steps=500,
monitor='valid_jaccard',
mode='max'
),
#pl.callbacks.early_stopping.EarlyStopping(monitor='valid_jaccard', mode='max', patience=5)
]
trainer = pl.Trainer(
gpus=1,
max_epochs=epoch,
logger=pl_logger,
callbacks = callbacks,
# strategy = 'ddp' # distributed
)
trainer.fit(
model_leish,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
)
if __name__ == '__main__':
main()