RamAnanth1 commited on
Commit
553b9ec
1 Parent(s): 5601530

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -3
app.py CHANGED
@@ -6,6 +6,8 @@ import pandas as pd
6
  import tqdm
7
 
8
  import cohere
 
 
9
 
10
  from topically import Topically
11
  from bertopic import BERTopic
@@ -47,10 +49,34 @@ st.write("Number of submissions accepted at ICLR 2023:", len(df))
47
 
48
  df_filtered = df[['id', 'content.title', 'content.keywords', 'content.abstract']]
49
  df = df_filtered
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- list_of_abstracts = list(df["content.title"].values)
52
 
 
53
 
 
54
 
55
- x = st.slider('Select a value')
56
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
6
  import tqdm
7
 
8
  import cohere
9
+ import os
10
+
11
 
12
  from topically import Topically
13
  from bertopic import BERTopic
 
49
 
50
  df_filtered = df[['id', 'content.title', 'content.keywords', 'content.abstract']]
51
  df = df_filtered
52
+ if "CO_API_KEY" not in os.environ:
53
+ raise KeyError("CO_API_KEY not found in st.secrets or os.environ. Please set it in "
54
+ ".streamlit/secrets.toml or as an environment variable.")
55
+
56
+ co = cohere.Client(os.environ["CO_API_KEY"])
57
+ list_of_titles = list(df["content.title"].values)
58
+
59
+ embeds = co.embed(texts=list_of_titles,
60
+ model="small").embeddings
61
+
62
+ embeds_npy = np.array(embeds)
63
+ # Load and initialize BERTopic to use KMeans clustering with 8 clusters only.
64
+ cluster_model = KMeans(n_clusters=8)
65
+ topic_model = BERTopic(hdbscan_model=cluster_model)
66
+
67
+ # df is a dataframe. df['title'] is the column of text we're modeling
68
+ df['topic'], probabilities = topic_model.fit_transform(df['content.title'], embeds_npy)
69
 
70
+ app = Topically(os.environ["CO_API_KEY"])
71
 
72
+ df['topic_name'], topic_names = app.name_topics((df['content.title'], df['topic']), num_generations=5)
73
 
74
+ st.write("Topics extracted are:", topic_names)
75
 
76
+ topic_model.set_topic_labels(topic_names)
77
+ # topic_model.visualize_documents(df['content.title'].values,
78
+ # embeddings=embeds_npy,
79
+ # topics = list(range(8)),
80
+ # custom_labels=True,
81
+ # width=900,
82
+ # height=600)