File size: 27,366 Bytes
0b11a42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
import random
import sys
from random import randint
import pandas as pd
import plotly.graph_objects as go
from anndata import AnnData
#add parent directory to path
sys.path.append('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/')
from src import (Results_Handler, correct_labels, load, predict_transforna,
predict_transforna_all_models,get_fused_seqs)
def get_mc_sc(infer_df,sequences,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag = False):
infered_seqs = infer_df.loc[sequences]
sc_classes_df = infered_seqs['subclass_name'].str.split(';',expand=True)
#filter rows with all nans in sc_classes_df
sc_classes_df = sc_classes_df[~sc_classes_df.isnull().all(axis=1)]
#cmask for classes used for training
if ood_flag:
sub_classes_used_for_training_plus_neighbors = []
#for every subclass in sub_classes_used_for_training that contains bin, get previous and succeeding bins
for sub_class in sub_classes_used_for_training:
sub_classes_used_for_training_plus_neighbors.append(sub_class)
if 'bin' in sub_class:
bin_num = int(sub_class.split('_bin-')[1])
if bin_num > 0:
sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num-1}')
sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num+1}')
if 'tR' in sub_class:
#seperate the first part(either 3p/5p), also ge tthe second part after __
first_part = sub_class.split('-')[0]
second_part = sub_class.split('__')[1]
#get all classes in sc_to_mc_mapper_dict,values that contain both parts and append them to sub_classes_used_for_training_plus_neighbors
sub_classes_used_for_training_plus_neighbors += [sc for sc in sc_to_mc_mapper_dict.keys() if first_part in sc and second_part in sc]
sub_classes_used_for_training_plus_neighbors = list(set(sub_classes_used_for_training_plus_neighbors))
mask = sc_classes_df.applymap(lambda x: True if (x not in sub_classes_used_for_training_plus_neighbors and 'hypermapper' not in x)\
or pd.isnull(x) else False)
else:
mask = sc_classes_df.applymap(lambda x: True if x in sub_classes_used_for_training or pd.isnull(x) else False)
#check if any sub class in sub_classes_used_for_training is in sc_classes_df
if mask.apply(lambda x: all(x.tolist()), axis=1).sum() == 0:
#TODO: change to log
import logging
log_ = logging.getLogger(__name__)
log_.error('None of the sub classes used for training are in the sequences')
raise Exception('None of the sub classes used for training are in the sequences')
#filter rows with atleast one False in mask
sc_classes_df = sc_classes_df[mask.apply(lambda x: all(x.tolist()), axis=1)]
#get mc classes
mc_classes_df = sc_classes_df.applymap(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict else 'not_found')
#assign major class for not found values if containing 'miRNA', 'tRNA', 'rRNA', 'snRNA', 'snoRNA'
#mc_classes_df = mc_classes_df.applymap(lambda x: None if x is None else 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else 'not_found')
#filter all 'not_found' rows
mc_classes_df = mc_classes_df[mc_classes_df.apply(lambda x: 'not_found' not in x.tolist() ,axis=1)]
#filter values with ; in mc_classes_df
mc_classes_df = mc_classes_df[~mc_classes_df[0].str.contains(';')]
#filter index
sc_classes_df = sc_classes_df.loc[mc_classes_df.index]
mc_classes_df = mc_classes_df.loc[sc_classes_df.index]
return mc_classes_df,sc_classes_df
def plot_confusion_false_novel(df,sc_df,mc_df,save_figs:bool=False):
#filter index of sc_classes_df to contain indices of outliers df
curr_sc_classes_df = sc_df.loc[[i for i in df.index if i in sc_df.index]]
curr_mc_classes_df = mc_df.loc[[i for i in df.index if i in mc_df.index]]
#convert Labels to mc_Labels
df = df.assign(predicted_mc_labels=df.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1))
#add mc classes
df = df.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist())
#add sc classes
df = df.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist())
#compute accuracy
df = df.assign(mc_accuracy=df.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1))
df = df.assign(sc_accuracy=df.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1))
#use plotly to plot confusion matrix based on mc classes
mc_confusion_matrix = df.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack()
mc_confusion_matrix = mc_confusion_matrix.fillna(0)
mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1)
mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,2))
#for columns not in rows, sum the column values and add them to a new column called 'other'
other_col = [0]*mc_confusion_matrix.shape[0]
for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]:
other_col += mc_confusion_matrix[i]
mc_confusion_matrix['other'] = other_col
#add an other row with all zeros
mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1]
#drop all columns not in rows
mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1)
#plot confusion matri
fig = go.Figure(data=go.Heatmap(
z=mc_confusion_matrix.values,
x=mc_confusion_matrix.columns,
y=mc_confusion_matrix.index,
hoverongaps = False))
#add z values to heatmap
for i in range(len(mc_confusion_matrix.index)):
for j in range(len(mc_confusion_matrix.columns)):
fig.add_annotation(text=str(mc_confusion_matrix.values[i][j]), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i],
showarrow=False, font_size=25, font_color='red')
#add title
fig.update_layout(title_text='Confusion matrix based on mc classes for false novel sequences')
#label x axis and y axis
fig.update_xaxes(title_text='Predicted mc class')
fig.update_yaxes(title_text='Actual mc class')
#save
if save_figs:
fig.write_image('transforna/bin/lc_figures/confusion_matrix_mc_classes_false_novel.png')
def compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers = False,fig_prefix:str = '',save_figs:bool=False):
font_size = 25
if fig_prefix == 'LC-familiar':
font_size = 10
#rename Labels to predicted_sc_labels
prediction_pd = prediction_pd.rename(columns={'Net-Label':'predicted_sc_labels'})
for model in prediction_pd['Model'].unique():
#get model predictions
num_rows = sc_classes_df.shape[0]
model_prediction_pd = prediction_pd[prediction_pd['Model'] == model]
model_prediction_pd = model_prediction_pd.set_index('Sequence')
#filter index of model_prediction_pd to contain indices of sc_classes_df
model_prediction_pd = model_prediction_pd.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]]
try: #try because ensemble models do not have a folder
#check how many of the hico seqs exist in the train_df
embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/{model}/embedds'
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
except:
embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq-Rev/embedds'
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train'])
train_seqs = set(results.splits_df_dict['train_df']['RNA Sequences']['0'].values.tolist())
common_seqs = train_seqs.intersection(set(model_prediction_pd.index.tolist()))
print(f'Number of common seqs between train_df and {model} is {len(common_seqs)}')
#print(f'removing overlaping sequences between train set and inference')
#remove common_seqs from model_prediction_pd
#model_prediction_pd = model_prediction_pd.drop(common_seqs)
#compute number of sequences where NLD is higher than Novelty Threshold
num_outliers = sum(model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'])
false_novel_df = model_prediction_pd[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']]
plot_confusion_false_novel(false_novel_df,sc_classes_df,mc_classes_df,save_figs)
#draw a pie chart depicting number of outliers per actual_mc_labels
fig_outl = mc_classes_df.loc[false_novel_df.index][0].value_counts().plot.pie(autopct='%1.1f%%',figsize=(6, 6))
fig_outl.set_title(f'False Novel per MC for {model}: {num_outliers}')
if save_figs:
fig_outl.get_figure().savefig(f'transforna/bin/lc_figures/false_novel_mc_{model}.png')
fig_outl.get_figure().clf()
#get number of unique sub classes per major class in false_novel_df
false_novel_sc_freq_df = sc_classes_df.loc[false_novel_df.index][0].value_counts().to_frame()
#save index as csv
#false_novel_sc_freq_df.to_csv(f'false_novel_sc_freq_df_{model}.csv')
#add mc to false_novel_sc_freq_df
false_novel_sc_freq_df['MC'] = false_novel_sc_freq_df.index.map(lambda x: sc_to_mc_mapper_dict[x])
#plot pie chart showing unique sub classes per major class in false_novel_df
fig_outl_sc = false_novel_sc_freq_df.groupby('MC')[0].sum().plot.pie(autopct='%1.1f%%',figsize=(6, 6))
fig_outl_sc.set_title(f'False novel: No. Unique sub classes per MC {model}: {num_outliers}')
if save_figs:
fig_outl_sc.get_figure().savefig(f'transforna/bin/lc_figures/{fig_prefix}_false_novel_sc_{model}.png')
fig_outl_sc.get_figure().clf()
#filter outliers
if seperate_outliers:
model_prediction_pd = model_prediction_pd[model_prediction_pd['NLD'] <= model_prediction_pd['Novelty Threshold']]
else:
#set the predictions of outliers to 'Outlier'
model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_sc_labels'] = 'Outlier'
model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_mc_labels'] = 'Outlier'
sc_to_mc_mapper_dict['Outlier'] = 'Outlier'
#filter index of sc_classes_df to contain indices of model_prediction_pd
curr_sc_classes_df = sc_classes_df.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]]
curr_mc_classes_df = mc_classes_df.loc[[i for i in model_prediction_pd.index if i in mc_classes_df.index]]
#convert Labels to mc_Labels
model_prediction_pd = model_prediction_pd.assign(predicted_mc_labels=model_prediction_pd.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1))
#add mc classes
model_prediction_pd = model_prediction_pd.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist())
#add sc classes
model_prediction_pd = model_prediction_pd.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist())
#correct labels
model_prediction_pd['predicted_sc_labels'] = correct_labels(model_prediction_pd['predicted_sc_labels'],model_prediction_pd['actual_sc_labels'],sc_to_mc_mapper_dict)
#compute accuracy
model_prediction_pd = model_prediction_pd.assign(mc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1))
model_prediction_pd = model_prediction_pd.assign(sc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1))
if not seperate_outliers:
cols_to_save = ['actual_mc_labels','predicted_mc_labels','predicted_sc_labels','actual_sc_labels']
total_false_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels != model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save]
#add a column indicating if NLD is greater than Novelty Threshold
total_false_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_false_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_false_mc_predictions_df.index]['Novelty Threshold']
#save
total_false_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_false_mcs_w_out_{model}.csv')
total_true_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels == model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save]
#add a column indicating if NLD is greater than Novelty Threshold
total_true_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_true_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_true_mc_predictions_df.index]['Novelty Threshold']
#save
total_true_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_true_mcs_w_out_{model}.csv')
print('Model: ', model)
print('num_outliers: ', num_outliers)
#print accuracy including outliers
print('mc_accuracy: ', model_prediction_pd['mc_accuracy'].mean())
print('sc_accuracy: ', model_prediction_pd['sc_accuracy'].mean())
#print balanced accuracy
print('mc_balanced_accuracy: ', model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean())
print('sc_balanced_accuracy: ', model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean())
#use plotly to plot confusion matrix based on mc classes
mc_confusion_matrix = model_prediction_pd.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack()
mc_confusion_matrix = mc_confusion_matrix.fillna(0)
mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1)
mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,4))
#for columns not in rows, sum the column values and add them to a new column called 'other'
other_col = [0]*mc_confusion_matrix.shape[0]
for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]:
other_col += mc_confusion_matrix[i]
mc_confusion_matrix['other'] = other_col
#add an other row with all zeros
mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1]
#drop all columns not in rows
mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1)
#plot confusion matrix
fig = go.Figure(data=go.Heatmap(
z=mc_confusion_matrix.values,
x=mc_confusion_matrix.columns,
y=mc_confusion_matrix.index,
colorscale='Blues',
hoverongaps = False))
#add z values to heatmap
for i in range(len(mc_confusion_matrix.index)):
for j in range(len(mc_confusion_matrix.columns)):
fig.add_annotation(text=str(round(mc_confusion_matrix.values[i][j],2)), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i],
showarrow=False, font_size=font_size, font_color='black')
fig.update_layout(
title='Confusion matrix for mc classes - ' + model + ' - ' + 'mc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean(),2)) \
+ ' - ' + 'sc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean(),2)) + '<br>' + \
'percent false novel: ' + str(round(num_outliers/num_rows,2)),
xaxis_nticks=36)
#label x axis and y axis
fig.update_xaxes(title_text='Predicted mc class')
fig.update_yaxes(title_text='Actual mc class')
if save_figs:
#save plot
if seperate_outliers:
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.png')
#save svg
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.svg')
else:
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.png')
#save svg
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.svg')
print('\n')
if __name__ == '__main__':
#####################################################################################################################
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'
LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'
path_to_models = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'
trained_on = 'full' #id or full
save_figs = True
infer_aa = infer_relaxed_mirna = infer_hico = infer_ood = infer_other_affixes = infer_random = infer_fused = infer_na = infer_loco = False
split = 'infer_hico'#sys.argv[1]
print(f'Running inference for {split}')
if split == 'infer_aa':
infer_aa = True
elif split == 'infer_relaxed_mirna':
infer_relaxed_mirna = True
elif split == 'infer_hico':
infer_hico = True
elif split == 'infer_ood':
infer_ood = True
elif split == 'infer_other_affixes':
infer_other_affixes = True
elif split == 'infer_random':
infer_random = True
elif split == 'infer_fused':
infer_fused = True
elif split == 'infer_na':
infer_na = True
elif split == 'infer_loco':
infer_loco = True
#####################################################################################################################
#only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico should be true
if sum([infer_aa,infer_relaxed_mirna,infer_hico,infer_ood,infer_other_affixes,infer_random,infer_fused,infer_na,infer_loco]) != 1:
raise Exception('Only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico or infer_other_affixes or infer_random or infer_fused or infer_na should be true')
#set fig_prefix
if infer_aa:
fig_prefix = '5\'A-affixes'
elif infer_other_affixes:
fig_prefix = 'other_affixes'
elif infer_relaxed_mirna:
fig_prefix = 'Relaxed-miRNA'
elif infer_hico:
fig_prefix = 'LC-familiar'
elif infer_ood:
fig_prefix = 'LC-novel'
elif infer_random:
fig_prefix = 'Random'
elif infer_fused:
fig_prefix = 'Fused'
elif infer_na:
fig_prefix = 'NA'
elif infer_loco:
fig_prefix = 'LOCO'
infer_df = load(LC_path)
if isinstance(infer_df,AnnData):
infer_df = infer_df.var
infer_df.set_index('sequence',inplace=True)
sc_to_mc_mapper_dict = load(mapping_dict_path)
#select hico sequences
hico_seqs = infer_df.index[infer_df['hico']].tolist()
art_affix_seqs = infer_df[~infer_df['five_prime_adapter_filter']].index.tolist()
if infer_hico:
hico_seqs = hico_seqs
if infer_aa:
hico_seqs = art_affix_seqs
if infer_other_affixes:
hico_seqs = infer_df[~infer_df['hbdx_spikein_affix_filter']].index.tolist()
if infer_na:
hico_seqs = infer_df[infer_df.subclass_name == 'no_annotation'].index.tolist()
if infer_loco:
hico_seqs = infer_df[~infer_df['hico']][infer_df.subclass_name != 'no_annotation'].index.tolist()
#for mirnas
if infer_relaxed_mirna:
#subclass name must contain miR, let, Let and not contain ; and that are not hico
mirnas_seqs = infer_df[infer_df.subclass_name.str.contains('miR') | infer_df.subclass_name.str.contains('let')][~infer_df.subclass_name.str.contains(';')].index.tolist()
#remove the ones that are true in ad.hico column
hico_seqs = list(set(mirnas_seqs).difference(set(hico_seqs)))
#novel mirnas
#mirnas_not_in_train_mask = (ad['hico']==True).values * ~(ad['subclass_name'].isin(mirna_train_sc)).values * (ad['small_RNA_class_annotation'].isin(['miRNA']))
#hicos = ad[mirnas_not_in_train_mask].index.tolist()
if infer_random:
#create random sequences
random_seqs = []
while len(random_seqs) < 200:
random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(18,30)))
if random_seq not in random_seqs:
random_seqs.append(random_seq)
hico_seqs = random_seqs
if infer_fused:
hico_seqs = get_fused_seqs(hico_seqs,num_sequences=200)
#hico_seqs = ad[ad.subclass_name.str.contains('mir')][~ad.subclass_name.str.contains(';')]['subclass_name'].index.tolist()
hico_seqs = [seq for seq in hico_seqs if len(seq) <= 30]
#set cuda 1
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
#run prediction
prediction_pd = predict_transforna_all_models(hico_seqs,trained_on=trained_on,path_to_models=path_to_models)
prediction_pd['split'] = fig_prefix
#the if condition here is to make sure to filter seqs with sub classes not used in training
if not infer_ood and not infer_relaxed_mirna and not infer_hico:
prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv')
if infer_aa or infer_other_affixes or infer_random or infer_fused:
for model in prediction_pd.Model.unique():
num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?'])
print(f'Number of non novel sequences for {model} is {num_non_novel}')
print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the lower the better')
else:
if infer_na or infer_loco:
#print number of Is Familiar per model
for model in prediction_pd.Model.unique():
num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?'])
print(f'Number of non novel sequences for {model} is {num_non_novel}')
print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the higher the better')
print('\n')
else:
#only to get classes used for training
prediction_single_pd = predict_transforna(hico_seqs[0],model='Seq',logits_flag = True,trained_on=trained_on,path_to_models=path_to_models)
sub_classes_used_for_training = prediction_single_pd.columns.tolist()
mc_classes_df,sc_classes_df = get_mc_sc(infer_df,hico_seqs,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag=infer_ood)
if infer_ood:
for model in prediction_pd.Model.unique():
#filter sequences in prediction_pd to only include sequences in sc_classes_df
curr_prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)]
#filter curr_prediction toonly include model
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd.Model == model]
num_seqs = curr_prediction_pd.shape[0]
#filter Is Familiar
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Is Familiar?']]
#filter sc_classes_df to only include sequences in curr_prediction_pd
curr_sc_classes_df = sc_classes_df[sc_classes_df.index.isin(curr_prediction_pd['Sequence'].values)]
#correct labels and remove the correct labels from the curr_prediction_pd
curr_prediction_pd['Net-Label'] = correct_labels(curr_prediction_pd['Net-Label'].values,curr_sc_classes_df[0].values,sc_to_mc_mapper_dict)
#filter rows in curr_prediction where Labels is equal to sc_classes_df[0]
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Net-Label'].values != curr_sc_classes_df[0].values]
num_non_novel = len(curr_prediction_pd)
print(f'Number of non novel sequences for {model} is {num_non_novel}')
print(f'Percent non novel for {model} is {num_non_novel/num_seqs}, the lower the better')
print('\n')
else:
#filter prediction_pd to include only sequences in prediction_pd
#compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=False,fig_prefix = fig_prefix,save_figs=save_figs)
compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=True,fig_prefix = fig_prefix,save_figs=save_figs)
if infer_ood or infer_relaxed_mirna or infer_hico:
prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)]
#save lev_dist_df
prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv')
|