Andrew Stirn commited on
Commit
d5e0e34
1 Parent(s): 12459b9

updated tiger.py

Browse files
Files changed (1) hide show
  1. tiger.py +6 -6
tiger.py CHANGED
@@ -24,7 +24,7 @@ GUIDE_LEN = 23
24
  CONTEXT_5P = 3
25
  CONTEXT_3P = 0
26
  TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
27
- UNIT_INTERVAL_MAP = 'exp-lin-exp'
28
 
29
  # reference transcript files
30
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
@@ -130,12 +130,12 @@ def calibrate_predictions(predictions: np.array, num_mismatches: np.array, param
130
  return correction * predictions
131
 
132
 
133
- def transform_predictions(predictions: np.array, params: dict = None):
134
  if params is None:
135
- with open('transform_params.pkl', 'rb') as f:
136
- params = pickle.load(f)
137
 
138
  if UNIT_INTERVAL_MAP == 'sigmoid':
 
139
  return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
140
 
141
  elif UNIT_INTERVAL_MAP == 'min-max':
@@ -180,7 +180,7 @@ def get_on_target_predictions(transcripts: pd.DataFrame, model: tf.keras.Model,
180
  # get predictions
181
  lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
182
  lfc_estimate = calibrate_predictions(lfc_estimate, num_mismatches=np.zeros_like(lfc_estimate))
183
- scores = transform_predictions(lfc_estimate)
184
  predictions = pd.concat([predictions, pd.DataFrame({
185
  ID_COL: [index] * len(scores),
186
  TARGET_COL: target_seq,
@@ -310,7 +310,7 @@ def predict_off_target(off_targets: pd.DataFrame, model: tf.keras.Model):
310
  ], axis=-1)
311
  lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
312
  lfc_estimate = calibrate_predictions(lfc_estimate, off_targets['Number of Mismatches'].to_numpy())
313
- off_targets[SCORE_COL] = transform_predictions(lfc_estimate)
314
 
315
  return off_targets.reset_index(drop=True)
316
 
 
24
  CONTEXT_5P = 3
25
  CONTEXT_3P = 0
26
  TARGET_LEN = CONTEXT_5P + GUIDE_LEN + CONTEXT_3P
27
+ UNIT_INTERVAL_MAP = 'sigmoid'
28
 
29
  # reference transcript files
30
  REFERENCE_TRANSCRIPTS = ('gencode.v19.pc_transcripts.fa.gz', 'gencode.v19.lncRNA_transcripts.fa.gz')
 
130
  return correction * predictions
131
 
132
 
133
+ def score_predictions(predictions: np.array, params: pd.DataFrame = None):
134
  if params is None:
135
+ params = pd.read_pickle('scoring_params.pkl')
 
136
 
137
  if UNIT_INTERVAL_MAP == 'sigmoid':
138
+ params = params.iloc[0]
139
  return 1 - 1 / (1 + np.exp(params['a'] * predictions + params['b']))
140
 
141
  elif UNIT_INTERVAL_MAP == 'min-max':
 
180
  # get predictions
181
  lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
182
  lfc_estimate = calibrate_predictions(lfc_estimate, num_mismatches=np.zeros_like(lfc_estimate))
183
+ scores = score_predictions(lfc_estimate)
184
  predictions = pd.concat([predictions, pd.DataFrame({
185
  ID_COL: [index] * len(scores),
186
  TARGET_COL: target_seq,
 
310
  ], axis=-1)
311
  lfc_estimate = model.predict(model_inputs, batch_size=BATCH_SIZE_COMPUTE, verbose=False)[:, 0]
312
  lfc_estimate = calibrate_predictions(lfc_estimate, off_targets['Number of Mismatches'].to_numpy())
313
+ off_targets[SCORE_COL] = score_predictions(lfc_estimate)
314
 
315
  return off_targets.reset_index(drop=True)
316