File size: 6,913 Bytes
85bd48b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v1.keras.backend as K1
tf1.disable_eager_execution()

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Activation, Dense, Lambda, Layer, Concatenate

def get_TrR_weights(filename):
  weights = [np.squeeze(w) for w in np.load(filename, allow_pickle=True)]
  # remove weights for beta-beta pairing
  del weights[-4:-2]
  return weights

def get_TrR(blocks=12, trainable=False, weights=None, name="TrR"):
  ex = {"trainable":trainable}
  # custom layer(s)
  class PSSM(Layer):
    # modified from MRF to only output tiled 1D features
    def __init__(self, diag=0.4, use_entropy=False):
      super(PSSM, self).__init__()
      self.diag = diag
      self.use_entropy = use_entropy
    def call(self, inputs):
      x,y = inputs
      _,_,L,A = [tf.shape(y)[k] for k in range(4)]
      with tf.name_scope('1d_features'):
        # sequence
        x_i = x[0,0,:,:20]
        # pssm
        f_i = y[0,0]
        # entropy
        if self.use_entropy:
          h_i = K.sum(-f_i * K.log(f_i + 1e-8), axis=-1, keepdims=True)
        else:
          h_i = tf.zeros((L,1))
        # tile and combined 1D features
        feat_1D = tf.concat([x_i,f_i,h_i], axis=-1)
        feat_1D_tile_A = tf.tile(feat_1D[:,None,:], [1,L,1])
        feat_1D_tile_B = tf.tile(feat_1D[None,:,:], [L,1,1])

      with tf.name_scope('2d_features'):
        ic = self.diag * tf.eye(L*A)
        ic = tf.reshape(ic,(L,A,L,A))
        ic = tf.transpose(ic,(0,2,1,3))
        ic = tf.reshape(ic,(L,L,A*A))
        i0 = tf.zeros([L,L,1])
        feat_2D = tf.concat([ic,i0], axis=-1)

      feat = tf.concat([feat_1D_tile_A, feat_1D_tile_B, feat_2D],axis=-1)
      return tf.reshape(feat, [1,L,L,442+2*42])
      
  class instance_norm(Layer):
    def __init__(self, axes=(1,2),trainable=True):
      super(instance_norm, self).__init__()
      self.axes = axes
      self.trainable = trainable
    def build(self, input_shape):
      self.beta  = self.add_weight(name='beta',shape=(input_shape[-1],),
                                  initializer='zeros',trainable=self.trainable)
      self.gamma = self.add_weight(name='gamma',shape=(input_shape[-1],),
                                  initializer='ones',trainable=self.trainable)
    def call(self, inputs):
      mean, variance = tf.nn.moments(inputs, self.axes, keepdims=True)
      return tf.nn.batch_normalization(inputs, mean, variance, self.beta, self.gamma, 1e-6)

  ## INPUT ##
  inputs = Input((None,None,21),batch_size=1)
  A = PSSM()([inputs,inputs])
  A = Dense(64, **ex)(A)
  A = instance_norm(**ex)(A)
  A = Activation("elu")(A)

  ## RESNET ##
  def resnet(X, dilation=1, filters=64, win=3):
    Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(X)
    Y = instance_norm(**ex)(Y)
    Y = Activation("elu")(Y)
    Y = Conv2D(filters, win, dilation_rate=dilation, padding='SAME', **ex)(Y)
    Y = instance_norm(**ex)(Y)
    return Activation("elu")(X+Y)

  for _ in range(blocks):
    for dilation in [1,2,4,8,16]:
      A = resnet(A, dilation)
  A = resnet(A, dilation=1)
  
  ## OUTPUT ##
  A_input   = Input((None,None,64))
  p_theta   = Dense(25, activation="softmax", **ex)(A_input)
  p_phi     = Dense(13, activation="softmax", **ex)(A_input)
  A_sym     = Lambda(lambda x: (x + tf.transpose(x,[0,2,1,3]))/2)(A_input)
  p_dist    = Dense(37, activation="softmax", **ex)(A_sym)
  p_omega   = Dense(25, activation="softmax", **ex)(A_sym)
  A_model   = Model(A_input,Concatenate()([p_theta,p_phi,p_dist,p_omega]))

  ## MODEL ##
  model = Model(inputs, A_model(A),name=name)
  if weights is not None: model.set_weights(weights)
  return model

