File size: 3,629 Bytes
35c98d9
 
 
 
 
 
 
 
553b9ec
 
35c98d9
 
 
 
 
 
5601530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18a5832
70bf707
a12dda5
5601530
 
 
 
 
 
 
 
 
ac0df45
a12dda5
 
 
 
553b9ec
a12dda5
553b9ec
ac0df45
 
 
 
 
 
 
 
 
 
a12dda5
ac0df45
 
 
 
 
 
 
a12dda5
 
 
2b70172
a12dda5
2b70172
a12dda5
 
 
2cdd7b2
a12dda5
 
2cdd7b2
a12dda5
2cdd7b2
a12dda5
2cdd7b2
a12dda5
2cdd7b2
a12dda5
 
 
 
 
 
 
 
 
ac0df45
5601530
a12dda5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import json
import requests
import csv
import pandas as pd
import tqdm

import cohere
import os


from topically import Topically
from bertopic import BERTopic
from sklearn.cluster import KMeans
import numpy as np

venue = 'ICLR.cc/2023/Conference'
venue_short = 'iclr2023'

def get_conference_notes(venue, blind_submission=False):
    """
    Get all notes of a conference (data) from OpenReview API.
    If results are not final, you should set blind_submission=True.
    """

    blind_param = '-/Blind_Submission' if blind_submission else ''
    offset = 0
    notes = []
    while True:
        print('Offset:', offset, 'Data:', len(notes))
        url = f'https://api.openreview.net/notes?invitation={venue}/{blind_param}&offset={offset}'
        response = requests.get(url)
        data = response.json()
        if len(data['notes']) == 0:
            break
        offset += 1000
        notes.extend(data['notes'])
    return notes

raw_notes = get_conference_notes(venue, blind_submission=True)


st.title("ICLR2023 Papers Visualization")
st.write("Number of submissions at ICLR 2023:", len(raw_notes))

df_raw = pd.json_normalize(raw_notes)
# set index as first column
# df_raw.set_index(df_raw.columns[0], inplace=True)
accepted_venues = ['ICLR 2023 poster', 'ICLR 2023 notable top 5%', 'ICLR 2023 notable top 25%']
df = df_raw[df_raw["content.venue"].isin(accepted_venues)]
st.write("Number of submissions accepted at ICLR 2023:", len(df))

df_filtered = df[['content.title', 'content.keywords', 'content.abstract', 'content.venue']]
df = df_filtered
if "CO_API_KEY" not in os.environ:
    raise KeyError("CO_API_KEY not found in st.secrets or os.environ. Please set it in "
                   ".streamlit/secrets.toml or as an environment variable.")

co = cohere.Client(os.environ["CO_API_KEY"])

def to_html(df: pd.DataFrame, table_header: str) -> str:
        table_data = ''.join(df.html_table_content)
        html = f'''
        <table>
            {table_header}
            {table_data}
        </table>'''
        return html


def get_visualizations():
    table_header = '''
            <tr>
                <td width="25%">Title</td>
                <td width="15%">Keywords</td>
                <td width="10%">Venue</td>
                <td width="50%">Abstract</td>
            </tr>'''
    list_of_titles = list(df["content.title"].values)
    embeds = co.embed(texts=list_of_titles,                  				
  					model="small").embeddings
    
    embeds_npy = np.array(embeds)
    
    # Load and initialize BERTopic to use KMeans clustering with 8 clusters only.
    cluster_model = KMeans(n_clusters=8)
    topic_model = BERTopic(hdbscan_model=cluster_model)
    
    # df is a dataframe. df['title'] is the column of text we're modeling
    df['topic'], probabilities = topic_model.fit_transform(df['content.title'], embeds_npy)
    
    app = Topically(os.environ["CO_API_KEY"])
    
    df['topic_name'], topic_names = app.name_topics((df['content.title'], df['topic']), num_generations=5)
    
    #st.write("Topics extracted are:", topic_names)
    
    topic_model.set_topic_labels(topic_names)
    fig1 = topic_model.visualize_documents(df['content.title'].values, 
                                    embeddings=embeds_npy,
                                    topics = list(range(8)),
                                    custom_labels=True)
    topic_model.set_topic_labels(topic_names)
    fig2 = topic_model.visualize_barchart(custom_labels=True)
    st.plotly_chart(fig1)
    st.plotly_chart(fig2)


st.button("Run Visualization", on_click=get_visualizations)