SnakeCLEF2023 / exp5 /convnext2b_exp5_OBIDattention.py
BBracke's picture
Upload 13 files
ce0919b
from email.policy import strict
import os, time, pickle, shutil
import pandas as pd
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
from torch import autocast
import torchvision.transforms as transforms
import timm
from timm.models import create_model
from timm.utils import ModelEmaV2
from timm.optim import create_optimizer_v2
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
from torchmetrics import MetricCollection
from pytorch_metric_learning.losses import ArcFaceLoss
import wandb
import matplotlib.pyplot as plt
# ### parameters
################## Settings #############################
#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.backends.cudnn.benchmark = True
################## Data Paths ##########################
MODEL_DIR = "./convnext2b_obdid_attention/"
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
shutil.copyfile('./convnext2b_exp5_OBIDattention.py', f'{MODEL_DIR}convnext2b_exp5_OBIDattention.py')
TRAIN_DATA_DIR = "/SnakeCLEF2023-large_size/" # train imgs. path
ADD_TRAIN_DATA_DIR = "/HMP/" # add. train imgs. path
VAL_DATA_DIR = "/SnakeCLEF2023-large_size/" # val imgs. path
TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-iNat.csv"
ADD_TRAINDATA_CONFIG = "/SnakeCLEF2023-TrainMetadata-HM.csv"
VALIDDATA_CONFIG = "/SnakeCLEF2023-ValMetadata.csv"
MISSING_FILES = "../missing_train_data.csv" # csv with missing img. files that will be filtered out
CCM = "../code_class_mapping_obid.csv" # csv to metadata code to snake species dist.
NUM_CLASSES = 1784
################## Hyperparameters ########################
NUM_EPOCHS = 50
WARMUP_EPOCHS = 0
RESUME_EPOCH = 39 # resume model, optimizer from epoch 39 of experiment 4, checkpoint files need to be copied to the MODEL_DIR folder
LEARNING_RATE = {
'cnn': 1e-05,
'embeddings': 1e-04,
'classifier': 1e-04,
'attention': 1e-04,
}
BATCH_SIZE = {
'train': 1,
'valid': 1,
'grad_acc': 128, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
'max_imgs_per_instance': 100 # maximum number of considered image instance (includes TTA) for each observation_id
}
BATCH_SIZE_AFTER_WARMUP = {
'train': 1,
'valid': 1,
'grad_acc': 128, # gradient acc. steps with 'train' of batch sizes, global batch size = 'grad_acc' * 'train'
'max_imgs_per_instance': 100 # maximum number of considered image instance (includes TTA) for each observation_id
}
TRANSFORMS = {
'IMAGE_SIZE_TRAIN': 544,
'IMAGE_SIZE_VAL': 544,
'RandAug' : {
'm': 7,
'n': 2
},
'num_rand_crops': 5, # num. of random crops during training per image instance
}
############# Focal Loss ####################
FOCAL_LOSS = {
'class_dist': pickle.load(open("../classDist_HMP_missedRemoved.p", "rb"))['counts'], # snake species frequency obtained on observation_id level taken into account missing observation_id of missing image files
'gamma': 0.5,
}
############# Checkpoints ####################
CHECKPOINTS = {
'fe_cnn': None,
'model': None,
'optimizer': None,
'scaler': None,
'arcloss': None,
}
# ####### Embedding Token Mappings ########################
META_SIZES = {'endemic': 2, 'code': 212}
EMBEDDING_SIZES = {'endemic': 64, 'code': 64}
CODE_TOKENS = pickle.load(open("../meta_code_tokens.p", "rb"))
ENDEMIC_TOKENS = pickle.load(open("../meta_endemic_tokens.p", "rb"))
################### WandB ##################
WANDB = True
if WANDB:
wandb.init(
entity="snakeclef2023", # our team at wandb
# set the wandb project where this run will be logged
project="exp5", # -> define sub-projects here, e.g. experiments with MetaFormer or CNNs...
# define a name for this run
name="OBIDattention",
# track all the used hyperparameters here, config is just a dict object so any key:value pairs are possible
config={
"learning_rate": LEARNING_RATE,
"focal_loss": FOCAL_LOSS,
"architecture": "convnextv2_base.fcmae_ft_in22k_in1k_384",
"pretrained": "iNat21",
"dataset": f"snakeclef2023, additional train data: {True if ADD_TRAINDATA_CONFIG else False}",
"epochs": NUM_EPOCHS,
"transforms": TRANSFORMS,
"checkpoints": CHECKPOINTS,
"model_dir": MODEL_DIR
# ... any other hyperparameter that is necessary to reproduce the result
},
save_code=True, # save the script file as backup
dir=MODEL_DIR # locally folder where wandb log files are saved
)
##################### Dataset & AugTransforms #####################################
# ### dataset & loaders
class SnakeInstanceDataset(Dataset):
def __init__(self, data, ccm, transform, fix_num=None):
self.data = data
self.instance_groups = data.groupby('observation_id').groups
self.instance_obids = list(self.instance_groups.keys())
self.transform = transform # Image augmentation pipeline
self.code_class_mapping = ccm
self.code_tokens = CODE_TOKENS
self.endemic_tokens = ENDEMIC_TOKENS
self.fix_num = fix_num
self.random_gen = torch.Generator().manual_seed(1)
def __len__(self):
return len(self.instance_obids)
def __getitem__(self, index):
obid = self.instance_obids[index] # get observation id
instances = self.data.iloc[self.instance_groups[obid]]
code = instances.code.tolist()[0]
code = code if code in self.code_tokens.keys() else "unknown"
endemic = instances.endemic.tolist()[0]
endemic = endemic if endemic in self.endemic_tokens.keys() else False # get endemic metadata
label = torch.tensor([instances.class_id.tolist()[0]]) # get "global" label
ccm = torch.from_numpy(self.code_class_mapping[code].to_numpy()) # code class mapping
meta = torch.tensor([[self.code_tokens[code], self.endemic_tokens[endemic]]]) # metadata tokens
# load instance images
files = instances.image_path.tolist()
imgs = torch.stack([self.transform(Image.open(file).convert("RGB")) for file in files])
img_size = imgs.size(-1)
imgs = imgs.view(-1, 3, img_size, img_size)
# randomly shuffle imgs and/or draw subset of imgs
num_imgs = imgs.size(0)
idx = torch.randperm(num_imgs, generator=self.random_gen)
idx = idx[:self.fix_num] if self.fix_num else idx # randomly draw 5 imgs
imgs = imgs[idx, :, :, :]
return (imgs, label, ccm, meta)
# valid data preprocessing pipeline
def get_val_preprocessing(img_size):
print(f'IMG_SIZE_VAL: {img_size}')
return transforms.Compose([
transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
transforms.Compose([
transforms.FiveCrop((img_size, img_size)), # this is a list of PIL Images
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])) # returns a 4D tensor
]),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
class MultipleRandomCropsWithAugmentation:
def __init__(self, img_size, num_crops=5):
super(MultipleRandomCropsWithAugmentation, self).__init__()
self.num_crops = num_crops
self.random_crop = transforms.RandomCrop((img_size, img_size))
self.augment = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandAugment(num_ops=TRANSFORMS['RandAug']['n'], magnitude=TRANSFORMS['RandAug']['m'])
])
self.to_tensor = transforms.ToTensor()
def __call__(self, x):
x = torch.stack([self.to_tensor(self.augment(self.random_crop(x))) for i in range(self.num_crops)])
return x
# train data augmentation/ preprocessing pipeline
def get_train_augmentation_preprocessing(img_size):
print(f'IMG_SIZE_TRAIN: {img_size}')
return transforms.Compose([
transforms.Resize(int(img_size * 1.25)), # Expand IMAGE_SIZE before random crop
MultipleRandomCropsWithAugmentation(img_size, TRANSFORMS['num_rand_crops']),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def get_datasets(train_transfroms, val_transforms):
# load CSVs
nan_values = ['', '#N/A', '#N/A N/A', '#NA', '-1.#IND', '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', '<NA>', 'N/A', 'NULL', 'NaN', 'n/a', 'nan', 'null']
train_data = pd.read_csv(TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
train_data = train_data.drop_duplicates(subset='image_path', keep="first")
missing_train_data = pd.read_csv(MISSING_FILES, na_values=nan_values, keep_default_na=False)
valid_data = pd.read_csv(VALIDDATA_CONFIG, na_values=nan_values, keep_default_na=False)
valid_data = valid_data.drop_duplicates(subset='image_path', keep="first")
# delete missing files of train data table
train_data = pd.merge(train_data, missing_train_data, how='outer', indicator=True)
train_data = train_data.loc[train_data._merge == 'left_only', ["observation_id","endemic","binomial_name","code","image_path","class_id","subset"]]
# load transposed version of CCM table
ccm = pd.read_csv(CCM, na_values=nan_values, keep_default_na=False)
# add image path
train_data["image_path"] = TRAIN_DATA_DIR + train_data['image_path']
valid_data["image_path"] = VAL_DATA_DIR + valid_data['image_path']
# add additional data
if ADD_TRAINDATA_CONFIG:
add_train_data = pd.read_csv(ADD_TRAINDATA_CONFIG, na_values=nan_values, keep_default_na=False)
add_train_data["image_path"] = ADD_TRAIN_DATA_DIR + add_train_data['image_path']
train_data = pd.concat([train_data, add_train_data], axis=0)
# limit data size
#train_data = train_data.head(150)
#valid_data = valid_data.head(150)
# shuffle
train_data = train_data.sample(frac=1, random_state=1).reset_index(drop=True)
valid_data = valid_data.sample(frac=1, random_state=1).reset_index(drop=True)
# compute train, valid data weights
#TCLASS_WEIGHTS = compute_weights(train_data)
#VCLASS_WEIGHTS = compute_weights(valid_data)
# create datasets
train_dataset = SnakeInstanceDataset(train_data, ccm, transform=train_transfroms, fix_num=BATCH_SIZE['max_imgs_per_instance'])
valid_dataset = SnakeInstanceDataset(valid_data, ccm, transform=val_transforms, fix_num=BATCH_SIZE['max_imgs_per_instance'])
print(f'train dataset shape: {len(train_dataset)}')
print(f'valid dataset shape: {len(valid_dataset)}')
return train_dataset, valid_dataset#, TCLASS_WEIGHTS, VCLASS_WEIGHTS
def get_collate_fn():
def collate_fn(batch):
imgs = batch[0][0]
targets = batch[0][1]
ccm = batch[0][2]
meta = batch[0][3]
return [imgs, targets, ccm, meta]
return collate_fn
def get_dataloaders(imgsize_train, imgsize_val):
# get train, valid augmentation & preprocessing pipelines
train_aug_preprocessing = get_train_augmentation_preprocessing(imgsize_train)
val_preprocessing = get_val_preprocessing(imgsize_val)
# prepare the datasets
train_dataset, valid_dataset = get_datasets(train_transfroms=train_aug_preprocessing, val_transforms=val_preprocessing)
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=1, num_workers=4, prefetch_factor=8, collate_fn=get_collate_fn(), drop_last=False, pin_memory=True)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=1, num_workers=4, prefetch_factor=8, collate_fn=get_collate_fn(), drop_last=False, pin_memory=True)
return train_loader, valid_loader
# #################### plot train history #########################
def plot_history(logs):
fig, ax = plt.subplots(3, 1, figsize=(8, 12))
ax[0].plot(logs['loss'], label="train data")
ax[0].plot(logs['val_loss'], label="valid data")
ax[0].legend(loc="best")
ax[0].set_ylabel("loss")
ax[0].set_ylim([0, -np.log(1/NUM_CLASSES)])
#ax[0].set_xlabel("epochs")
ax[0].set_title("train- vs. valid loss")
ax[1].plot(logs['acc'], label="train data")
ax[1].plot(logs['val_acc'], label="valid data")
ax[1].legend(loc="best")
ax[1].set_ylabel("accuracy")
ax[1].set_ylim([0, 1.01])
#ax[1].set_xlabel("epochs")
ax[1].set_title("train- vs. valid accuracy")
ax[2].plot(logs['f1'], label="train data")
ax[2].plot(logs['val_f1'], label="valid data")
ax[2].legend(loc="best")
ax[2].set_ylabel("f1")
ax[2].set_ylim([0, 1.01])
ax[2].set_xlabel("epochs")
ax[2].set_title("train- vs. valid f1")
fig.savefig(f'{MODEL_DIR}model_history.svg', dpi=150, format="svg")
plt.show()
#################### Focal Loss ##################################
class FocalLoss(nn.Module):
'''
Multi-class Focal Loss
'''
def __init__(self, gamma=2, class_dist=None, reduction='mean', device='cuda'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = torch.tensor((1.0 - 0.999) / (1.0 - 0.999**class_dist), dtype=torch.float32, device=device) if class_dist is not None else torch.ones(NUM_CLASSES, device=device)
self.reduction = reduction
def forward(self, inputs, targets):
"""
input: [N, C], float32
target: [N, ], int64
"""
logpt = torch.nn.functional.log_softmax(inputs, dim=1)
pt = torch.exp(logpt)
logpt = (1-pt)**self.gamma * logpt
loss = torch.nn.functional.nll_loss(logpt, targets, weight=self.weight, reduction=self.reduction)
return loss
# #################### Model #####################################
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
self.conv_backbone = create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=0, drop_path_rate=0.2)
if CHECKPOINTS['fe_cnn']:
self.conv_backbone.load_state_dict(torch.load(CHECKPOINTS['fe_cnn'], map_location='cpu'), strict=True)
print(f"use FE_CHECKPOINTS: {CHECKPOINTS['fe_cnn']}")
torch.cuda.empty_cache()
def forward(self, img):
conv_features = self.conv_backbone(img)
return conv_features
class MetaEmbeddings(nn.Module):
def __init__(self, embedding_sizes: dict, meta_sizes: dict, dropout: float = None):
super(MetaEmbeddings, self).__init__()
self.endemic_embedding = nn.Embedding(meta_sizes['endemic'], embedding_sizes['endemic'], max_norm=1.0)
self.code_embedding = nn.Embedding(meta_sizes['code'], embedding_sizes['code'], max_norm=1.0)
self.dim_embedding = sum(embedding_sizes.values())
self.embedding_net = nn.Sequential(
nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
nn.GELU(),
nn.LayerNorm(self.dim_embedding, eps=1e-06),
nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity(),
nn.Linear(in_features=self.dim_embedding, out_features=self.dim_embedding, bias=True),
nn.GELU(),
nn.LayerNorm(self.dim_embedding, eps=1e-06),
)
def forward(self, meta):
code_feature = self.code_embedding(meta[:,0])
endemic_feature = self.endemic_embedding(meta[:,1])
embeddings = torch.concat([code_feature, endemic_feature], dim=-1)
embedding_features = self.embedding_net(embeddings)
return embedding_features
class Classifier(nn.Module):
def __init__(self, num_classes: int, dim_embeddings: int, dropout: float = None):
super(Classifier, self).__init__()
self.dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
self.classifier = nn.Linear(in_features=dim_embeddings, out_features=num_classes, bias=True)
def forward(self, embeddings):
dropped_feature = self.dropout(embeddings)
outputs = self.classifier(dropped_feature)
return outputs
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.L = 1024
self.D = 256
self.K = 1
self.attention = nn.Sequential(
nn.Linear(self.L, self.D),
nn.Tanh(),
nn.Linear(self.D, self.K)
)
def forward(self, x):
N, L = x.shape
x = x.view(1,N,L)
A = self.attention(x) # 1xNx1
A = torch.transpose(A, 2, 1) # 1x1xN
A = nn.functional.softmax(A, dim=-1) # softmax over N
M = torch.bmm(A, x).squeeze(dim=1) # 1xL
return M, A
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.feature_extractor = FeatureExtractor()
self.embedding_net = MetaEmbeddings(embedding_sizes=EMBEDDING_SIZES, meta_sizes=META_SIZES, dropout=0.25)
self.mil_pooling = Attention()
self.classifier = Classifier(num_classes=NUM_CLASSES, dim_embeddings=1024+128, dropout=0.25)
def forward(self, img, meta):
img_features = self.feature_extractor(img)
img_features, A = self.mil_pooling(img_features)
meta_features = self.embedding_net(meta)
cat_features = torch.concat([img_features, meta_features], dim=-1)
classifier_outputs = self.classifier(cat_features)
return classifier_outputs, cat_features
class LossLayer(nn.Module):
def __init__(self):
super(LossLayer, self).__init__()
self.arcloss = ArcFaceLoss(num_classes=NUM_CLASSES, embedding_size=1024+128, margin=28.6, scale=64)
self.celoss = FocalLoss(gamma=FOCAL_LOSS['gamma'], class_dist=FOCAL_LOSS['class_dist'])
def forward(self, classifier_outputs, cat_features, labels):
classifier_loss = self.celoss(classifier_outputs, labels)
embedding_loss = self.arcloss(cat_features, labels)
return classifier_loss + embedding_loss
def load_checkpoints(model=None, ema_model=None, optimizer=None, scaler=None, arcloss=None):
if CHECKPOINTS['model'] and model is not None:
model.load_state_dict(torch.load(CHECKPOINTS['model'], map_location='cpu'), strict=False)
print(f"use model checkpoints: {CHECKPOINTS['model']}")
if CHECKPOINTS['ema_model'] and ema_model is not None:
ema_model.load_state_dict(torch.load(CHECKPOINTS['ema_model'], map_location='cpu'), strict=False)
print(f"use ema_model checkpoints: {CHECKPOINTS['ema_model']}")
if CHECKPOINTS['optimizer'] and optimizer is not None:
optimizer.load_state_dict(torch.load(CHECKPOINTS['optimizer'], map_location='cpu'))
print(f"use optimizer checkpoints: {CHECKPOINTS['optimizer']}")
if CHECKPOINTS['scaler'] and scaler is not None:
scaler.load_state_dict(torch.load(CHECKPOINTS['scaler'], map_location='cpu'))
print(f"use scaler checkpoints: {CHECKPOINTS['scaler']}")
if CHECKPOINTS['arcloss'] and arcloss is not None:
arcloss.load_state_dict(torch.load(CHECKPOINTS['arcloss'], map_location='cpu'))
print(f"use arcloss checkpoints: {CHECKPOINTS['arcloss']}")
torch.cuda.empty_cache()
def resume_checkpoints(model=None, optimizer=None, scaler=None):
if model is not None:
model.load_state_dict(torch.load(f'{MODEL_DIR}model_epoch{RESUME_EPOCH}.pth', map_location='cpu'), strict=False)
print(f"use model checkpoints: {MODEL_DIR}model_epoch{RESUME_EPOCH}.pth")
if optimizer is not None:
optimizer.load_state_dict(torch.load(f'{MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
print(f"use optimizer checkpoints: {MODEL_DIR}optimizer_epoch{RESUME_EPOCH}.pth")
if scaler is not None:
scaler.load_state_dict(torch.load(f'{MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth', map_location='cpu'))
print(f"use scaler checkpoints: {MODEL_DIR}mp_scaler_epoch{RESUME_EPOCH}.pth")
torch.cuda.empty_cache()
def resume_logs(logs):
old_logs = pd.read_csv(f"{MODEL_DIR}train_history.csv")
for m in list(logs.keys()):
logs[m].extend(list(old_logs[m].values))
######################## Optimizer #####################################
def get_optm_group(module):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, timm.layers.GlobalResponseNormMlp)
blacklist_weight_modules = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in module.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in module.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
return param_dict, decay, no_decay
def get_warmup_optimizer(model):
params_group = []
param_dict, decay, no_decay = get_optm_group(model.embedding_net)
params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['embeddings']})
params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['embeddings']})
param_dict, decay, no_decay = get_optm_group(model.classifier)
params_group.append({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['classifier']})
params_group.append({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
optimizer = torch.optim.AdamW(params_group)
return optimizer
def get_after_warmup_optimizer(model, old_opt):
new_opt = create_optimizer_v2(model.feature_extractor.conv_backbone, opt='adamw', filter_bias_and_bn=True, weight_decay=1e-8, layer_decay=0.85, lr=LEARNING_RATE['cnn'])
# add old param groups
for group in old_opt.param_groups:
new_opt.add_param_group(group)
return new_opt
# #################### Model Warmup #####################################
def warmup_start(model):
# freeze model feature_extractor.conv_backbone during warmup
for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
param.requires_grad = False
print(f'--> freeze feature_extractor.conv_backbone during warmup phase')
# freeze model feature_extractor.conv_backbone during warmup
for i, (param_name, param) in enumerate(model.embedding_net.named_parameters()):
param.requires_grad = False
print(f'--> freeze feature_extractor.embedding_net during warmup phase')
def warmup_end(model):
# unfreeze feature_extractor.conv_backbone during warmup
for i, (param_name, param) in enumerate(model.feature_extractor.conv_backbone.named_parameters()):
param.requires_grad = True
print(f'--> unfreeze feature_extractor.conv_backbone after warmup phase')
# freeze model feature_extractor.conv_backbone during warmup
for i, (param_name, param) in enumerate(model.embedding_net.named_parameters()):
param.requires_grad = True
print(f'--> unfreeze feature_extractor.embedding_net during warmup phase')
# #################### Train Loop #####################################
# ### train
def main():
device = torch.device(f'cuda:1')
torch.cuda.set_device(device)
# prepare the datasets
train_loader, valid_loader = get_dataloaders(imgsize_train=TRANSFORMS['IMAGE_SIZE_TRAIN'],
imgsize_val=TRANSFORMS['IMAGE_SIZE_VAL'],
)
# instantiate the model
model = Model().to(device)
if RESUME_EPOCH > 0:
resume_checkpoints(model=model)
ema_model = ModelEmaV2(model, decay=0.9998, device=device)
warmup_start(model)
loss_fn = LossLayer().to(device)
if RESUME_EPOCH > 0:
resume_checkpoints(arcloss=loss_fn.arcloss)
# Optimizer & Schedules & early stopping
optimizer = get_warmup_optimizer(model)
optimizer.add_param_group({"params": loss_fn.arcloss.parameters(), "weight_decay": 0.0, 'lr': LEARNING_RATE['classifier']})
scaler = GradScaler()
if RESUME_EPOCH > 0:
#optimizer = get_after_warmup_optimizer(model, optimizer) if RESUME_EPOCH > WARMUP_EPOCHS else optimizer
resume_checkpoints(optimizer=optimizer, scaler=scaler)
# add attention module
param_dict, decay, no_decay = get_optm_group(model.attention)
optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.05, 'lr': LEARNING_RATE['attention']})
optimizer.add_param_group({"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0, 'lr': LEARNING_RATE['attention']})
# running metrics during training
loss_metric = MeanMetric().to(device)
metrics = MetricCollection(metrics={
'acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro'),
'top3_acc': MulticlassAccuracy(num_classes=NUM_CLASSES, average='macro', top_k=3),
'f1': MulticlassF1Score(num_classes=NUM_CLASSES, average='macro')
}).to(device)
metric_ccm = MulticlassF1Score(num_classes=NUM_CLASSES, average='macro').to(device)
# start time of trainig
start_training = time.perf_counter()
# create log dict
logs = {'loss': [], 'acc': [], 'acc_top3': [], 'f1': [], 'f1country': [], 'val_loss': [], 'val_acc': [], 'val_acc_top3': [], 'val_f1': [], 'val_f1country': []}
if RESUME_EPOCH > 0:
resume_logs(logs)
#iterate over epochs
start_epoch = RESUME_EPOCH+1 if RESUME_EPOCH > 0 else 0
for epoch in range(start_epoch, NUM_EPOCHS):
# start time of epoch
epoch_start = time.perf_counter()
print(f'Epoch {epoch+1}/{NUM_EPOCHS}')
############################## train phase ####################################
model.train()
# zero the parameter gradients
optimizer.zero_grad(set_to_none=True)
# grad acc loss divider
loss_div = torch.tensor(BATCH_SIZE['grad_acc'], dtype=torch.float16, device=device, requires_grad=False) if BATCH_SIZE['grad_acc'] != 0 else torch.tensor(1.0, dtype=torch.float16, device=device, requires_grad=False)
# iterate over training batches
for batch_idx, (inputs, labels, ccm, meta) in enumerate(train_loader):
inputs = inputs.to(device, non_blocking=True)
meta = meta.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
ccm = ccm.to(device, non_blocking=True)
# forward with mixed precision
with autocast(device_type='cuda', dtype=torch.float16):
outputs, embeddings = model(inputs, meta)
loss = loss_fn(outputs, embeddings, labels) / loss_div
# loss backward
scaler.scale(loss).backward()
# Compute metrics
loss_metric.update((loss * loss_div).detach())
preds = outputs.softmax(dim=-1).detach()
metrics.update(preds, labels)
metric_ccm.update(preds * ccm, labels)
############################ grad acc ##############################
if (batch_idx+1) % BATCH_SIZE['grad_acc'] == 0:
#scaler.unscale_(optimizer)
#torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # optimize with gradient clipping to 1 with mixed precision
scaler.step(optimizer)
scaler.update()
# zero the parameter gradients
optimizer.zero_grad(set_to_none=True)
# update ema model
ema_model.update(model)
# compute, sync & reset metrics for validation
epoch_loss = loss_metric.compute()
epoch_metrics = metrics.compute()
epoch_metric_ccm = metric_ccm.compute()
loss_metric.reset()
metrics.reset()
metric_ccm.reset()
# Append metric results to logs
logs['loss'].append(epoch_loss.cpu().item())
logs['acc'].append(epoch_metrics['acc'].cpu().item())
logs['acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
logs['f1'].append(epoch_metrics['f1'].cpu().item())
logs['f1country'].append(epoch_metric_ccm.detach().cpu().item())
print(f"loss: {logs['loss'][epoch]:.5f}, acc: {logs['acc'][epoch]:.5f}, acc_top3: {logs['acc_top3'][epoch]:.5f}, f1: {logs['f1'][epoch]:.5f}, f1country: {logs['f1country'][epoch]:.5f}", end=' || ')
# zero the parameter gradients
optimizer.zero_grad(set_to_none=True)
del inputs, labels, ccm, meta, preds, outputs, loss, loss_div, epoch_loss, epoch_metrics, epoch_metric_ccm
torch.cuda.empty_cache()
############################## valid phase ####################################
with torch.no_grad():
model.eval()
# iterate over validation batches
for (inputs, labels, ccm, meta) in valid_loader:
inputs = inputs.to(device, non_blocking=True)
meta = meta.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
ccm = ccm.to(device, non_blocking=True)
# forward with mixed precision
with autocast(device_type='cuda', dtype=torch.float16):
outputs, embeddings = model(inputs, meta)
loss = loss_fn(outputs, embeddings, labels)
# Compute metrics
loss_metric.update(loss.detach())
preds = outputs.softmax(dim=-1).detach()
metrics.update(preds, labels)
metric_ccm.update(preds * ccm, labels)
# compute, sync & reset metrics for validation
epoch_loss = loss_metric.compute()
epoch_metrics = metrics.compute()
epoch_metric_ccm = metric_ccm.compute()
loss_metric.reset()
metrics.reset()
metric_ccm.reset()
# Append metric results to logs
logs['val_loss'].append(epoch_loss.cpu().item())
logs['val_acc'].append(epoch_metrics['acc'].cpu().item())
logs['val_acc_top3'].append(epoch_metrics['top3_acc'].cpu().item())
logs['val_f1'].append(epoch_metrics['f1'].cpu().item())
logs['val_f1country'].append(epoch_metric_ccm.detach().cpu().item())
print(f"val_loss: {logs['val_loss'][epoch]:.5f}, val_acc: {logs['val_acc'][epoch]:.5f}, val_acc_top3: {logs['val_acc_top3'][epoch]:.5f}, val_f1: {logs['val_f1'][epoch]:.5f}, val_f1country: {logs['val_f1country'][epoch]:.5f}", end=' || ')
del inputs, labels, ccm, meta, preds, outputs, loss, epoch_loss, epoch_metrics, epoch_metric_ccm
torch.cuda.empty_cache()
# save logs as csv
logs_df = pd.DataFrame(logs)
logs_df.to_csv(f'{MODEL_DIR}train_history.csv', index_label='epoch', sep=',', encoding='utf-8')
if WANDB:
# at the end of each epoch, log anything you want to log for that epoch
wandb.log(
{k:v[epoch] for k,v in logs.items()}, # e.g. log each metric value for the current epoch in our defined logs dict
step=epoch # epoch index for wandb
)
#save trained model for each epoch
torch.save(model.state_dict(), f'{MODEL_DIR}model_epoch{epoch}.pth')
torch.save(ema_model.module.state_dict(), f'{MODEL_DIR}ema_model_epoch{epoch}.pth')
torch.save(optimizer.state_dict(), f'{MODEL_DIR}optimizer_epoch{epoch}.pth')
torch.save(scaler.state_dict(), f'{MODEL_DIR}mp_scaler_epoch{epoch}.pth')
torch.save(loss_fn.arcloss.state_dict(), f'{MODEL_DIR}arcloss_epoch{epoch}.pth')
# end time of epoch
epoch_end = time.perf_counter()
print(f"epoch runtime: {epoch_end-epoch_start:5.3f} sec.")
del logs_df, epoch_start, epoch_end
torch.cuda.empty_cache()
################################## EMA Model Validation ################################
del model
torch.cuda.empty_cache()
ema_net = ema_model.module
ema_net.eval()
with torch.no_grad():
# iterate over validation batches
for (inputs, labels, ccm, meta) in valid_loader:
inputs = inputs.to(device, non_blocking=True)
meta = meta.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
ccm = ccm.to(device, non_blocking=True)
# forward with mixed precision
with autocast(device_type='cuda', dtype=torch.float16):
outputs, embeddings = model(inputs, meta)
loss = loss_fn(outputs, embeddings, labels)
# Compute metrics
loss_metric.update(loss.detach())
preds = outputs.softmax(dim=-1).detach()
metrics.update(preds, labels)
metric_ccm.update(preds * ccm, labels)
# compute, sync & reset metrics for validation
epoch_loss = loss_metric.compute()
epoch_metrics = metrics.compute()
epoch_metric_ccm = metric_ccm.compute()
loss_metric.reset()
metrics.reset()
metric_ccm.reset()
print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}")
with open(f'{MODEL_DIR}ema_results.txt', 'w') as f:
print(f"ema_loss: {epoch_loss.cpu().item():.5f}, ema_acc: {epoch_metrics['acc'].cpu().item():.5f}, ema_acc_top3: {epoch_metrics['top3_acc'].cpu().item():.5f}, ema_f1: {epoch_metrics['f1'].cpu().item():.5f}, ema_f1country: {epoch_metric_ccm.detach().cpu().item():.5f}", file=f)
plot_history(logs)
# end time of trainig
end_training = time.perf_counter()
print(f'Training succeeded in {(end_training - start_training):5.3f}s')
if WANDB:
wandb.finish()
if __name__=="__main__":
main()