friendshipkim commited on
Commit
86a9a82
1 Parent(s): fad6840

Add application file

Browse files
Files changed (1) hide show
  1. app.py +181 -2
app.py CHANGED
@@ -1,4 +1,183 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import sys
4
+ import os
5
+ from datasets import load_from_disk, load_dataset
6
+ from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
9
+ import time
10
+ from annotated_text import annotated_text
11
 
12
+ from huggingface_hub import hf_hub_download
13
+ repo_id = "friendshipkim/IUR_Reddit"
14
+
15
+ # ABSOLUTE_PATH = os.path.dirname(__file__)
16
+ # ASSETS_PATH = os.path.join(ABSOLUTE_PATH, 'model_assets')
17
+
18
+ @st.cache_data
19
+ def preprocess_text(s):
20
+ return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
21
+
22
+ @st.cache_data
23
+ def get_pairwise_distances(model):
24
+ # df = pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv").set_index('index')
25
+ df = pd.read_csv(hf_hub_download(repo_id=repo_id, filename="pairwise_distances.csv")).set_index('index')
26
+ return df
27
+
28
+ @st.cache_data
29
+ def get_pairwise_distances_chunked(model, chunk):
30
+ # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
31
+ # print(df.iloc[0]['queries'])
32
+ # if chunk == int(df.iloc[0]['queries']):
33
+ # return df
34
+ return get_pairwise_distances(model)
35
+ @st.cache_data
36
+ def get_query_strings():
37
+ # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.jsonl", lines = True)
38
+ df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_queries_english.jsonl"), lines = True)
39
+ df['index'] = df.reset_index().index
40
+ return df
41
+ # df['partition'] = df['index']%100
42
+ # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
43
+
44
+ # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
45
+ @st.cache_data
46
+ def get_candidate_strings():
47
+ # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
48
+ df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_candidates_english.jsonl"), lines = True)
49
+ df['index'] = df.reset_index().index
50
+ return df
51
+ # df['partition'] = df['index']%100
52
+ # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
53
+ # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
54
+
55
+ @st.cache_data
56
+ def get_embedding_dataset(model):
57
+ # data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
58
+ data = load_dataset("friendshipkim/luar_clone2_top_100_embedding")
59
+ return data
60
+
61
+ @st.cache_data
62
+ def get_bad_queries(model):
63
+ df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
64
+ return df
65
+ @st.cache_data
66
+ def get_gt_candidates(model, author):
67
+ gt_candidates = get_candidate_strings()
68
+ df = gt_candidates[gt_candidates['authorIDs'].apply(lambda x: x[0]) == author]
69
+ return df
70
+ @st.cache_data
71
+ def get_candidate_text(l):
72
+ return get_candidate_strings().at[l,'fullText']
73
+
74
+ @st.cache_data
75
+ def get_annotated_text(text, word, pos):
76
+ # print("here", word, pos)
77
+ start= text.index(word, pos)
78
+ end = start+len(word)
79
+ return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end
80
+
81
+ class AgGridBuilder:
82
+ __static_key = 0
83
+ def build_ag_grid(table, display_columns):
84
+ AgGridBuilder.__static_key += 1
85
+ options_builder = GridOptionsBuilder.from_dataframe(table[display_columns])
86
+ options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10)
87
+ options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0])
88
+ options = options_builder.build()
89
+ return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED)
90
+
91
+ if __name__ == '__main__':
92
+ st.set_page_config(layout="wide")
93
+
94
+ # models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH))
95
+ models = ['luar_clone2_top_100']
96
+
97
+ with st.sidebar:
98
+ current_model = st.selectbox(
99
+ "Select Model to analyze",
100
+ models
101
+ )
102
+
103
+ pairwise_distances = get_pairwise_distances(current_model)
104
+ embedding_dataset = get_embedding_dataset(current_model)
105
+
106
+ candidate_string_grid = None
107
+ gt_candidate_string_grid = None
108
+ with st.container():
109
+ t1 = time.time()
110
+ st.title("Full Text")
111
+ col1, col2 = st.columns([14, 2])
112
+ t2 = time.time()
113
+ query_table = get_bad_queries(current_model)
114
+ t3 = time.time()
115
+ # print(query_table)
116
+ with col2:
117
+ index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1)
118
+ query_text = query_table.loc[index]['fullText']
119
+ preprocessed_query_text = preprocess_text(query_text)
120
+ text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1)
121
+ query_index = int(query_table.iloc[index]['index'])
122
+
123
+ with col1:
124
+ if 'pos_highlight' not in st.session_state or text_highlight_index == 0:
125
+ st.session_state['pos_highlight'] = text_highlight_index
126
+ st.session_state['pos_history'] = [0]
127
+
128
+ if st.session_state['pos_highlight'] > text_highlight_index:
129
+ st.session_state['pos_history'] = st.session_state['pos_history'][:-2]
130
+ if len(st.session_state['pos_history']) == 0:
131
+ st.session_state['pos_history'] = [0]
132
+ # print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index)
133
+ anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0)
134
+ if st.session_state['pos_highlight'] < text_highlight_index:
135
+ st.session_state['pos_history'].append(pos)
136
+ st.session_state['pos_highlight'] = text_highlight_index
137
+ annotated_text(*anotated_text_)
138
+ # annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.")
139
+ t4 = time.time()
140
+
141
+ # print(f"query time query text: {t3-t2}, total time: {t4-t1}")
142
+ with st.container():
143
+ st.title("Top 16 Recommended Candidates")
144
+ col1, col2, col3 = st.columns([10, 4, 2])
145
+ rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates']
146
+ # print(rec_candidates)
147
+ l = list(rec_candidates)
148
+ with col3:
149
+ candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1)
150
+ print("l:",l, query_index)
151
+ pairwise_candidate_index = int(l[candidate_rec_index])
152
+ with col1:
153
+ st.header("Text")
154
+ t1 = time.time()
155
+ st.write(get_candidate_text(pairwise_candidate_index))
156
+ t2 = time.time()
157
+ with col2:
158
+ st.header("Cosine Distance")
159
+ st.write(float(pairwise_distances[\
160
+ ( pairwise_distances['queries'] == query_index ) \
161
+ &
162
+ ( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances']))
163
+ print(f"candidate string retreival: {t2-t1}")
164
+ with st.container():
165
+ t1 = time.time()
166
+ st.title("Candidates With Same Authors As Query")
167
+ col1, col2, col3 = st.columns([10, 4, 2])
168
+ t2 = time.time()
169
+ gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0])
170
+ t3 = time.time()
171
+
172
+ with col3:
173
+ candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1)
174
+ gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index'])
175
+ with col1:
176
+ st.header("Text")
177
+ st.write(gt_candidates.iloc[candidate_index]['fullText'])
178
+ with col2:
179
+ t4 = time.time()
180
+ st.header("Cosine Distance")
181
+ st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][gt_candidate_index]['embedding']]))[0,0])
182
+ t5 = time.time()
183
+ print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")