RamAnanth1 commited on
Commit
70bf707
β€’
1 Parent(s): 18a5832

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -29
app.py CHANGED
@@ -40,6 +40,7 @@ def get_conference_notes(venue, blind_submission=False):
40
  raw_notes = get_conference_notes(venue, blind_submission=True)
41
 
42
  st.set_page_config(page_title="ICLR2023 Papers Visualization", page_icon="🐞", layout="centered")
 
43
  st.write("Number of submissions at ICLR 2023:", len(raw_notes))
44
 
45
  df_raw = pd.json_normalize(raw_notes)
@@ -49,42 +50,42 @@ accepted_venues = ['ICLR 2023 poster', 'ICLR 2023 notable top 5%', 'ICLR 2023 no
49
  df = df_raw[df_raw["content.venue"].isin(accepted_venues)]
50
  st.write("Number of submissions accepted at ICLR 2023:", len(df))
51
 
52
- df_filtered = df[['id', 'content.title', 'content.keywords', 'content.abstract']]
53
- df = df_filtered
54
- if "CO_API_KEY" not in os.environ:
55
- raise KeyError("CO_API_KEY not found in st.secrets or os.environ. Please set it in "
56
- ".streamlit/secrets.toml or as an environment variable.")
57
 
58
- co = cohere.Client(os.environ["CO_API_KEY"])
59
 
60
- def get_visualizations():
61
- list_of_titles = list(df["content.title"].values)
62
- embeds = co.embed(texts=list_of_titles,
63
- model="small").embeddings
64
 
65
- embeds_npy = np.array(embeds)
66
 
67
- # Load and initialize BERTopic to use KMeans clustering with 8 clusters only.
68
- cluster_model = KMeans(n_clusters=8)
69
- topic_model = BERTopic(hdbscan_model=cluster_model)
70
 
71
- # df is a dataframe. df['title'] is the column of text we're modeling
72
- df['topic'], probabilities = topic_model.fit_transform(df['content.title'], embeds_npy)
73
 
74
- app = Topically(os.environ["CO_API_KEY"])
75
 
76
- df['topic_name'], topic_names = app.name_topics((df['content.title'], df['topic']), num_generations=5)
77
 
78
- #st.write("Topics extracted are:", topic_names)
79
 
80
- topic_model.set_topic_labels(topic_names)
81
- fig1 = topic_model.visualize_documents(df['content.title'].values,
82
- embeddings=embeds_npy,
83
- topics = list(range(8)),
84
- custom_labels=True)
85
- topic_model.set_topic_labels(topic_names)
86
- fig2 = topic_model.visualize_barchart(custom_labels=True)
87
- st.plotly_chart(fig1)
88
- st.plotly_chart(fig2)
89
 
90
- st.button("Run Visualization", on_click=get_visualizations)
 
40
  raw_notes = get_conference_notes(venue, blind_submission=True)
41
 
42
  st.set_page_config(page_title="ICLR2023 Papers Visualization", page_icon="🐞", layout="centered")
43
+
44
  st.write("Number of submissions at ICLR 2023:", len(raw_notes))
45
 
46
  df_raw = pd.json_normalize(raw_notes)
 
50
  df = df_raw[df_raw["content.venue"].isin(accepted_venues)]
51
  st.write("Number of submissions accepted at ICLR 2023:", len(df))
52
 
53
+ # df_filtered = df[['id', 'content.title', 'content.keywords', 'content.abstract']]
54
+ # df = df_filtered
55
+ # if "CO_API_KEY" not in os.environ:
56
+ # raise KeyError("CO_API_KEY not found in st.secrets or os.environ. Please set it in "
57
+ # ".streamlit/secrets.toml or as an environment variable.")
58
 
59
+ # co = cohere.Client(os.environ["CO_API_KEY"])
60
 
61
+ # def get_visualizations():
62
+ # list_of_titles = list(df["content.title"].values)
63
+ # embeds = co.embed(texts=list_of_titles,
64
+ # model="small").embeddings
65
 
66
+ # embeds_npy = np.array(embeds)
67
 
68
+ # # Load and initialize BERTopic to use KMeans clustering with 8 clusters only.
69
+ # cluster_model = KMeans(n_clusters=8)
70
+ # topic_model = BERTopic(hdbscan_model=cluster_model)
71
 
72
+ # # df is a dataframe. df['title'] is the column of text we're modeling
73
+ # df['topic'], probabilities = topic_model.fit_transform(df['content.title'], embeds_npy)
74
 
75
+ # app = Topically(os.environ["CO_API_KEY"])
76
 
77
+ # df['topic_name'], topic_names = app.name_topics((df['content.title'], df['topic']), num_generations=5)
78
 
79
+ # #st.write("Topics extracted are:", topic_names)
80
 
81
+ # topic_model.set_topic_labels(topic_names)
82
+ # fig1 = topic_model.visualize_documents(df['content.title'].values,
83
+ # embeddings=embeds_npy,
84
+ # topics = list(range(8)),
85
+ # custom_labels=True)
86
+ # topic_model.set_topic_labels(topic_names)
87
+ # fig2 = topic_model.visualize_barchart(custom_labels=True)
88
+ # st.plotly_chart(fig1)
89
+ # st.plotly_chart(fig2)
90
 
91
+ # st.button("Run Visualization", on_click=get_visualizations)