Spaces:
Running
Running
File size: 1,209 Bytes
3cc4a06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import os
import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler
class BaseModel(nn.Module):
def __init__(self, opt):
super(BaseModel, self).__init__()
self.opt = opt
self.total_steps = 0
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
if opt.gpu_ids:
self.device= torch.device('cuda:{}'.format(opt.gpu_ids[0]))
else:
print("gpu is not available! ")
# exit()
self.device = torch.device('cpu')
# self.device = torch.device('cuda')
def save_networks(self, save_filename):
save_path = os.path.join(self.save_dir, save_filename)
# serialize model and optimizer to dict
state_dict = {
'model': self.model.state_dict(),
'optimizer' : self.optimizer.state_dict(),
'total_steps' : self.total_steps,
}
torch.save(state_dict, save_path)
def eval(self):
self.model.eval()
def test(self):
with torch.no_grad():
self.forward() |