myTest01 / models /ddc_model.py
meng2003's picture
Upload 85 files
bc32eea
from .base_model import BaseModel
import torch.nn.functional as F
from torch import nn
import torch
#from models import constants
import numpy as np
import os
class DDCModel(nn.Module):
#class DDCModel(BaseModel):
def __init__(self, opt):
super().__init__()
#super().__init__(opt)
self.opt = opt
self.loss_names = ['ce', 'humaneness_reg', 'total']
self.metric_names = ['accuracy']
self.module_names = [''] # changed from 'model_names'
self.schedulers = []
self.net = DDCNet(opt)
self.optimizers = [torch.optim.Adam([
{'params': [param for name, param in self.net.named_parameters() if name[-4:] == 'bias'],
'lr': 2 * opt.learning_rate}, # bias parameters change quicker - no weight decay is applied
{'params': [param for name, param in self.net.named_parameters() if name[-4:] != 'bias'],
'lr': opt.learning_rate, 'weight_decay': opt.weight_decay} # filter parameters have weight decay
])]
self.loss_ce = None
self.humaneness_reg = None
self.save_dir=opt.checkpoints_dir+"/block_placement_ddc2"
self.device="cpu"
def name(self):
return "DDCNet"
def load_networks(self, epoch):
for name in self.module_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
#for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
# if not self.opt.gpu_ids:
# state_dict = {key[6:]: value for key, value in
# state_dict.items()} # remove data_parallel's "module."
net.load_state_dict(state_dict)
@staticmethod
def modify_commandline_options(parser, is_train):
# parser.add_argument('--num_classes', type=int, default=20)
# parser.add_argument('--output_channels', type=int, default=(4*3))
# parser.add_argument('--kernel_size', type=int, default=2)
# parser.add_argument('--bias', action='store_false')
parser.add_argument('--entropy_loss_coeff', type=float, default=0.0)
parser.add_argument('--humaneness_reg_coeff', type=float, default=0.0)
parser.add_argument('--hidden_dim', type=int, default=512)
parser.add_argument('--num_classes', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.0)
return parser
def set_input(self, data):
# move multiple samples of the same song to the second dimension and the reshape to batch dimension
input_ = data['input']
target_ = data['target']
input_shape = input_.shape
target_shape = target_.shape
# 0 batch dimension, 1 window dimension, 2 context time dimension, 3 frequency dimension, 4 mel_window_size dimension, 5 time dimension
self.input = input_.reshape((input_shape[0]*input_shape[1], input_shape[2], input_shape[3], input_shape[4], input_shape[5])).to(self.device)
self.input = self.input.permute(0,4,1,2,3) # batch/window x time x temporal_context x frequency_features x mel_window_sizes
#we collapse all the dimensions of target_ because that is the same way the output of the network is being processed for the cross entropy calculation (see self.forward)
# here, 0 is the batch dimension, 1 is the window index, 2 is the time dimension, 3 is the output channel dimension
self.target = target_.reshape((target_shape[0]*target_shape[1]*target_shape[2]*target_shape[3])).to(self.device)
def forward(self):
self.output = self.net.forward(self.input)
x = self.output
[n, l , classes] = x.size()
x = x.view(n * l, classes)
# print(x)
self.loss_ce = F.cross_entropy(x, self.target)
if self.opt.entropy_loss_coeff > 0:
S = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
S = -1.0 * S.mean()
self.loss_ce += self.opt.entropy_loss_coeff * S
self.metric_accuracy = (torch.argmax(x,1) == self.target).sum().float()/len(self.target)
#TODO: implement humaneness_reg maybe
# problem is we don't have past notes available in input, so need to do that differently
# just use output I guess :P
# step_size = self.opt.step_size
# humaneness_delta = constants.HUMAN_DELTA
# window_size = int(humaneness_delta/step_size)
#
# receptive_field = self.net.module.receptive_field
# notes = (torch.argmax(input[:,-5:,receptive_field//2-(window_size):receptive_field//2],1)==4).float()
# distance_factor = torch.tensor(np.exp(-2*np.arange(window_size,0,-1)/window_size)).float().cuda()
# if self.opt.entropy_loss_coeff > 0:
# weights = torch.tensordot(notes,distance_factor,dims=1)
# humaneness_reg = F.cross_entropy(x,torch.zeros(weights.shape).long().cuda(), reduction='none')
# humaneness_reg = torch.dot(humaneness_reg, weights)
# self.loss_humaneness_reg = humaneness_reg
# # self.loss_humaneness_reg = 0
# self.loss_total = self.loss_ce + self.opt.humaneness_reg_coeff * self.loss_humaneness_reg
# else:
# self.loss_humaneness_reg = 0
# self.loss_total = self.loss_ce
self.loss_humaneness_reg = 0
self.loss_total = self.loss_ce
def backward(self):
self.optimizers[0].zero_grad()
self.loss_total.backward()
self.optimizers[0].step()
def optimize_parameters(self):
self.set_requires_grad(self.net, requires_grad=True)
self.forward()
self.backward()
for scheduler in self.schedulers:
# step for schedulers that update after each iteration
try:
scheduler.batch_step()
except AttributeError:
pass
def prepare_input(self,y):
# dimensions of y are: features x window_sizes x time
receptive_field = 1
input_length = y.shape[-1]
y = np.concatenate((np.zeros((y.shape[0],y.shape[1],receptive_field+self.opt.time_shifts//2)),y),2)
# we also pad at the end to allow generation to be of the same length of song, by padding an amount corresponding to time_shifts
y = np.concatenate((y,np.zeros((y.shape[0],y.shape[1],self.opt.time_shifts//2))),2)
input_windowss = []
time_shifts = self.opt.time_shifts - 1
# loop that gets the input features for each of the windows, shifted by `ii`, and saves them in `input_windowss`
for ii in range(-time_shifts//2, time_shifts//2+1):
input_windows = [y[:,:,self.opt.time_shifts//2+ii:self.opt.time_shifts//2+ii+input_length]]
input_windows = torch.tensor(input_windows)
input_windows = (input_windows - input_windows.mean())/torch.abs(input_windows).max()
# input_windows = (input_windows.permute(3,0,1,2) - input_windows.mean(-1)).permute(1,2,3,0)
input_windowss.append(input_windows.float())
input = torch.stack(input_windowss,dim=1).float()
input_shape = input.shape
input = input.to(self.device)
input = input.permute(0,4,1,2,3) # batch/window x time x temporal_context x frequency_features x mel_window_sizes
return input
def generate(self,y):
input = self.prepare_input(y)
if self.opt.cuda:
with torch.no_grad():
self.net.module.eval()
return F.softmax(self.net.module.forward(input)[0],2)
else:
with torch.no_grad():
self.net.eval()
return F.softmax(self.net.forward(input)[0],2)
def generate_features(self,y):
input = self.prepare_input(y)
if self.opt.cuda:
with torch.no_grad():
self.net.module.eval()
logits, h = self.net.module.forward(input)
return h, F.softmax(logits,2)
else:
with torch.no_grad():
self.net.eval()
logits, h = self.net.forward(input)
return h, F.softmax(logits,2)
class DDCNet(nn.Module):
def __init__(self,opt):
super(DDCNet, self).__init__()
self.conv1 = nn.Conv2d(3, 20, (7,3)) #assumes CHW format
# self.pool = nn.MaxPool1d(3, 3)
self.pool = nn.MaxPool2d((1,3), (1,3))
self.conv2 = nn.Conv2d(20, 20, 3)
# self.fc1 = nn.Linear(20 * 9, 256)
# self.fc2 = nn.Linear(256, 128)
self.lstm = nn.LSTM(input_size=20*7*8, hidden_size=opt.hidden_dim, num_layers=2, batch_first=True) # Define the LSTM
self.hidden_to_state = nn.Linear(opt.hidden_dim,
opt.num_classes)
def forward(self, x):
# batch/window x time x temporal_context x frequency_features x mel_window_sizes
# print(x.shape)
[N,L,deltaT,dim,winsizes] = x.shape
x = x.reshape(N*L,deltaT,dim,winsizes)
x = x.permute(0,3,1,2)
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
# print(x.shape)
x = x.reshape(N,L,20*7*8) # batch x time x CNN_features
# x = F.relu(self.fc1(x))
# x = F.relu(self.fc2(x))
lstm_out, _ = self.lstm(x)
logits = self.hidden_to_state(lstm_out)
# print(logits.shape)
return logits, lstm_out