import streamlit as st import joblib import pandas as pd from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.svm import OneClassSVM # Load the saved model model = joblib.load('one_class_svm_model.pkl') vectorizer = joblib.load('tfidf_vectorizer.pkl') # Load the vectorizer used for training # Define a function for making predictions def predict(n=5, retrain=False, positive_labelled_file=None, unlabelled_labelled_file=None): if retrain and positive_labelled_file is not None and unlabelled_labelled_file is not None: # Load the positive labelled and unlabelled data positive_labelled_info = pd.read_csv(positive_labelled_file) unlabelled_labelled = pd.read_csv(unlabelled_labelled_file) # Combine title and abstract for both datasets positive_labelled_info['text'] = positive_labelled_info['title'] + ' ' + positive_labelled_info['abstract'] unlabelled_labelled['text'] = unlabelled_labelled['title'] + ' ' + unlabelled_labelled['abstract'] # Feature extraction for positive labelled data X_pos = vectorizer.transform(positive_labelled_info['text']) # Fit the model on the new positive labelled data model.fit(X_pos) # Predict the class of unlabelled data X_unlabelled = vectorizer.transform(unlabelled_labelled['text']) predictions = model.predict(X_unlabelled) # Return top n positive papers from unlabelled data positive_indices = predictions == 1 top_n_positive_papers = unlabelled_labelled.loc[positive_indices].head(n) # Return titles and IDs of selected papers selected_paper_info = top_n_positive_papers[['id', 'title']] return selected_paper_info # Define the input components n_input = st.slider("Top N papers to return:", min_value=1, max_value=20, value=5) retrain_input = st.checkbox("Retrain model?") positive_labelled_file = st.file_uploader("Upload Positive Labelled Data:", type=['csv']) unlabelled_labelled_file = st.file_uploader("Upload Unlabelled Labelled Data:", type=['csv']) # Check if the user has uploaded files and process them if available if positive_labelled_file is not None and unlabelled_labelled_file is not None: # Call the predict function with uploaded data and display the result result = predict(n_input, retrain_input, positive_labelled_file, unlabelled_labelled_file) st.write(result) else: st.info("Please upload the CSV files.")