|
import argparse |
|
import os |
|
|
|
import numpy as np |
|
import torch as th |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .pr_datasets_all import FUNC_DICT |
|
import matplotlib.pyplot as plt |
|
|
|
plt.rcParams["figure.figsize"] = (20, 3) |
|
plt.rcParams['figure.dpi'] = 300 |
|
plt.rcParams['savefig.dpi'] = 300 |
|
|
|
|
|
def model_fn(x, t, y=None, rule=None, |
|
model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.): |
|
|
|
y_null = th.tensor([num_classes] * x.shape[0], device=x.device) |
|
if class_cond: |
|
if cfg: |
|
return (1 + w) * model(x, t, y) - w * model(x, t, y_null) |
|
else: |
|
return model(x, t, y) |
|
else: |
|
return model(x, t, y_null) |
|
|
|
|
|
def dc_model_fn(x, t, y=None, rule=None, |
|
model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.): |
|
|
|
x = x.permute(0, 1, 3, 2) |
|
y_null = th.tensor([num_classes] * x.shape[0], device=x.device) |
|
if class_cond: |
|
if cfg: |
|
eps = (1 + w) * model(x, t, y) - w * model(x, t, y_null) |
|
return eps.permute(0, 1, 3, 2) |
|
else: |
|
return model(x, t, y).permute(0, 1, 3, 2) |
|
else: |
|
return model(x, t, y_null).permute(0, 1, 3, 2) |
|
|
|
|
|
|
|
def grad_nn_zt_xentropy(x, y=None, rule=None, classifier=nn.Identity()): |
|
|
|
assert rule is not None |
|
t = th.zeros(x.shape[0], device=x.device) |
|
with th.enable_grad(): |
|
x_in = x.detach().requires_grad_(True) |
|
logits = classifier(x_in, t) |
|
log_probs = F.log_softmax(logits, dim=-1) |
|
selected = log_probs[range(len(logits)), rule.view(-1)] |
|
return th.autograd.grad(selected.sum(), x_in)[0] |
|
|
|
|
|
def grad_nn_zt_mse(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity()): |
|
assert rule is not None |
|
with th.enable_grad(): |
|
x_in = x.detach().requires_grad_(True) |
|
logits = classifier(x_in, t) |
|
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1) |
|
return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale |
|
|
|
|
|
def grad_nn_zt_chord(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity(), both=False): |
|
assert rule is not None |
|
with th.enable_grad(): |
|
x_in = x.detach().requires_grad_(True) |
|
key_logits, chord_logits = classifier(x_in, t) |
|
if both: |
|
rule_key = rule[:, :1] |
|
rule_chord = rule[:, 1:] |
|
rule_chord = rule_chord.reshape(-1) |
|
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1]) |
|
key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none") |
|
chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none") |
|
chord_log_probs = chord_log_probs.reshape(x_in.shape[0], -1).mean(dim=-1) |
|
log_probs = key_log_probs + chord_log_probs |
|
else: |
|
rule = rule.reshape(-1) |
|
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1]) |
|
log_probs = - F.cross_entropy(chord_logits, rule, reduction="none") |
|
return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale |
|
|
|
|
|
def nn_z0_chord_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity(), both=False): |
|
|
|
t = th.zeros(x.shape[0], device=x.device) |
|
key_logits, chord_logits = classifier(x, t) |
|
if both: |
|
rule_key = rule[:, :1] |
|
rule_chord = rule[:, 1:] |
|
rule_chord = rule_chord.reshape(-1) |
|
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1]) |
|
key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none") |
|
chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none") |
|
chord_log_probs = chord_log_probs.reshape(x.shape[0], -1).mean(dim=-1) |
|
log_probs = key_log_probs + chord_log_probs |
|
else: |
|
rule = rule.reshape(-1) |
|
chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1]) |
|
log_probs = - F.cross_entropy(chord_logits, rule, reduction="none") |
|
log_probs = log_probs.reshape(x.shape[0], -1).mean(dim=-1) |
|
return log_probs * classifier_scale |
|
|
|
|
|
def nn_z0_mse_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity()): |
|
|
|
assert rule is not None |
|
t = th.zeros(x.shape[0], device=x.device) |
|
logits = classifier(x, t) |
|
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1) |
|
return log_probs * classifier_scale |
|
|
|
|
|
def nn_z0_mse(x, rule=None, classifier=nn.Identity()): |
|
|
|
t = th.zeros(x.shape[0], device=x.device) |
|
logits = classifier(x, t) |
|
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1) |
|
return log_probs |
|
|
|
|
|
def rule_x0_mse_dummy(x, t, y=None, rule=None, rule_name='pitch_hist'): |
|
|
|
logits = FUNC_DICT[rule_name](x) |
|
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1) |
|
return log_probs |
|
|
|
|
|
def rule_x0_mse(x, rule=None, rule_name='pitch_hist', soft=False): |
|
|
|
|
|
logits = FUNC_DICT[rule_name](x, soft=soft) |
|
log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1) |
|
return log_probs |
|
|
|
|
|
class _WrappedFn: |
|
def __init__(self, fn): |
|
self.fn = fn |
|
|
|
def __call__(self, x, t, y=None, rule=None): |
|
return self.fn(x, t, y, rule) |
|
|
|
|
|
function_map = { |
|
"grad_nn_zt_xentropy": grad_nn_zt_xentropy, |
|
"grad_nn_zt_mse": grad_nn_zt_mse, |
|
"grad_nn_zt_chord": grad_nn_zt_chord, |
|
"nn_z0_chord_dummy": nn_z0_chord_dummy, |
|
"nn_z0_mse_dummy": nn_z0_mse_dummy, |
|
"nn_z0_mse": nn_z0_mse, |
|
"rule_x0_mse_dummy": rule_x0_mse_dummy, |
|
"rule_x0_mse": rule_x0_mse |
|
} |
|
|
|
|
|
def composite_nn_zt(x, t, y=None, rule=None, fns=None, classifier_scales=None, classifiers=None, rule_names=None): |
|
num_classifiers = len(classifiers) |
|
out = 0 |
|
for i in range(num_classifiers): |
|
out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]], |
|
classifier_scale=classifier_scales[i], classifier=classifiers[i]) |
|
return out |
|
|
|
|
|
def composite_rule(x, t, y=None, rule=None, fns=None, classifier_scales=None, rule_names=None): |
|
out = 0 |
|
for i in range(len(fns)): |
|
out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]], rule_name=rule_names[i]) * classifier_scales[i] |
|
return out |
|
|