import torch import pandas as pd import streamlit as st from transformers import AutoTokenizer, AutoConfig, AutoModel from gpt_download import download_and_load_gpt2 from previous_scope import GPTModel, load_weights_into_gpt def load_model(model_path): # Load the model configuration config = AutoConfig.from_pretrained('bert-base-uncased') # replace with your model's config # Initialize a new model of the same architecture model = AutoModel.from_config(config) # replace AutoModel with your model's class if needed # Load the weights into the new model model_state_dict = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(model_state_dict) return model def classify_review(text, model, tokenizer, max_length=512): # Tokenize the text inputs = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') # Run the text through the model outputs = model(**inputs) # Get the predicted label _, predicted = torch.max(outputs.logits, dim=1) predicted_label = 'Positive' if predicted.item() == 1 else 'Negative' return predicted_label def main(): st.title("Text Classification App") # Model path input model_path = "clv__classifier_774M.pth" model = load_model(model_path) st.write("Model loaded successfully!") tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Input options input_option = st.radio("Select input option", ("Single Text Query", "Upload Table")) if input_option == "Single Text Query": # Single text query input text_query = st.text_input("Enter text query") if st.button("Classify"): if text_query: # Classify the text query predicted_label = classify_review(text_query, model, tokenizer, max_length=512) st.write("Predicted Label:") st.write(predicted_label) else: st.warning("Please enter a text query.") elif input_option == "Upload Table": # Table upload uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"]) if uploaded_file is not None: # Read the uploaded file if uploaded_file.name.endswith(".csv"): df = pd.read_csv(uploaded_file) else: df = pd.read_excel(uploaded_file) # Select the text column text_column = st.selectbox("Select the text column", df.columns) # Classify the texts in the selected column predicted_labels = [] for text in df[text_column]: predicted_label = classify_review(text, model, tokenizer, max_length=512) predicted_labels.append(predicted_label) # Add the predicted labels to the DataFrame df["Predicted Label"] = predicted_labels # Display the DataFrame with predicted labels st.write(df) if __name__ == "__main__": main()