Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pandas as pd | |
import plotly.express as px | |
import numpy as np | |
from predict import process_target_data, get_average_embedding # Import your function | |
st.set_page_config(page_title="CoV-SNN", page_icon="🧬") | |
COLOR_BY = "Escape Potential" | |
def main(): | |
# Rebuild trigger | |
st.title("CoV-SNN") | |
st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!") | |
# Read the README.md file | |
try: | |
with open("INSTRUCTIONS.md", "r") as readme_file: | |
readme_text = readme_file.read() | |
except FileNotFoundError: | |
readme_text = "INSTRUCTIONS.md file not found." | |
option = st.radio( | |
"Select a reference embedding:", | |
["Omicron", "Other"], | |
captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],) | |
# File uploader for the reference.csv | |
reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.", | |
type=["csv"], | |
disabled=option == "Omicron") | |
# File uploader for the target.csv | |
target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.", | |
type=["csv"], | |
disabled = option == "Other" and reference_file is None) | |
if target_file is not None and (option == "Omicron" or reference_file is not None): | |
if option == "Omicron": | |
# Assuming you have a pre-defined average_embedding | |
average_embedding = np.load("average_omicron_embedding.npy") | |
print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}") | |
else: | |
with st.spinner('Calculating average embedding...'): | |
ref_df = pd.read_csv(reference_file) | |
average_embedding = get_average_embedding(ref_df) | |
with st.spinner('Predicting escape potentials...'): | |
# Read the uploaded CSV file into a DataFrame | |
target_dataset = pd.read_csv(target_file) | |
# Process the target dataset | |
results_df = process_target_data(average_embedding, target_dataset) | |
print("Processing target data completed.") | |
# Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one | |
results_df['Escape Potential'] = results_df['rank_scgra'].max() + 1 - results_df['rank_scgra'] | |
hover_data = { | |
"sc": True, | |
"grt": False, | |
"gra": True, | |
"raw_sc": False, | |
"raw_grt": False, | |
"raw_gra": False, | |
"rank_sc": True, | |
"rank_grt": False, | |
"rank_gra": True, | |
"rank_scgrt": False, | |
"rank_scgra": True, | |
"Escape Potential": False | |
} | |
if "source" in results_df.columns: | |
hover_data["source"] = True | |
# Create scatter plot with manual color assignment | |
fig = px.scatter( | |
results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x), | |
x="gra", | |
y="sc", | |
title="CoV-SNN Results", | |
hover_name="accession_id", | |
color= COLOR_BY, | |
color_continuous_scale=["green", "yellow", "red"], | |
color_discrete_map={ | |
"Eris": "#C70039", | |
"New": "#4caf50", | |
"GPT": "#FFC300" | |
}, | |
hover_data=hover_data, | |
) | |
# Hide the colorbar ticks and labels | |
fig.update_coloraxes( | |
colorbar=dict( | |
title=None, | |
tickvals=[], | |
ticktext=[], | |
y=0.5, | |
len=0.7 | |
) | |
) | |
if COLOR_BY == "Escape Potential": | |
#Hide the legend | |
fig.update_layout(showlegend=False) | |
#add your rotated title via annotations | |
fig.update_layout( | |
margin=dict(r=110), | |
annotations=[ | |
dict( | |
text="Escape Potential", | |
font_size=14, | |
textangle=270, | |
showarrow=False, | |
xref="paper", | |
yref="paper", | |
x=1.14, | |
y=0.5 | |
) | |
] | |
) | |
# Display the plot in Streamlit | |
st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black") | |
# Set df columns | |
cols = ["accession_id", "sc", "gra", "grt", "rank_sc", "rank_gra", "rank_grt", "rank_scgra", "rank_scgrt"] | |
print(results_df.columns) | |
if "source" in results_df.columns: | |
cols = ["source"] + cols | |
# Display the DataFrame in Streamlit | |
st.dataframe(results_df[cols], hide_index=True) | |
# Display the README.md file | |
st.markdown(readme_text) | |
if __name__ == "__main__": | |
main() | |