import tensorflow as tf import tensorflow.keras.backend as K import tensorflow.compat.v1 as tf1 import tensorflow.compat.v1.keras.backend as K1 tf1.disable_eager_execution() from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Conv2D, Activation, Dense, Lambda, Layer, Concatenate def get_TrR_weights(filename): weights = [np.squeeze(w) for w in np.load(filename, allow_pickle=True)] # remove weights for beta-beta pairing del weights[-4:-2] return weights def get_TrR(blocks=12, trainable=False, weights=None, name="TrR"): ex = {"trainable":trainable} # custom layer(s) class PSSM(Layer): # modified from MRF to only output tiled 1D features def __init__(self, diag=0.4, use_entropy=False): super(PSSM, self).__init__() self.diag = diag self.use_entropy = use_entropy def call(self, inputs): x,y = inputs _,_,L,A = [tf.shape(y)[k] for k in range(4)] with tf.name_scope('1d_features'): # sequence x_i = x[0,0,:,:20] # pssm f_i = y[0,0] # entropy if self.use_entropy: h_i = K.sum(-f_i * K.log(f_i + 1e-8), axis=-1, keepdims=True) else: h_i = tf.zeros((L,1)) # tile and combined 1D features feat_1D = tf.concat([x_i,f_i,h_i], axis=-1) feat_1D_tile_A = tf.tile(feat_1D[:,None,:], [1,L,1]) feat_1D_tile_B = tf.tile(feat_1D[None,:,:], [L,1,1]) with tf.name_scope('2d_features'): ic = self.diag * tf.eye(L*A) ic = tf.reshape(ic,(L,A,L,A)) ic = tf.transpose(ic,(0,2,1,3)) ic = tf.reshape(ic,(L,L,A*A)) i0 = tf.zeros([L,L,1]) feat_2D = tf.concat([ic,i0], axis=-1) feat = tf.concat([feat_1D_tile_A, feat_1D_tile_B, feat_2D],axis=-1) return tf.reshape(feat, [1,L,L,442+2*42]) class instance_norm(Layer): def __init__(self, axes=(1,2),trainable=True): super(instance_norm, self).__init__() self.axes = axes self.trainable = trainable def build(self, input_shape): self.beta = self.add_weight(name='beta',shape=(input_shape[-1],), initializer='zeros',trainable=self.trainable) self.gamma = self.add_weight(name='gamma',shape=(input_shape[-1],), initializer='ones',trainable=self.trainable) def call(self, inputs): mean, variance = tf.nn.moments(inputs, self.axes, keepdims=True) return tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, 1e-6) ## INPUT ## inputs = Input((None,None,21),batch_size=1) A = PSSM()([inputs,inputs]) A = Dense(64, **ex)(A) A = instance_norm(**ex)(A) A = Activation("elu")(A) ## RESNET ## def resnet(X, dilation=1, filters=64, win=3): Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(X) Y = instance_norm(**ex)(Y) Y = Activation("elu")(Y) Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(Y) Y = instance_norm(**ex)(Y) return Activation("elu")(X+Y) for _ in range(blocks): for dilation in [1,2,4,8,16]: A = resnet(A, dilation) A = resnet(A, dilation=1) ## OUTPUT ## A_input = Input((None,None,64)) p_theta = Dense(25, activation="softmax", **ex)(A_input) p_phi = Dense(13, activation="softmax", **ex)(A_input) A_sym = Lambda(lambda x: (x + tf.transpose(x,[0,2,1,3]))/2)(A_input) p_dist = Dense(37, activation="softmax", **ex)(A_sym) p_omega = Dense(25, activation="softmax", **ex)(A_sym) A_model = Model(A_input,Concatenate()([p_theta,p_phi,p_dist,p_omega])) ## MODEL ## model = Model(inputs, A_model(A),name=name) if weights is not None: model.set_weights(weights) return model def get_TrR_model(L=None, exclude_theta=False, use_idx=False, use_bkg=False, models_path="models"): def gather_idx(x): idx = x[1][0] return tf.gather(tf.gather(x[0],idx,axis=-2),idx,axis=-3) def get_cce_loss(x, eps=1e-8, only_dist=False): if only_dist: true_x = split_feat(x[0])["dist"] pred_x = split_feat(x[1])["dist"] loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2]) return loss * 4 elif exclude_theta: true_x = split_feat(x[0]) pred_x = split_feat(x[1]) true_x = tf.concat([true_x[k] for k in ["phi","dist","omega"]],-1) pred_x = tf.concat([pred_x[k] for k in ["phi","dist","omega"]],-1) loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2]) return loss * 4/3 else: return -tf.reduce_mean(tf.reduce_sum(x[0]*tf.math.log(x[1] + eps),-1),[-1,-2]) def get_bkg_loss(x, eps=1e-8): return -tf.reduce_mean(tf.reduce_sum(x[1]*(tf.math.log(x[1]+eps)-tf.math.log(x[0]+eps)),-1),[-1,-2]) def prep_seq(x_logits): x_soft = tf.nn.softmax(x_logits,-1) x_hard = tf.one_hot(tf.argmax(x_logits,-1),20) x = tf.stop_gradient(x_hard - x_soft) + x_soft x = tf.pad(x,[[0,0],[0,0],[0,1]]) return x[None] I_seq_logits = Input((L,20),name="seq_logits") seq = Lambda(prep_seq,name="seq")(I_seq_logits) I_true = Input((L,L,100),name="true") if use_bkg: I_bkg = Input((L,L,100),name="bkg") if use_idx: I_idx = Input((None,),dtype=tf.int32,name="idx") I_idx_true = Input((None,),dtype=tf.int32,name="idx_true") pred = [] for nam in ["xaa","xab","xac","xad","xae"]: print(nam) TrR = get_TrR(weights=get_TrR_weights(f"{models_path}/model_{nam}.npy"),name=nam) pred.append(TrR(seq)) pred = sum(pred)/len(pred) if use_idx: pred_sub = Lambda(gather_idx, name="pred_sub")([pred,I_idx]) true_sub = Lambda(gather_idx, name="true_sub")([I_true,I_idx_true]) else: pred_sub = pred true_sub = I_true cce_loss = Lambda(get_cce_loss,name="cce_loss")([true_sub, pred_sub]) if use_bkg: bkg_loss = Lambda(get_bkg_loss,name="bkg_loss")([I_bkg, pred]) loss = Lambda(lambda x: x[0]+0.1*x[1])([cce_loss,bkg_loss]) else: loss = cce_loss grad = Lambda(lambda x: tf.gradients(x[0],x[1]), name="grad")([loss,I_seq_logits]) # setup model inputs = [I_seq_logits, I_true] outputs = [cce_loss] if use_bkg: inputs += [I_bkg] outputs += [bkg_loss] if use_idx: inputs += [I_idx, I_idx_true] model = Model(intputs, outputs + [grad, pred], name="TrR_model") TrR_model(seq, true, **kwargs): i = [seq[None],true[None]] if use_bkg: i += [kwargs["bkg"][None]] if use_idx: pos_idx = kwargs["pos_idx"] if "pos_idx_ref" not in kwargs or kwargs["pos_idx_ref"] is None: pos_idx_ref = pos_idx else: pos_idx_ref = kwargs["pos_idx_ref"] i += [pos_idx[None],pos_idx_ref[None]] *o = model.predict(i) r = {"cce_loss":o[0][0],"grad":o[-1][0],"pred":o[-2][0]} if use_bkg: r["bkg_loss"] = o[1][0] return r return TrR_model