def get_TrR_model(L=None, exclude_theta=False, use_idx=False, use_bkg=False, models_path="models"):
  def gather_idx(x):
    idx = x[1][0]
    return tf.gather(tf.gather(x[0],idx,axis=-2),idx,axis=-3)

  def get_cce_loss(x, eps=1e-8, only_dist=False):
    if only_dist:
      true_x = split_feat(x[0])["dist"]
      pred_x = split_feat(x[1])["dist"]
      loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
      return loss * 4
      
    elif exclude_theta:
      true_x = split_feat(x[0])
      pred_x = split_feat(x[1])
      true_x = tf.concat([true_x[k] for k in ["phi","dist","omega"]],-1)
      pred_x = tf.concat([pred_x[k] for k in ["phi","dist","omega"]],-1)
      loss = -tf.reduce_mean(tf.reduce_sum(true_x*tf.math.log(pred_x + eps),-1),[-1,-2])
      return loss * 4/3

    else:
      return -tf.reduce_mean(tf.reduce_sum(x[0]*tf.math.log(x[1] + eps),-1),[-1,-2])
  
  def get_bkg_loss(x, eps=1e-8):
    return -tf.reduce_mean(tf.reduce_sum(x[1]*(tf.math.log(x[1]+eps)-tf.math.log(x[0]+eps)),-1),[-1,-2])      

  def prep_seq(x_logits):
    x_soft = tf.nn.softmax(x_logits,-1)
    x_hard = tf.one_hot(tf.argmax(x_logits,-1),20)
    x = tf.stop_gradient(x_hard - x_soft) + x_soft
    x = tf.pad(x,[[0,0],[0,0],[0,1]])
    return x[None]

  I_seq_logits = Input((L,20),name="seq_logits")
  seq = Lambda(prep_seq,name="seq")(I_seq_logits)
  I_true = Input((L,L,100),name="true")
  if use_bkg:
    I_bkg = Input((L,L,100),name="bkg")
  if use_idx:
    I_idx = Input((None,),dtype=tf.int32,name="idx")
    I_idx_true = Input((None,),dtype=tf.int32,name="idx_true")
  
  pred = []
  for nam in ["xaa","xab","xac","xad","xae"]:
    print(nam)
    TrR = get_TrR(weights=get_TrR_weights(f"{models_path}/model_{nam}.npy"),name=nam)
    pred.append(TrR(seq))
  pred = sum(pred)/len(pred)

  if use_idx:
    pred_sub = Lambda(gather_idx, name="pred_sub")([pred,I_idx])
    true_sub = Lambda(gather_idx, name="true_sub")([I_true,I_idx_true])
  else:
    pred_sub = pred
    true_sub = I_true
  
  cce_loss = Lambda(get_cce_loss,name="cce_loss")([true_sub, pred_sub])
  if use_bkg:
    bkg_loss = Lambda(get_bkg_loss,name="bkg_loss")([I_bkg, pred])
    loss = Lambda(lambda x: x[0]+0.1*x[1])([cce_loss,bkg_loss])
  else:
    loss = cce_loss
  grad = Lambda(lambda x: tf.gradients(x[0],x[1]), name="grad")([loss,I_seq_logits])

  # setup model
  inputs = [I_seq_logits, I_true]
  outputs = [cce_loss]
  if use_bkg:
    inputs += [I_bkg]
    outputs += [bkg_loss]
  if use_idx: inputs += [I_idx, I_idx_true]  
  model = Model(intputs, outputs + [grad, pred], name="TrR_model")
  
  TrR_model(seq, true, **kwargs):
    i = [seq[None],true[None]]
    if use_bkg:
      i += [kwargs["bkg"][None]]
    if use_idx:
      pos_idx = kwargs["pos_idx"]
      if "pos_idx_ref" not in kwargs or kwargs["pos_idx_ref"] is None:
        pos_idx_ref = pos_idx
      else:
        pos_idx_ref = kwargs["pos_idx_ref"]
      i += [pos_idx[None],pos_idx_ref[None]]
    
    *o = model.predict(i)
    r = {"cce_loss":o[0][0],"grad":o[-1][0],"pred":o[-2][0]}
    if use_bkg: r["bkg_loss"] = o[1][0]
    return r
  
  return TrR_model