Spaces:
Running
on
T4
Running
on
T4
import os | |
import numpy as np | |
import torch | |
from torch.autograd import Variable | |
from pdb import set_trace as st | |
from IPython import embed | |
class BaseModel(): | |
def __init__(self): | |
pass; | |
def name(self): | |
return 'BaseModel' | |
def initialize(self, use_gpu=True, gpu_ids=[0]): | |
self.use_gpu = use_gpu | |
self.gpu_ids = gpu_ids | |
def forward(self): | |
pass | |
def get_image_paths(self): | |
pass | |
def optimize_parameters(self): | |
pass | |
def get_current_visuals(self): | |
return self.input | |
def get_current_errors(self): | |
return {} | |
def save(self, label): | |
pass | |
# helper saving function that can be used by subclasses | |
def save_network(self, network, path, network_label, epoch_label): | |
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) | |
save_path = os.path.join(path, save_filename) | |
torch.save(network.state_dict(), save_path) | |
# helper loading function that can be used by subclasses | |
def load_network(self, network, network_label, epoch_label): | |
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) | |
save_path = os.path.join(self.save_dir, save_filename) | |
print('Loading network from %s'%save_path) | |
network.load_state_dict(torch.load(save_path)) | |
def update_learning_rate(): | |
pass | |
def get_image_paths(self): | |
return self.image_paths | |
def save_done(self, flag=False): | |
np.save(os.path.join(self.save_dir, 'done_flag'),flag) | |
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') | |