cov-snn-app / app.py
smtnkc
Rebuild trigger
cfaba36
import streamlit as st
import os
import pandas as pd
import plotly.express as px
import numpy as np
from predict import process_target_data, get_average_embedding # Import your function
st.set_page_config(page_title="CoV-SNN", page_icon="🧬")
COLOR_BY = "Escape Potential"
def main():
# Rebuild trigger
st.title("CoV-SNN")
st.markdown("##### Predict viral escape potential of novel SARS-CoV-2 variants in seconds!")
# Read the README.md file
try:
with open("INSTRUCTIONS.md", "r") as readme_file:
readme_text = readme_file.read()
except FileNotFoundError:
readme_text = "INSTRUCTIONS.md file not found."
option = st.radio(
"Select a reference embedding:",
["Omicron", "Other"],
captions=["Use average embedding of Omicron sequences (Pre-generated)", "Generate average embedding of your own sequences (Takes longer)"],)
# File uploader for the reference.csv
reference_file = st.file_uploader("Upload reference sequences. Make sure the CSV file has ``sequence`` column.",
type=["csv"],
disabled=option == "Omicron")
# File uploader for the target.csv
target_file = st.file_uploader("Upload target sequences. Make sure the CSV file has ``accession_id`` and ``sequence`` columns.",
type=["csv"],
disabled = option == "Other" and reference_file is None)
if target_file is not None and (option == "Omicron" or reference_file is not None):
if option == "Omicron":
# Assuming you have a pre-defined average_embedding
average_embedding = np.load("average_omicron_embedding.npy")
print(f"Average Omicron embedding loaded from file with shape {average_embedding.shape}")
else:
with st.spinner('Calculating average embedding...'):
ref_df = pd.read_csv(reference_file)
average_embedding = get_average_embedding(ref_df)
with st.spinner('Predicting escape potentials...'):
# Read the uploaded CSV file into a DataFrame
target_dataset = pd.read_csv(target_file)
# Process the target dataset
results_df = process_target_data(average_embedding, target_dataset)
print("Processing target data completed.")
# Reverse the rank_sc_sp by subtracting it from the maximum rank value plus one
results_df['Escape Potential'] = results_df['rank_scgra'].max() + 1 - results_df['rank_scgra']
hover_data = {
"sc": True,
"grt": False,
"gra": True,
"raw_sc": False,
"raw_grt": False,
"raw_gra": False,
"rank_sc": True,
"rank_grt": False,
"rank_gra": True,
"rank_scgrt": False,
"rank_scgra": True,
"Escape Potential": False
}
if "source" in results_df.columns:
hover_data["source"] = True
# Create scatter plot with manual color assignment
fig = px.scatter(
results_df.applymap(lambda x: round(x, 6) if isinstance(x, (int, float)) else x),
x="gra",
y="sc",
title="CoV-SNN Results",
hover_name="accession_id",
color= COLOR_BY,
color_continuous_scale=["green", "yellow", "red"],
color_discrete_map={
"Eris": "#C70039",
"New": "#4caf50",
"GPT": "#FFC300"
},
hover_data=hover_data,
)
# Hide the colorbar ticks and labels
fig.update_coloraxes(
colorbar=dict(
title=None,
tickvals=[],
ticktext=[],
y=0.5,
len=0.7
)
)
if COLOR_BY == "Escape Potential":
#Hide the legend
fig.update_layout(showlegend=False)
#add your rotated title via annotations
fig.update_layout(
margin=dict(r=110),
annotations=[
dict(
text="Escape Potential",
font_size=14,
textangle=270,
showarrow=False,
xref="paper",
yref="paper",
x=1.14,
y=0.5
)
]
)
# Display the plot in Streamlit
st.plotly_chart(fig, theme="streamlit", border=True, use_container_width=True, border_color="black")
# Set df columns
cols = ["accession_id", "sc", "gra", "grt", "rank_sc", "rank_gra", "rank_grt", "rank_scgra", "rank_scgrt"]
print(results_df.columns)
if "source" in results_df.columns:
cols = ["source"] + cols
# Display the DataFrame in Streamlit
st.dataframe(results_df[cols], hide_index=True)
# Display the README.md file
st.markdown(readme_text)
if __name__ == "__main__":
main()