awacke1 commited on
Commit
072885d
·
1 Parent(s): 1112873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -68
app.py CHANGED
@@ -1,70 +1,54 @@
1
  import streamlit as st
2
- from bertopic import BERTopic
3
- import streamlit.components.v1 as components
4
- from sentence_transformers import SentenceTransformer
5
- from umap import UMAP
6
- from hdbscan import HDBSCAN
7
-
8
- # Initialize BERTopic model
9
- model = BERTopic()
10
-
11
- st.subheader("Topic Modeling with Topic-Wizard")
12
- uploaded_file = st.file_uploader("Choose a text file", type=["txt"])
13
-
14
- if uploaded_file is not None:
15
- st.session_state["text"] = uploaded_file.getvalue().decode("utf-8")
16
-
17
- st.write("OR")
18
-
19
- input_text = st.text_area(
20
- label="Enter text separated by newlines",
21
- value="",
22
- key="text",
23
- height=150,
24
- )
25
-
26
- button = st.button("Get Segments")
27
-
28
- if button and (uploaded_file is not None or input_text != ""):
29
- if uploaded_file is not None:
30
- texts = st.session_state["text"].split("\n")
31
  else:
32
- texts = input_text.split("\n")
33
-
34
- # Fit BERTopic model
35
- topics, probabilities = model.fit_transform(texts)
36
-
37
- # Create embeddings
38
- embeddings_model = SentenceTransformer("distilbert-base-nli-mean-tokens")
39
- embeddings = embeddings_model.encode(texts)
40
-
41
- # Reduce dimensionality of embeddings using UMAP
42
- umap_model = UMAP(n_neighbors=15, n_components=2, metric="cosine")
43
- umap_embeddings = umap_model.fit_transform(embeddings)
44
-
45
- # Cluster topics using HDBSCAN
46
- cluster = HDBSCAN(
47
- min_cluster_size=15, metric="euclidean", cluster_selection_method="eom"
48
- ).fit(umap_embeddings)
49
-
50
- # Visualize BERTopic results with Streamlit
51
- st.title("BERTopic Visualization")
52
-
53
- # Display top N most representative topics and their documents
54
- num_topics = st.sidebar.slider("Select number of topics to display", 1, 20, 5, 1)
55
- topic_words = model.get_topics()
56
- topic_freq = model.get_topic_freq().head(num_topics + 1) # Add 1 to exclude -1 (outliers topic)
57
- for _, row in topic_freq.iterrows():
58
- topic_id = row["Topic"]
59
- if topic_id == -1:
60
- continue # Skip the outliers topic
61
- st.write(f"## Topic {topic_id}")
62
- st.write("Keywords:", ", ".join(topic_words[topic_id]))
63
- st.write("Documents:")
64
- doc_ids = [idx for idx, topic in enumerate(topics) if topic == topic_id][:5]
65
- for doc in doc_ids:
66
- st.write("-", texts[doc])
67
-
68
- # Display topic clusters
69
- st.write("## Topic Clusters")
70
- components.html(cluster.labels_.tolist(), height=500, width=800)
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import bertopic
4
+ import plotly.express as px
5
+
6
+ st.set_page_config(page_title="Topic Modeling with Bertopic")
7
+
8
+ # Function to read the uploaded file and return a Pandas DataFrame
9
+ def read_file(file):
10
+ if file.type == 'text/plain':
11
+ df = pd.read_csv(file, header=None, names=['data'])
12
+ elif file.type == 'text/csv':
13
+ df = pd.read_csv(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  else:
15
+ st.error("Unsupported file format. Please upload a TXT or CSV file.")
16
+ return None
17
+ return df
18
+
19
+ # Sidebar to upload the file
20
+ st.sidebar.title("Upload File")
21
+ file = st.sidebar.file_uploader("Choose a file", type=["txt", "csv"])
22
+
23
+ # Perform topic modeling when the user clicks the "Visualize" button
24
+ if st.sidebar.button("Visualize"):
25
+
26
+ # Read the uploaded file
27
+ df = read_file(file)
28
+ if df is None:
29
+ st.stop()
30
+
31
+ # Perform topic modeling using Bertopic
32
+ model = bertopic.Bertopic()
33
+ topics, probabilities = model.fit_transform(df['data'])
34
+
35
+ # Create a plot of the topic distribution
36
+ fig = px.histogram(x=topics, nbins=max(topics)+1, color_discrete_sequence=px.colors.qualitative.Pastel)
37
+ fig.update_layout(
38
+ title="Distribution of Topics",
39
+ xaxis_title="Topic",
40
+ yaxis_title="Count",
41
+ )
42
+ st.plotly_chart(fig)
43
+
44
+ # Display the top words in each topic
45
+ st.write("Top words in each topic:")
46
+ for topic_id in range(max(topics)+1):
47
+ st.write(f"Topic {topic_id}: {model.get_topic(topic_id)}")
48
+
49
+ # Display the clusters
50
+ st.write("Clusters:")
51
+ for cluster_id, docs in model.get_clusters().items():
52
+ st.write(f"Cluster {cluster_id}:")
53
+ for doc in docs:
54
+ st.write(f"\t{doc}")