Nazarshia2889
quick second update
3905f18
raw history blame
No virus
2.68 kB
import streamlit as st
from transformers import TFAutoModelForSequenceClassification
from transformers import AutoTokenizer
import pandas as pd
import tensorflow as tf
# title
st.title('Ravens AI')
# text input with label
sequence = st.text_input('Enter Amino Acid Sequence')
model_type = st.radio(
"Choose Linear Epitope Classifier",
('Linear T-Cells (MHC Class I Restriction)', 'Linear T-Cells (MHC Class II Restriction)', 'Linear B-Cell'))
# windows length slider
# length = st.slider('Window Length', 1, 50, 10)
threshold = st.slider('Probability Threshold', 0.0, 1.0, 0.5)
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# try:
if model_type == 'Linear T-Cells (MHC Class I Restriction)':
try:
model = TFAutoModelForSequenceClassification.from_pretrained('classifier')
except:
st.warning("We're experiencing server issues. Please try again later!", icon="⚠️")
elif model_type == 'Linear T-Cells (MHC Class II Restriction)':
try:
model = TFAutoModelForSequenceClassification.from_pretrained('classifier2')
except:
st.warning("We're experiencing server issues. Please try again later!", icon="⚠️")
elif model_type == 'Linear B-Cell':
try:
model = TFAutoModelForSequenceClassification.from_pretrained('bcell')
except:
st.warning("We're experiencing server issues. Please refresh and try again!", icon="⚠️")
try:
# submit button
if st.button('Submit'):
locations = []
peptide_name = sequence
peptide = tokenizer(peptide_name, return_tensors="tf")
output = model(peptide)
locations.append([peptide_name, output.logits.numpy()[0][0]])
locations = pd.DataFrame(locations, columns = ['Peptide', 'Probability'])
# display table with sequence and probability as the headers
def color_survived(x: float): # x between 0 and 1
# red to green scale based on x
# 0 -> red
# 0.5 -> clear
# 1 -> green
# red
if x < threshold:
r = 179
g = 40
b = 2
# green
else:
r = 18
g = 150
b = 6
return f'background-color: rgb({r}, {g}, {b})'
st.table(locations.style.applymap(color_survived, subset=['Probability']))
except NameError:
st.warning("We're experiencing server issues. Please refresh and try again!", icon="⚠️")
# except InvalidArgumentError:
# st.warning("We're experiencing server issues. Please try again later!", icon="⚠️")