File size: 626 Bytes
17b870c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
class Params:
def __init__(self):
self.batch_size = 148
self.name = "resnet_50"
self.num_workers = 48
self.lr = 0.165
self.momentum = 0.9
self.weight_decay = 1e-4
self.lr_step_size = 30
self.lr_gamma = 0.1
self.num_epochs = 50
def __repr__(self):
return str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def get_device():
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
) |