import json import streamlit as st from pathlib import Path from Bio import SeqIO from io import StringIO, BytesIO from collections import namedtuple import base64 import pandas as pd from predict import predict_single_seq, process_output from src.utils import count_seqs, generate_fcgr from src.fcgr import FCGR from src.model_loader import ModelLoader # load CLADES (order output models) with open("trained-models/config.json") as fp: config = json.load(fp) CLADES = config["CLADES"] Result = namedtuple("Result", ["id","description","clade","score"]) @st.cache(allow_output_mutation=True) def load_model(kmer, order_output): n_output = len(order_output) path_weights = list(Path(f"trained-models/{kmer}mers").rglob("*.hdf5"))[0] loader = ModelLoader() model = loader(f"resnet50_{kmer}mers", n_output, path_weights) return model def get_image_download_link(img): """Generates a link allowing the PIL image to be downloaded in: PIL image out: href string """ buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() href = f'Download FCGR' return href @st.cache def convert_df(df): return df.to_csv().encode('utf-8') # --- Sidebar --- with st.sidebar: button = st.button(label="Run") multifasta = st.checkbox("Multifasta", value=False, help="If selected, only inferences will be computed for all the sequences in the fasta file.") st.write("Options") kmer = st.slider(label="kmer", min_value=6, max_value=9, value=6, help="There is one trained model for each kmer" ) # Instantiate FCGR generator fcgr = FCGR(kmer) # Load model for selected kmer with st.spinner(f"Loading model for {kmer}mers..."): model = load_model(kmer, CLADES) st.success(f"Model for {kmer}mers loaded!") if multifasta: st.warning("For multifasta files, only the predictions will be computed") st.info("Deselecting 'Multifasta' box means that all other available analysis will be computed for the first sequence only.") else: st.warning("If Multifasta is not selected, only the first sequence in the uploaded fasta file will be considered") # --- Main panel --- st.title("Sars-cov-2 classification with FCGR") st.text("Demo for the classification of Sars-Cov-2 sequences into 11 GISAID clades:") st.text(", ".join(CLADES)) st.text("A sequence is represented by its Frequency matrix of Chaos Game Representation") st.text("Which is then fed to a Convolutional Neural Network.") # load fasta file uploaded_file = st.file_uploader(label="Load fasta file") if uploaded_file is not None and button: # count number of sequences in the fasta file stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) n_seqs = count_seqs(stringio) # read and parse fasta file stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) records = SeqIO.parse(stringio , "fasta") # --- Multifasta case: compute only inference for all the sequences --- if multifasta is True: st.info(f"Computing inference on {n_seqs} sequences") progress_bar = st.progress(0) step_pg = 1./n_seqs results=[] current_seq = 0 for fasta in records: print(fasta.id) current_seq +=1 print(current_seq) pred = predict_single_seq(str(fasta.seq), fcgr, model) label, score = process_output(pred, CLADES) results.append(Result(fasta.id, fasta.description, label, score)) # update progress bar progress_bar.progress(current_seq*step_pg) results_df = pd.DataFrame(results) st.dataframe(results_df) # Download results csv = convert_df(results_df) st.download_button( "Download results", csv, "results.csv", "text/csv", key="download-csv" ) # --- All for one sequence --- else: fasta = next(records) with st.spinner("Inference..."): pred = predict_single_seq(str(fasta.seq), fcgr, model) label, score = process_output(pred, CLADES) st.success("Done!") st.write("### Results ") st.dataframe(pd.DataFrame([Result(fasta.id, fasta.description, label, score)])) st.write("Prediction: ", label) st.write("Confidence: ", score) # To generate the image to show with st.spinner("Plotting FCGR"): img = generate_fcgr(kmer, fasta, fcgr) # Show FCGR st.image( image=img, caption="FCGR \n Predicted Clade: {} | Confidence: {:.3f}".format(label, score), use_column_width="auto", width=20) st.markdown(get_image_download_link(img), unsafe_allow_html=True) #st.snow()