import torch class Params: def __init__(self): self.batch_size = 148 self.name = "resnet_50" self.num_workers = 48 self.lr = 0.165 self.momentum = 0.9 self.weight_decay = 1e-4 self.lr_step_size = 30 self.lr_gamma = 0.1 self.num_epochs = 50 def __repr__(self): return str(self.__dict__) def __eq__(self, other): return self.__dict__ == other.__dict__ def get_device(): return ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" )