Simon Duerr
add fast af
85bd48b
raw
history blame
6.91 kB
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v1.keras.backend as K1
tf1.disable_eager_execution()
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Activation, Dense, Lambda, Layer, Concatenate
def get_TrR_weights(filename):
weights = [np.squeeze(w) for w in np.load(filename, allow_pickle=True)]
# remove weights for beta-beta pairing
del weights[-4:-2]
return weights
def get_TrR(blocks=12, trainable=False, weights=None, name="TrR"):
ex = {"trainable":trainable}
# custom layer(s)
class PSSM(Layer):
# modified from MRF to only output tiled 1D features
def __init__(self, diag=0.4, use_entropy=False):
super(PSSM, self).__init__()
self.diag = diag
self.use_entropy = use_entropy
def call(self, inputs):
x,y = inputs
_,_,L,A = [tf.shape(y)[k] for k in range(4)]
with tf.name_scope('1d_features'):
# sequence
x_i = x[0,0,:,:20]
# pssm
f_i = y[0,0]
# entropy
if self.use_entropy:
h_i = K.sum(-f_i * K.log(f_i + 1e-8), axis=-1, keepdims=True)
else:
h_i = tf.zeros((L,1))
# tile and combined 1D features
feat_1D = tf.concat([x_i,f_i,h_i], axis=-1)
feat_1D_tile_A = tf.tile(feat_1D[:,None,:], [1,L,1])
feat_1D_tile_B = tf.tile(feat_1D[None,:,:], [L,1,1])
with tf.name_scope('2d_features'):
ic = self.diag * tf.eye(L*A)
ic = tf.reshape(ic,(L,A,L,A))
ic = tf.transpose(ic,(0,2,1,3))
ic = tf.reshape(ic,(L,L,A*A))
i0 = tf.zeros([L,L,1])
feat_2D = tf.concat([ic,i0], axis=-1)
feat = tf.concat([feat_1D_tile_A, feat_1D_tile_B, feat_2D],axis=-1)
return tf.reshape(feat, [1,L,L,442+2*42])
class instance_norm(Layer):
def __init__(self, axes=(1,2),trainable=True):
super(instance_norm, self).__init__()
self.axes = axes
self.trainable = trainable
def build(self, input_shape):
self.beta = self.add_weight(name='beta',shape=(input_shape[-1],),
initializer='zeros',trainable=self.trainable)
self.gamma = self.add_weight(name='gamma',shape=(input_shape[-1],),
initializer='ones',trainable=self.trainable)
def call(self, inputs):
mean, variance = tf.nn.moments(inputs, self.axes, keepdims=True)
return tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, 1e-6)
## INPUT ##
inputs = Input((None,None,21),batch_size=1)
A = PSSM()([inputs,inputs])
A = Dense(64, **ex)(A)
A = instance_norm(**ex)(A)
A = Activation("elu")(A)
## RESNET ##
def resnet(X, dilation=1, filters=64, win=3):
Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(X)
Y = instance_norm(**ex)(Y)
Y = Activation("elu")(Y)
Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(Y)
Y = instance_norm(**ex)(Y)
return Activation("elu")(X+Y)
for _ in range(blocks):
for dilation in [1,2,4,8,16]:
A = resnet(A, dilation)
A = resnet(A, dilation=1)
## OUTPUT ##
A_input = Input((None,None,64))
p_theta = Dense(25, activation="softmax", **ex)(A_input)
p_phi = Dense(13, activation="softmax", **ex)(A_input)
A_sym = Lambda(lambda x: (x + tf.transpose(x,[0,2,1,3]))/2)(A_input)
p_dist = Dense(37, activation="softmax", **ex)(A_sym)
p_omega = Dense(25, activation="softmax", **ex)(A_sym)
A_model = Model(A_input,Concatenate()([p_theta,p_phi,p_dist,p_omega]))
## MODEL ##
model = Model(inputs, A_model(A),name=name)
if weights is not None: model.set_weights(weights)
return model
def get_TrR_model(L=None, exclude_theta=False, use_idx=False, use_bkg=False, models_path="models"):
def gather_idx(x):
idx = x[1][0]
return tf.gather(tf.gather(x[0],idx,axis=-2),idx,axis=-3)
def get_cce_loss(x, eps=1e-8, only_dist=False):
if only_dist:
true_x = split_feat(x[0])["dist"]
pred_x = split_feat(x[1])["dist"]
loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
return loss * 4
elif exclude_theta:
true_x = split_feat(x[0])
pred_x = split_feat(x[1])
true_x = tf.concat([true_x[k] for k in ["phi","dist","omega"]],-1)
pred_x = tf.concat([pred_x[k] for k in ["phi","dist","omega"]],-1)
loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
return loss * 4/3
else:
return -tf.reduce_mean(tf.reduce_sum(x[0]*tf.math.log(x[1] + eps),-1),[-1,-2])
def get_bkg_loss(x, eps=1e-8):
return -tf.reduce_mean(tf.reduce_sum(x[1]*(tf.math.log(x[1]+eps)-tf.math.log(x[0]+eps)),-1),[-1,-2])
def prep_seq(x_logits):
x_soft = tf.nn.softmax(x_logits,-1)
x_hard = tf.one_hot(tf.argmax(x_logits,-1),20)
x = tf.stop_gradient(x_hard - x_soft) + x_soft
x = tf.pad(x,[[0,0],[0,0],[0,1]])
return x[None]
I_seq_logits = Input((L,20),name="seq_logits")
seq = Lambda(prep_seq,name="seq")(I_seq_logits)
I_true = Input((L,L,100),name="true")
if use_bkg:
I_bkg = Input((L,L,100),name="bkg")
if use_idx:
I_idx = Input((None,),dtype=tf.int32,name="idx")
I_idx_true = Input((None,),dtype=tf.int32,name="idx_true")
pred = []
for nam in ["xaa","xab","xac","xad","xae"]:
print(nam)
TrR = get_TrR(weights=get_TrR_weights(f"{models_path}/model_{nam}.npy"),name=nam)
pred.append(TrR(seq))
pred = sum(pred)/len(pred)
if use_idx:
pred_sub = Lambda(gather_idx, name="pred_sub")([pred,I_idx])
true_sub = Lambda(gather_idx, name="true_sub")([I_true,I_idx_true])
else:
pred_sub = pred
true_sub = I_true
cce_loss = Lambda(get_cce_loss,name="cce_loss")([true_sub, pred_sub])
if use_bkg:
bkg_loss = Lambda(get_bkg_loss,name="bkg_loss")([I_bkg, pred])
loss = Lambda(lambda x: x[0]+0.1*x[1])([cce_loss,bkg_loss])
else:
loss = cce_loss
grad = Lambda(lambda x: tf.gradients(x[0],x[1]), name="grad")([loss,I_seq_logits])
# setup model
inputs = [I_seq_logits, I_true]
outputs = [cce_loss]
if use_bkg:
inputs += [I_bkg]
outputs += [bkg_loss]
if use_idx: inputs += [I_idx, I_idx_true]
model = Model(intputs, outputs + [grad, pred], name="TrR_model")
TrR_model(seq, true, **kwargs):
i = [seq[None],true[None]]
if use_bkg:
i += [kwargs["bkg"][None]]
if use_idx:
pos_idx = kwargs["pos_idx"]
if "pos_idx_ref" not in kwargs or kwargs["pos_idx_ref"] is None:
pos_idx_ref = pos_idx
else:
pos_idx_ref = kwargs["pos_idx_ref"]
i += [pos_idx[None],pos_idx_ref[None]]
*o = model.predict(i)
r = {"cce_loss":o[0][0],"grad":o[-1][0],"pred":o[-2][0]}
if use_bkg: r["bkg_loss"] = o[1][0]
return r
return TrR_model