ssk / app.py
ssk3232's picture
Upload 3 files
9439064 verified
raw
history blame
2.47 kB
import streamlit as st
import joblib
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import OneClassSVM
# Load the saved model
model = joblib.load('one_class_svm_model.pkl')
vectorizer = joblib.load('tfidf_vectorizer.pkl') # Load the vectorizer used for training
# Define a function for making predictions
def predict(n=5, retrain=False, positive_labelled_file=None, unlabelled_labelled_file=None):
if retrain and positive_labelled_file is not None and unlabelled_labelled_file is not None:
# Load the positive labelled and unlabelled data
positive_labelled_info = pd.read_csv(positive_labelled_file)
unlabelled_labelled = pd.read_csv(unlabelled_labelled_file)
# Combine title and abstract for both datasets
positive_labelled_info['text'] = positive_labelled_info['title'] + ' ' + positive_labelled_info['abstract']
unlabelled_labelled['text'] = unlabelled_labelled['title'] + ' ' + unlabelled_labelled['abstract']
# Feature extraction for positive labelled data
X_pos = vectorizer.transform(positive_labelled_info['text'])
# Fit the model on the new positive labelled data
model.fit(X_pos)
# Predict the class of unlabelled data
X_unlabelled = vectorizer.transform(unlabelled_labelled['text'])
predictions = model.predict(X_unlabelled)
# Return top n positive papers from unlabelled data
positive_indices = predictions == 1
top_n_positive_papers = unlabelled_labelled.loc[positive_indices].head(n)
# Return titles and IDs of selected papers
selected_paper_info = top_n_positive_papers[['id', 'title']]
return selected_paper_info
# Define the input components
n_input = st.slider("Top N papers to return:", min_value=1, max_value=20, value=5)
retrain_input = st.checkbox("Retrain model?")
positive_labelled_file = st.file_uploader("Upload Positive Labelled Data:", type=['csv'])
unlabelled_labelled_file = st.file_uploader("Upload Unlabelled Labelled Data:", type=['csv'])
# Check if the user has uploaded files and process them if available
if positive_labelled_file is not None and unlabelled_labelled_file is not None:
# Call the predict function with uploaded data and display the result
result = predict(n_input, retrain_input, positive_labelled_file, unlabelled_labelled_file)
st.write(result)
else:
st.info("Please upload the CSV files.")