P8_PPS / app.py
CASY85's picture
Create app.py
5020dd3 verified
import streamlit as st
import joblib
from sentence_transformers import SentenceTransformer
# Load the pre-trained embedding model
@st.cache_resource # Cache the embedding model to save loading time
def load_embedding_model():
return SentenceTransformer('neuml/pubmedbert-base-embeddings')
# Load the MLP model
@st.cache_resource # Cache the loaded model
def load_mlp_model():
with open("MLP.pkl", "rb") as file:
return joblib.load(file)
# Embed text
def get_embeddings(title, abstract, embedding_model):
# Concatenate title and abstract
combined_text = title + " " + abstract
return embedding_model.encode(combined_text)
# Main Streamlit app
def main():
st.title("MLP Predictor for Titles and Abstracts")
# Input fields
title = st.text_input("Enter the Title:")
abstract = st.text_area("Enter the Abstract:")
# Load models
embedding_model = load_embedding_model()
mlp_model = load_mlp_model()
# Predict button
if st.button("Predict Label"):
if title.strip() == "" or abstract.strip() == "":
st.error("Both Title and Abstract are required!")
else:
# Get embeddings
embeddings = get_embeddings(title, abstract, embedding_model)
# Make prediction
prediction = mlp_model.predict([embeddings])[0] # Input should be a 2D array
# Display result
st.success(f"The predicted label is: {prediction}")
if __name__ == "__main__":
main()