Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import imp | |
import numpy as np | |
class Base(nn.Module): | |
def __init__(self, stop_threshold): | |
super().__init__() | |
self.init_model() | |
self.num_params() | |
self.register_buffer("step", torch.zeros(1, dtype=torch.long)) | |
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)) | |
def r(self): | |
return self.decoder.r.item() | |
def r(self, value): | |
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) | |
def init_model(self): | |
for p in self.parameters(): | |
if p.dim() > 1: nn.init.xavier_uniform_(p) | |
def finetune_partial(self, whitelist_layers): | |
self.zero_grad() | |
for name, child in self.named_children(): | |
if name in whitelist_layers: | |
print("Trainable Layer: %s" % name) | |
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()])) | |
for param in child.parameters(): | |
param.requires_grad = False | |
def get_step(self): | |
return self.step.data.item() | |
def reset_step(self): | |
# assignment to parameters or buffers is overloaded, updates internal dict entry | |
self.step = self.step.data.new_tensor(1) | |
def log(self, path, msg): | |
with open(path, "a") as f: | |
print(msg, file=f) | |
def load(self, path, device, optimizer=None): | |
# Use device of model params as location for loaded state | |
checkpoint = torch.load(str(path), map_location=device) | |
self.load_state_dict(checkpoint["model_state"], strict=False) | |
if "optimizer_state" in checkpoint and optimizer is not None: | |
optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
def save(self, path, optimizer=None): | |
if optimizer is not None: | |
torch.save({ | |
"model_state": self.state_dict(), | |
"optimizer_state": optimizer.state_dict(), | |
}, str(path)) | |
else: | |
torch.save({ | |
"model_state": self.state_dict(), | |
}, str(path)) | |
def num_params(self, print_out=True): | |
parameters = filter(lambda p: p.requires_grad, self.parameters()) | |
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 | |
if print_out: | |
print("Trainable Parameters: %.3fM" % parameters) | |
return parameters | |