Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -48,6 +48,23 @@ class Sequence:
48
  return to_categorical(np.array([padded_ie[0], all_ohe]))[:1]
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def residual_block(data, filters, d_rate):
52
  """
53
  _data: input
@@ -97,19 +114,22 @@ def get_model():
97
  return model2
98
 
99
 
100
- model = get_model()
 
101
  mappings_path = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'prot_mappings.json'))
102
  with open(mappings_path) as f:
103
  prot_mappings = json.load(f)
104
 
105
  def greet(Amino_Acid_Sequence):
106
  processed_seq = Sequence.prepare(Amino_Acid_Sequence)
107
- raw_prediction = model.predict(processed_seq)[0]
108
- idx = raw_prediction.argmax()
 
 
109
  fam_asc = prot_mappings['id2fam_asc'][str(idx)]
110
  fam_id = prot_mappings['fam_asc2fam_id'][fam_asc]
111
  gc.collect()
112
- return f"Input is {Amino_Acid_Sequence}.\nProcessed input is:\n{processed_seq}\n\nModel makes prediction which maps to:\nFamily Accession={fam_asc} and ID={fam_id}\n\nRaw Prediction:\n{raw_prediction}\n"
113
 
114
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
115
  iface.launch()
 
48
  return to_categorical(np.array([padded_ie[0], all_ohe]))[:1]
49
 
50
 
51
+ def get_lstm_model():
52
+ x_input = Input(shape=(100,))
53
+ max_length = 100
54
+ emb = Embedding(21, 128, input_length=max_length)(x_input)
55
+ bi_rnn = CuDNNLSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01))(emb)
56
+ x = Dropout(0.3)(bi_rnn)
57
+
58
+ # softmax classifier
59
+ x_output = Dense(1000, activation='softmax')(x)
60
+
61
+ model1 = Model(inputs=x_input, outputs=x_output)
62
+ model1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
63
+ weights = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'model1.h5'))
64
+ model1.load_weights(weights)
65
+ return model1
66
+
67
+
68
  def residual_block(data, filters, d_rate):
69
  """
70
  _data: input
 
114
  return model2
115
 
116
 
117
+ cnn_model = get_model()
118
+ lstm_model = get_lstm_model()
119
  mappings_path = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'prot_mappings.json'))
120
  with open(mappings_path) as f:
121
  prot_mappings = json.load(f)
122
 
123
  def greet(Amino_Acid_Sequence):
124
  processed_seq = Sequence.prepare(Amino_Acid_Sequence)
125
+ cnn_raw_prediction = cnn_model.predict(processed_seq)[0]
126
+ lstm_raw_prediction = lstm_model.predict(processed_seq)[0]
127
+ joined_prediction = (cnn_raw_prediction + lstm_raw_prediction) / 2.0
128
+ idx = joined_prediction.argmax()
129
  fam_asc = prot_mappings['id2fam_asc'][str(idx)]
130
  fam_id = prot_mappings['fam_asc2fam_id'][fam_asc]
131
  gc.collect()
132
+ return f"Input is {Amino_Acid_Sequence}.\nProcessed input is:\n{processed_seq}\n\nModel makes prediction which maps to:\nFamily Accession={fam_asc} and ID={fam_id}\n\nRaw Joined Prediction:\n{joined_prediction}\n"
133
 
134
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
135
  iface.launch()