Debit_Analysis / app.py
sofarikasid's picture
Synced repo using 'sync_with_huggingface' Github Action
1956315
import re
import torch
import pandas as pd
import nltk
import streamlit as st
import matplotlib.pyplot as plt
from nltk.tokenize import word_tokenize
from transformers import BertTokenizer, BertForSequenceClassification
from wordcloud import WordCloud
# Download NLTK data
nltk.download('punkt')
# Load label dictionary
label_dict = {
0: 'Transfer', 1: 'Other', 2: 'Food', 3: 'Grocery', 4: 'Entertainment',
5: 'Cash Withdrawal', 6: 'Shopping', 7: 'Gas', 8: 'Liquor/Club/Smoke',
9: 'Credit Card', 10: 'Transportation'
}
# Load and preprocess DataFrame
@st.cache_data
def load_dataframe(file_path):
return pd.read_csv(file_path, index_col=False)
# Clean text using BERT requirements
def clean_text_BERT(text):
text = text.lower()
text = re.sub(r'[^\w\s]|https?://\S+|www\.\S+|https?:/\S+|[^\x00-\x7F]+|\d+', '', text.strip())
text_list = [token for token in word_tokenize(text) if len(token) > 1]
result = ' '.join(text_list)
return result
# Setup BERT model for classification
@st.cache_resource
def setup_bert_model(num_classes, device):
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_classes)
state_dict = torch.load('finiancal_class.pth', map_location=device)
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
model.load_state_dict(filtered_state_dict)
model.to(device)
model.eval()
return model
# Predict the class using the BERT model
def predict_class(model, tokenizer, text, device):
inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
predicted_index = torch.argmax(outputs.logits, dim=1).item()
return label_dict[predicted_index]
# Generate and display word cloud
def display_word_cloud(data, selected_column, predicted_class):
class_data = data[data['Predicted_Class'] == predicted_class]
class_text = " ".join(class_data[selected_column])
wordcloud = WordCloud(width=800, height=400, background_color='white').generate(class_text)
plt.figure(figsize=(10, 5))
plt.imshow(wordcloud, interpolation='bilinear')
plt.title(f"Word Cloud for '{predicted_class}' Class")
plt.axis('off')
st.pyplot(plt)
# Main Streamlit app
def main():
st.set_page_config(page_title="Text Classification", page_icon="πŸ“š")
st.title("πŸ“šπŸ”— Text Classification with BERT")
uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
if uploaded_file is not None:
data = load_dataframe(uploaded_file)
st.sidebar.header("Settings")
selected_column = st.sidebar.selectbox("Select a text column", data.columns)
st.write(f"Predicting classes for '{selected_column}' column...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = setup_bert_model(num_classes=len(label_dict), device=device)
predictions = []
for text in data[selected_column]:
cleaned_input = clean_text_BERT(text)
predicted_label = predict_class(model, tokenizer, cleaned_input, device)
predictions.append(predicted_label)
data['Predicted_Class'] = predictions
st.subheader("Predictions and Class Distribution")
st.dataframe(data)
st.sidebar.subheader("Filter and Visualization")
selected_filter = st.sidebar.radio("Select a filter", ["All", "Debit", "Credit"])
if selected_filter == "Debit":
filtered_data = data[data[selected_column].str.contains("Debit", case=False, regex=True)]
elif selected_filter == "Credit":
filtered_data = data[data[selected_column].str.contains("Credit", case=False, regex=True)]
else:
filtered_data = data
st.subheader("Filtered Predicted Class Distribution")
class_distribution = filtered_data['Predicted_Class'].value_counts()
if st.checkbox("Show Distribution"):
st.bar_chart(class_distribution)
# Word Cloud Visualization
st.subheader("Word Cloud Visualization")
selected_class = st.selectbox("Select a class for Word Cloud visualization", class_distribution.index)
if st.button("Generate Word Cloud"):
display_word_cloud(filtered_data, selected_column, selected_class)
if __name__ == "__main__":
main()