import os import numpy as np import pandas as pd import nltk import pickle from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.pipeline import make_pipeline import streamlit.components.v1 as components from transformers import pipeline from sklearn.svm import SVC from sklearn.preprocessing import LabelEncoder import pickle import streamlit as st # Function to load the pre-trained model @st.cache(allow_output_mutation=True) def load_pretrained_model(): try: feature_file='tfidf_scorer.pkl' with open(feature_file,'rb') as f: feature_extractor=pickle.load(f) f.close() encoder_file='encoder.pkl' with open(encoder_file,'rb') as f: encoder=pickle.load(f) f.close() model_file='classifier.pkl' with open(model_file,'rb') as f: model=pickle.load(f) f.close() pipe=pipeline("token-classification",model="hatmimoha/arabic-ner",aggregation_strategy='max') return feature_extractor,encoder,model,pipe except FileNotFoundError: st.error("Pre-trained model not found. Please make sure the model file exists.") st.stop() # Streamlit App st.title("Text Classification App") st.write("This app demonstrates text classification using a pre-trained scikit-learn-based machine learning model.") # Information about the app st.sidebar.title("App Information") st.sidebar.info( """This Streamlit app showcases text classification using a pre-trained scikit-learn-based machine learning model on Arabic texts. The data is sourced is from Arabic news articles organized into 3 balanced categories from www.alkhaleej.ae Labels are categorized in: Medical,Sports,Tech. Enter text in the provided area, and the model will predict the label.""" ) # Load the pre-trained model tfidf,encode,trained_model,pipeline_obj = load_pretrained_model() # User input for text classification user_text = st.text_area("Enter text for classification:") # Classify user input if user_text: tokens_new=nltk.wordpunct_tokenize(user_text) tokens_corrected=[i for i in tokens_new if len(i)>1] tfidf_tokens=' '.join(tokens_corrected) x_test=tfidf.transform([tfidf_tokens]) predicted=trained_model.predict(x_test) predicted_class=encode.inverse_transform(predicted)[0] st.write(f"Predicted Label: {predicted_class}") if st.button("Extract entities"): with st.spinner('Calculating...'): entities=pipeline_obj(user_text) if len(entities)>0: entity_df=pd.DataFrame(entities) st.table(entity_df[["entity_group","word"]]) else: st.write("No entities found") """if st.button("Perform explainability analysis"): : c=make_pipeline(tfidf,trained_model) explainer = LimeTextExplainer(class_names=np.array(["Medical","Sports","Tech,Others"]),random_state=42) exp = explainer.explain_instance(user_text, c.predict_proba, num_features=20, top_labels=3) components.html(exp.as_html(), height=800) #top_labels=exp.available_labels()"""