ResNetonImageNet / src /config.py
AravindKumarRajendran's picture
fix issues
17b870c
raw
history blame contribute delete
626 Bytes
import torch
class Params:
def __init__(self):
self.batch_size = 148
self.name = "resnet_50"
self.num_workers = 48
self.lr = 0.165
self.momentum = 0.9
self.weight_decay = 1e-4
self.lr_step_size = 30
self.lr_gamma = 0.1
self.num_epochs = 50
def __repr__(self):
return str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def get_device():
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)