Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import pandas as pd | |
| import plotly.express as px | |
| import numpy as np | |
| from predict import process_target_data, get_average_embedding # Import your function | |
| st.set_page_config(page_title="CoV-SNN", page_icon="🧬") | |
| COLOR_BY = "Escape Potential" | |
| def main(): | |
| # Rebuild trigger | |
| st.title("CoV-SNN") | |
| st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!") | |
| # Read the README.md file | |
| try: | |
| with open("INSTRUCTIONS.md", "r") as readme_file: | |
| readme_text = readme_file.read() | |
| except FileNotFoundError: | |
| readme_text = "INSTRUCTIONS.md file not found." | |
| option = st.radio( | |
| "Select a reference embedding:", | |
| ["Omicron", "Other"], | |
| captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],) | |
| # File uploader for the reference.csv | |
| reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.", | |
| type=["csv"], | |
| disabled=option == "Omicron") | |
| # File uploader for the target.csv | |
| target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.", | |
| type=["csv"], | |
| disabled = option == "Other" and reference_file is None) | |
| if target_file is not None and (option == "Omicron" or reference_file is not None): | |
| if option == "Omicron": | |
| # Assuming you have a pre-defined average_embedding | |
| average_embedding = np.load("average_omicron_embedding.npy") | |
| print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}") | |
| else: | |
| with st.spinner('Calculating average embedding...'): | |
| ref_df = pd.read_csv(reference_file) | |
| average_embedding = get_average_embedding(ref_df) | |
| with st.spinner('Predicting escape potentials...'): | |
| # Read the uploaded CSV file into a DataFrame | |
| target_dataset = pd.read_csv(target_file) | |
| # Process the target dataset | |
| results_df = process_target_data(average_embedding, target_dataset) | |
| print("Processing target data completed.") | |
| # Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one | |
| results_df['Escape Potential'] = results_df['rank_scgra'].max() + 1 - results_df['rank_scgra'] | |
| hover_data = { | |
| "sc": True, | |
| "grt": False, | |
| "gra": True, | |
| "raw_sc": False, | |
| "raw_grt": False, | |
| "raw_gra": False, | |
| "rank_sc": True, | |
| "rank_grt": False, | |
| "rank_gra": True, | |
| "rank_scgrt": False, | |
| "rank_scgra": True, | |
| "Escape Potential": False | |
| } | |
| if "source" in results_df.columns: | |
| hover_data["source"] = True | |
| # Create scatter plot with manual color assignment | |
| fig = px.scatter( | |
| results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x), | |
| x="gra", | |
| y="sc", | |
| title="CoV-SNN Results", | |
| hover_name="accession_id", | |
| color= COLOR_BY, | |
| color_continuous_scale=["green", "yellow", "red"], | |
| color_discrete_map={ | |
| "Eris": "#C70039", | |
| "New": "#4caf50", | |
| "GPT": "#FFC300" | |
| }, | |
| hover_data=hover_data, | |
| ) | |
| # Hide the colorbar ticks and labels | |
| fig.update_coloraxes( | |
| colorbar=dict( | |
| title=None, | |
| tickvals=[], | |
| ticktext=[], | |
| y=0.5, | |
| len=0.7 | |
| ) | |
| ) | |
| if COLOR_BY == "Escape Potential": | |
| #Hide the legend | |
| fig.update_layout(showlegend=False) | |
| #add your rotated title via annotations | |
| fig.update_layout( | |
| margin=dict(r=110), | |
| annotations=[ | |
| dict( | |
| text="Escape Potential", | |
| font_size=14, | |
| textangle=270, | |
| showarrow=False, | |
| xref="paper", | |
| yref="paper", | |
| x=1.14, | |
| y=0.5 | |
| ) | |
| ] | |
| ) | |
| # Display the plot in Streamlit | |
| st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black") | |
| # Set df columns | |
| cols = ["accession_id", "sc", "gra", "grt", "rank_sc", "rank_gra", "rank_grt", "rank_scgra", "rank_scgrt"] | |
| print(results_df.columns) | |
| if "source" in results_df.columns: | |
| cols = ["source"] + cols | |
| # Display the DataFrame in Streamlit | |
| st.dataframe(results_df[cols], hide_index=True) | |
| # Display the README.md file | |
| st.markdown(readme_text) | |
| if __name__ == "__main__": | |
| main() | |