Spaces:
Sleeping
Sleeping
import re | |
import torch | |
import pandas as pd | |
import nltk | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from nltk.tokenize import word_tokenize | |
from transformers import BertTokenizer, BertForSequenceClassification | |
from wordcloud import WordCloud | |
# Download NLTK data | |
nltk.download('punkt') | |
# Load label dictionary | |
label_dict = { | |
0: 'Transfer', 1: 'Other', 2: 'Food', 3: 'Grocery', 4: 'Entertainment', | |
5: 'Cash Withdrawal', 6: 'Shopping', 7: 'Gas', 8: 'Liquor/Club/Smoke', | |
9: 'Credit Card', 10: 'Transportation' | |
} | |
# Load and preprocess DataFrame | |
def load_dataframe(file_path): | |
return pd.read_csv(file_path, index_col=False) | |
# Clean text using BERT requirements | |
def clean_text_BERT(text): | |
text = text.lower() | |
text = re.sub(r'[^\w\s]|https?://\S+|www\.\S+|https?:/\S+|[^\x00-\x7F]+|\d+', '', text.strip()) | |
text_list = [token for token in word_tokenize(text) if len(token) > 1] | |
result = ' '.join(text_list) | |
return result | |
# Setup BERT model for classification | |
def setup_bert_model(num_classes, device): | |
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_classes) | |
state_dict = torch.load('finiancal_class.pth', map_location=device) | |
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()} | |
model.load_state_dict(filtered_state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
# Predict the class using the BERT model | |
def predict_class(model, tokenizer, text, device): | |
inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
predicted_index = torch.argmax(outputs.logits, dim=1).item() | |
return label_dict[predicted_index] | |
# Generate and display word cloud | |
def display_word_cloud(data, selected_column, predicted_class): | |
class_data = data[data['Predicted_Class'] == predicted_class] | |
class_text = " ".join(class_data[selected_column]) | |
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(class_text) | |
plt.figure(figsize=(10, 5)) | |
plt.imshow(wordcloud, interpolation='bilinear') | |
plt.title(f"Word Cloud for '{predicted_class}' Class") | |
plt.axis('off') | |
st.pyplot(plt) | |
# Main Streamlit app | |
def main(): | |
st.set_page_config(page_title="Text Classification", page_icon="π") | |
st.title("ππ Text Classification with BERT") | |
uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"]) | |
if uploaded_file is not None: | |
data = load_dataframe(uploaded_file) | |
st.sidebar.header("Settings") | |
selected_column = st.sidebar.selectbox("Select a text column", data.columns) | |
st.write(f"Predicting classes for '{selected_column}' column...") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = setup_bert_model(num_classes=len(label_dict), device=device) | |
predictions = [] | |
for text in data[selected_column]: | |
cleaned_input = clean_text_BERT(text) | |
predicted_label = predict_class(model, tokenizer, cleaned_input, device) | |
predictions.append(predicted_label) | |
data['Predicted_Class'] = predictions | |
st.subheader("Predictions and Class Distribution") | |
st.dataframe(data) | |
st.sidebar.subheader("Filter and Visualization") | |
selected_filter = st.sidebar.radio("Select a filter", ["All", "Debit", "Credit"]) | |
if selected_filter == "Debit": | |
filtered_data = data[data[selected_column].str.contains("Debit", case=False, regex=True)] | |
elif selected_filter == "Credit": | |
filtered_data = data[data[selected_column].str.contains("Credit", case=False, regex=True)] | |
else: | |
filtered_data = data | |
st.subheader("Filtered Predicted Class Distribution") | |
class_distribution = filtered_data['Predicted_Class'].value_counts() | |
if st.checkbox("Show Distribution"): | |
st.bar_chart(class_distribution) | |
# Word Cloud Visualization | |
st.subheader("Word Cloud Visualization") | |
selected_class = st.selectbox("Select a class for Word Cloud visualization", class_distribution.index) | |
if st.button("Generate Word Cloud"): | |
display_word_cloud(filtered_data, selected_column, selected_class) | |
if __name__ == "__main__": | |
main() | |