Spaces:
Runtime error
Runtime error
| """ | |
| Build an editable user profile based recommender. | |
| - Read the users json and read their paper reps and keyphrases into memory. | |
| - Read the candidates document (first stage retrieval) and | |
| sentence embeddings into memory (second stage retrieval). | |
| - Display the keyphrases to users and ask them to check it. | |
| - Use the keyphrases and sentence embeddings to compute keyphrase values. | |
| - Display the keyphrase selection box to users for retrieval. | |
| - Use the selected keyphrases for performing retrieval. | |
| """ | |
| import copy | |
| import json | |
| import pickle | |
| import re | |
| import joblib | |
| import os | |
| import collections | |
| import streamlit as st | |
| import numpy as np | |
| from scipy.spatial import distance | |
| from scipy import special | |
| from sklearn.neighbors import NearestNeighbors | |
| from sentence_transformers import SentenceTransformer, models | |
| import torch | |
| import ot | |
| # import seaborn as sns | |
| # import matplotlib | |
| # matplotlib.use('Agg') | |
| # import matplotlib.pyplot as plt | |
| # plt.rcParams['figure.dpi'] = 400 | |
| # plt.rcParams.update({'axes.labelsize': 'small'}) | |
| in_path = './data' | |
| ######################################## | |
| # BACKEND CODE # | |
| ######################################## | |
| def read_user(seed_json): | |
| """ | |
| Given the seed json for the user read the embedded | |
| documents for the user. | |
| :param seed_json: | |
| :return: | |
| """ | |
| if 'doc_vectors_user' not in st.session_state: | |
| uname = seed_json['username'] | |
| user_kps = seed_json['user_kps'] | |
| # Read document vectors. | |
| doc_vectors_user = np.load(os.path.join(in_path, 'users', uname, f'embeds-{uname}-doc.npy')) | |
| with open(os.path.join(in_path, 'users', uname, f'pid2idx-{uname}-doc.json'), 'r') as fp: | |
| pid2idx_user = json.load(fp) | |
| # Read sentence vectors. | |
| pid2sent_vectors = joblib.load(os.path.join(in_path, 'users', uname, f'embeds-{uname}-sent.pickle')) | |
| pid2sent_vectors_user = collections.OrderedDict() | |
| for pid in sorted(pid2sent_vectors): | |
| pid2sent_vectors_user[pid] = pid2sent_vectors[pid] | |
| st.session_state['doc_vectors_user'] = doc_vectors_user | |
| st.session_state['pid2idx_user'] = pid2idx_user | |
| st.session_state['pid2sent_vectors_user'] = pid2sent_vectors_user | |
| st.session_state['user_kps'] = user_kps | |
| st.session_state['username'] = uname | |
| st.session_state['seed_titles'] = [] | |
| for pd in seed_json['papers']: | |
| norm_title = " ".join(pd['title'].lower().strip().split()) | |
| st.session_state.seed_titles.append(norm_title) | |
| return doc_vectors_user, pid2idx_user, pid2sent_vectors, user_kps | |
| else: | |
| return st.session_state.doc_vectors_user, st.session_state.pid2idx_user, \ | |
| st.session_state.pid2sent_vectors_user, st.session_state.user_kps | |
| def first_stage_ranked_docs(user_doc_queries, per_doc_to_rank, total_to_rank=2000): | |
| """ | |
| Return a list of ranked documents given a set of queries. | |
| :param user_doc_queries: read the cached query embeddings | |
| :return: | |
| """ | |
| if 'first_stage_ret_pids' not in st.session_state: | |
| # read the document vectors | |
| doc_vectors = np.load(os.path.join(in_path, 'cands', 'embeds-mlconfs-18_23.npy')) | |
| with open(os.path.join(in_path, 'cands', 'pid2idx-mlconfs-18_23.pickle'), 'rb') as fp: | |
| pid2idx_cands = pickle.load(fp) | |
| idx2pid_cands = dict([(v, k) for k, v in pid2idx_cands.items()]) | |
| # index the vectors into a nearest neighbors structure | |
| neighbors = NearestNeighbors(n_neighbors=per_doc_to_rank) | |
| neighbors.fit(doc_vectors) | |
| st.session_state['neighbors'] = neighbors | |
| st.session_state['idx2pid_cands'] = idx2pid_cands | |
| # Get the dists for all the query docs. | |
| nearest_dists, nearest_idxs = neighbors.kneighbors(user_doc_queries, return_distance=True) | |
| # Get the docs | |
| top_pids = [] | |
| uniq_top = set() | |
| for ranki in range(per_doc_to_rank): # Save papers by rank position for debugging. | |
| for qi in range(user_doc_queries.shape[0]): | |
| idx = nearest_idxs[qi, ranki] | |
| pid = idx2pid_cands[idx] | |
| if pid not in uniq_top: # Only save the unique papers. (ignore multiple retrievals of the same paper) | |
| top_pids.append(pid) | |
| uniq_top.add(pid) | |
| top_pids = top_pids[:total_to_rank] | |
| st.session_state['first_stage_ret_pids'] = top_pids | |
| return top_pids | |
| else: | |
| return st.session_state.first_stage_ret_pids | |
| def read_kp_encoder(in_path): | |
| """ | |
| Read the kp encoder model from disk. | |
| :param in_path: string; | |
| :return: | |
| """ | |
| if 'kp_enc_model' not in st.session_state: | |
| word_embedding_model = models.Transformer('Sheshera/lace-kp-encoder-compsci', | |
| max_seq_length=512) | |
| # trained_model_fname = os.path.join(in_path, 'models', 'kp_encoder_cur_best.pt') | |
| # if torch.cuda.is_available(): | |
| # saved_model = torch.load(trained_model_fname) | |
| # else: | |
| # saved_model = torch.load(trained_model_fname, map_location=torch.device('cpu')) | |
| # word_embedding_model.auto_model.load_state_dict(saved_model) | |
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean') | |
| kp_enc_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | |
| st.session_state['kp_enc_model'] = kp_enc_model | |
| else: | |
| return st.session_state.kp_enc_model | |
| def read_candidates(in_path): | |
| """ | |
| Read candidate papers into pandas dataframe. | |
| :param in_path: | |
| :return: | |
| """ | |
| if 'pid2abstract' not in st.session_state: | |
| with open(os.path.join(in_path, 'cands', 'abstract-mlconfs-18_23.pickle'), 'rb') as fp: | |
| pid2abstract = pickle.load(fp) | |
| # read the sentence vectors | |
| pid2sent_vectors = joblib.load(os.path.join(in_path, 'cands', f'embeds-sent-mlconfs-18_23.pickle')) | |
| st.session_state['pid2sent_vectors_cands'] = pid2sent_vectors | |
| st.session_state['pid2abstract'] = pid2abstract | |
| return pid2abstract, pid2sent_vectors | |
| else: | |
| return st.session_state.pid2abstract, st.session_state.pid2sent_vectors_cands | |
| def get_kp_embeddings(profile_keyphrases): | |
| """ | |
| Embed the passed profike keyphrases | |
| :param profile_keyphrases: list(string) | |
| :return: | |
| """ | |
| kp_enc_model = st.session_state['kp_enc_model'] | |
| if 'kp_vectors_user' not in st.session_state: | |
| kp_embeddings = kp_enc_model.encode(profile_keyphrases) | |
| kp_vectors_user = collections.OrderedDict() | |
| for i, kp in enumerate(profile_keyphrases): | |
| kp_vectors_user[kp] = kp_embeddings[i, :] | |
| st.session_state['kp_vectors_user'] = kp_vectors_user | |
| return kp_vectors_user | |
| else: | |
| uncached_kps = [kp for kp in profile_keyphrases if kp not in st.session_state.kp_vectors_user] | |
| kp_embeddings = kp_enc_model.encode(uncached_kps) | |
| for i, kp in enumerate(uncached_kps): | |
| st.session_state.kp_vectors_user[kp] = kp_embeddings[i, :] | |
| return st.session_state.kp_vectors_user | |
| def generate_profile_values(profile_keyphrases): | |
| """ | |
| - Read sentence embeddings | |
| - Read profile keyphrase embeddings | |
| - Compute alignment from sentences to keyphrases | |
| - Barycenter project the keyphrases to sentences to get kp values | |
| - Return the kp values | |
| :param profile_keyphrases: list(string) | |
| :return: | |
| """ | |
| kp_embeddings = get_kp_embeddings(profile_keyphrases) | |
| # Read sentence embeddings. | |
| user_seed_sentembeds = np.vstack(list(st.session_state.pid2sent_vectors_user.values())) | |
| # Read keyphrase embeddings. | |
| kps_embeds_flat = [] | |
| for kp in profile_keyphrases: | |
| kps_embeds_flat.append(kp_embeddings[kp]) | |
| kps_embeds_flat = np.vstack(kps_embeds_flat) | |
| # Compute transport plan from sentence to keyphrases. | |
| pair_dists = distance.cdist(user_seed_sentembeds, kps_embeds_flat, 'euclidean') | |
| a_distr = [1 / user_seed_sentembeds.shape[0]] * user_seed_sentembeds.shape[0] | |
| b_distr = [1 / kps_embeds_flat.shape[0]] * kps_embeds_flat.shape[0] | |
| # tplan = ot.bregman.sinkhorn_epsilon_scaling(a_distr, b_distr, pair_dists, 0.05, numItermax=2000) | |
| tplan = ot.partial.entropic_partial_wasserstein(a_distr, b_distr, pair_dists, 0.05, m=0.8) | |
| # Barycenter project the keyphrases to the sentences: len(profile_keyphraases) x embedding_dim | |
| proj_kp_vectors = np.matmul(user_seed_sentembeds.T, tplan).T | |
| norm = np.sum(tplan, axis=0) | |
| kp_value_vectors = proj_kp_vectors/norm[:, np.newaxis] | |
| # Return as a dict. | |
| kp2valvectors = {} | |
| for i, kp in enumerate(profile_keyphrases): | |
| kp2valvectors[kp] = kp_value_vectors[i, :] | |
| return kp2valvectors, tplan | |
| def second_stage_ranked_docs(selected_query_kps, first_stage_pids, pid2abstract, pid2sent_reps_cand, to_rank=30): | |
| """ | |
| Return a list of ranked documents given a set of queries. | |
| :param first_stage_pids: list(string) | |
| :param pid2abstract: dict(pid: paperd) | |
| :param query_paper_idxs: list(int); | |
| :return: | |
| """ | |
| if len(selected_query_kps) < 3: | |
| topk = len(selected_query_kps) | |
| else: # Use 20% of keyphrases for scoring or 3 whichever is larger | |
| topk = max(int(len(st.session_state.kp2val_vectors)*0.2), 3) | |
| query_kp_values = np.vstack([st.session_state.kp2val_vectors[kp] for kp in selected_query_kps]) | |
| pid2topkdist = dict() | |
| pid2kp_expls = collections.defaultdict(list) | |
| for i, pid in enumerate(first_stage_pids): | |
| sent_reps = pid2sent_reps_cand[pid] | |
| pair_dists = distance.cdist(query_kp_values, sent_reps) | |
| # Pick the topk unique profile concepts. | |
| kp_ind = np.argsort(pair_dists.min(axis=1))[:topk] | |
| sub_pair_dists = pair_dists[kp_ind, :] | |
| # sub_kp_reps = query_kp_values[kp_ind, :] | |
| # a_distr = special.softmax(-1*np.min(sub_pair_dists, axis=1)) | |
| # b_distr = [1 / sent_reps.shape[0]] * sent_reps.shape[0] | |
| # tplan = ot.bregman.sinkhorn_epsilon_scaling(a_distr, b_distr, sub_pair_dists, 0.05) | |
| # Use attention instead of OT for distance computation | |
| tplan = special.softmax(-1 * sub_pair_dists) | |
| wd = np.sum(sub_pair_dists * tplan) | |
| # topk_dist = 0 | |
| # for k in range(topk): | |
| # topk_dist += pair_dists[kp_ind[k], sent_ind[k]] | |
| # pid2kp_expls[pid].append(selected_query_kps[kp_ind[k]]) | |
| # pid2topkdist[pid] = topk_dist | |
| pid2topkdist[pid] = wd | |
| top_pids = sorted(pid2topkdist, key=pid2topkdist.get) | |
| # Get the docs | |
| retrieved_papers = collections.OrderedDict() | |
| for pid in top_pids: | |
| # Exclude papers from the seed set in the result set. | |
| norm_title = " ".join(pid2abstract[pid]['title'].lower().strip().split()) | |
| # The mlconf pid2abstract has braces in the titles sometimes - remove them | |
| norm_title = re.sub('\{', '', norm_title) | |
| norm_title = re.sub('\}', '', norm_title) | |
| if norm_title in st.session_state.seed_titles: | |
| continue | |
| retrieved_papers[pid2abstract[pid]['title']] = { | |
| 'title': pid2abstract[pid]['title'], | |
| 'kp_explanations': pid2kp_expls[pid], | |
| 'abstract': pid2abstract[pid]['abstract'], | |
| 'author_names': pid2abstract[pid]['author_names'], | |
| 'url': pid2abstract[pid]['url'], | |
| } | |
| if len(retrieved_papers) == to_rank: | |
| break | |
| return retrieved_papers | |
| ######################################## | |
| # HELPER CODE # | |
| ######################################## | |
| def parse_input_kps(unparsed_kps, initial_user_kps): | |
| """ | |
| Function to parse the input keyphrase string. | |
| :return: | |
| """ | |
| if unparsed_kps.strip(): | |
| kps = unparsed_kps.split(',') | |
| parsed_user_kps = [] | |
| uniq_kps = set() | |
| for kp in kps: | |
| kp = kp.strip() | |
| if kp not in uniq_kps: | |
| parsed_user_kps.append(kp) | |
| uniq_kps.add(kp) | |
| else: # If its an empty string use the initial kps | |
| parsed_user_kps = copy.copy(initial_user_kps) | |
| return parsed_user_kps | |
| # def plot_sent_kp_alignment(tplan, kp_labels, sent_labels): | |
| # """ | |
| # Plot the sentence keyphrase alignment. | |
| # :return: | |
| # """ | |
| # fig, ax = plt.subplots() | |
| # h = sns.heatmap(tplan.T, linewidths=.3, xticklabels=sent_labels, | |
| # yticklabels=kp_labels, cmap='Blues') | |
| # h.tick_params('y', labelsize=5) | |
| # h.tick_params('x', labelsize=2) | |
| # plt.tight_layout() | |
| # return fig | |
| def multiselect_title_formatter(title): | |
| """ | |
| Format the multi-select titles. | |
| :param title: string | |
| :return: string: formatted title | |
| """ | |
| ftitle = title.split()[:5] | |
| return ' '.join(ftitle) + '...' | |
| def format_abstract(paperd, to_display=3, markdown=True): | |
| """ | |
| Given a dict with title and abstract return | |
| a formatted text for rendering with markdown. | |
| :param paperd: | |
| :param to_display: | |
| :return: | |
| """ | |
| if len(paperd['abstract']) < to_display: | |
| sents = ' '.join(paperd['abstract']) | |
| else: | |
| sents = ' '.join(paperd['abstract'][:to_display]) + '...' | |
| try: | |
| kp_expl = ', '.join(paperd['kp_explanations']) | |
| except KeyError: | |
| kp_expl = '' | |
| title = re.sub('\{', '', paper['title']) | |
| title = re.sub('\}', '', title) | |
| sents = re.sub('\{', '', sents) | |
| sents = re.sub('\}', '', sents) | |
| if markdown: | |
| try: | |
| url = paperd['url'] | |
| par = '<p><b>Title</b>: <i><a href="{:s}">{:s}</a></i><br><b>Abstract</b>: {:s}<br><i>{:s}</i></p>'. \ | |
| format(url, title, sents, kp_expl) | |
| except KeyError: | |
| par = '<p><b>Title</b>: <i>{:s}</i><br><b>Abstract</b>: {:s}<br><i>{:s}</i></p>'. \ | |
| format(paper['title'], sents, kp_expl) | |
| else: | |
| par = 'Title: {:s}; Abstract: {:s}'.format(paper['title'], sents) | |
| return par | |
| def perp_result_json(): | |
| """ | |
| Create a json with the results retrieved for each | |
| iteration and the papers users choose to save at | |
| each step. | |
| :return: | |
| """ | |
| result_json = {} | |
| # print(len(st.session_state.i_selections)) | |
| # print(len(st.session_state.i_resultps)) | |
| # print(len(st.session_state.i_savedps)) | |
| # print(st.session_state.tuning_i) | |
| assert(len(st.session_state.i_selections) == len(st.session_state.i_resultps) | |
| == len(st.session_state.i_savedps) == st.session_state.tuning_i) | |
| for tuning_i, i_pselects, (_, i_savedps) in zip(range(st.session_state.tuning_i), st.session_state.i_selections, | |
| st.session_state.i_savedps.items()): | |
| iterdict = { | |
| 'iteration': tuning_i, | |
| 'profile_selections': copy.deepcopy(i_pselects), | |
| 'saved_papers': copy.deepcopy(list(i_savedps.items())) | |
| } | |
| result_json[tuning_i] = iterdict | |
| result_json['condition'] = 'maple' | |
| result_json['username'] = st.session_state.username | |
| return json.dumps(result_json) | |
| ######################################## | |
| # APP CODE # | |
| ######################################## | |
| st.title('\U0001F341 Maple Paper Recommender \U0001F341') | |
| st.markdown( | |
| '\U0001F341 Maple \U0001F341 uses a seed set of authored papers to make paper recommendations from ML and NLP conferences: NeurIPS, ICLR, ICML, UAI, AISTATS, ACL*, and EMNLP from years 2018 to 2023.' | |
| '\n1. :white_check_mark: Select your username on the left\n2. :eyes: Verify keyphrases inferred for the papers and click ' | |
| '"\U0001F9D1 Generate profile \U0001F9D1"\n3. :mag: Request recommendations\n4. :repeat: Tune recommendations') | |
| # Load candidate documents and models. | |
| pid2abstract_cands, pid2sent_vectors_cands = read_candidates(in_path) | |
| kp_encoding_model = read_kp_encoder(in_path) | |
| # Initialize the session state: | |
| if 'tuning_i' not in st.session_state: | |
| st.session_state['tuning_i'] = 0 | |
| # Save the profile keyphrases at every run | |
| # (run is every time the script runs, iteration is every time recs are requested) | |
| st.session_state['run_user_kps'] = [] | |
| # Save the profile selections at each iteration | |
| st.session_state['i_selections'] = [] | |
| # dict of dicts: tuning_i: dict(paper_title: paper) | |
| st.session_state['i_resultps'] = {} | |
| # dict of dicts: tuning_i: dict(paper_title: saved or not bool) | |
| st.session_state['i_savedps'] = collections.defaultdict(dict) | |
| # Ask user to upload a set of seed query papers. | |
| with st.sidebar: | |
| available_users = os.listdir(os.path.join(in_path, 'users')) | |
| available_users.sort() | |
| available_users = (None,) + tuple(available_users) | |
| # uploaded_file = st.file_uploader("\U0001F331 Upload seed papers", | |
| # type='json', | |
| # help='Upload a json file with titles and abstracts of the papers to ' | |
| # 'include in your profile.') | |
| # st.markdown(f"<b style='color:red;'>Select your username from the drop-down:</b>", unsafe_allow_html=True) | |
| selected_user = st.selectbox('Select your username from the drop-down', | |
| available_users) | |
| if selected_user is not None: | |
| user_papers = json.load( | |
| open(os.path.join(in_path, 'users', selected_user, f'seedset-{selected_user}-maple.json'))) | |
| # user_papers = json.load(uploaded_file) | |
| # Read user data. | |
| doc_vectors_user, pid2idx_user, pid2sent_vectors_user, user_kps = read_user(user_papers) | |
| st.session_state.run_user_kps.append(copy.copy(user_kps)) | |
| display_profile_kps = ', '.join(user_kps) | |
| # Perform first stage retrieval. | |
| first_stage_ret_pids = first_stage_ranked_docs(user_doc_queries=doc_vectors_user, per_doc_to_rank=500) | |
| with st.expander("Examine seed papers"): | |
| st.markdown(f'**Initial profile keyphrases**:') | |
| st.markdown(display_profile_kps) | |
| st.markdown('**Seed papers**: {:d}'.format(len(user_papers['papers']))) | |
| for paper in user_papers['papers']: | |
| par = format_abstract(paperd=paper, to_display=6) | |
| st.markdown(par, unsafe_allow_html=True) | |
| st.markdown('\u2b50 Saved papers') | |
| if selected_user is not None: | |
| # Create a text box where users can see their profile keyphrases. | |
| st.subheader('\U0001F4DD Seed paper keyphrases') | |
| with st.form('profile_kps'): | |
| input_kps = st.text_area( | |
| 'Add/remove keyphrases to fix redundancy, inaccuracy, incompleteness, or being nonsensical:', | |
| display_profile_kps, | |
| help='Edit the profile keyphrases if they are redundant, incomplete, nonsensical, ' | |
| 'or dont accurately describe the seed papers. You can also add keyphrases to ' | |
| 'capture aspects of the seed papers that the keyphrases dont currently capture.', | |
| placeholder='If left empty initial profile keyphrases will be used...') | |
| input_user_kps = parse_input_kps(unparsed_kps=input_kps, initial_user_kps=user_kps) | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col2: | |
| generate_profile = st.form_submit_button('\U0001F9D1 Generate profile \U0001F9D1') | |
| if generate_profile: | |
| prev_run_input_kps = st.session_state.run_user_kps[-1] | |
| if set(prev_run_input_kps) == set(input_user_kps): # If there is no change then use | |
| if 'kp2val_vectors' in st.session_state: # This happens all the time except the first run. | |
| kp2val_vectors = st.session_state.kp2val_vectors | |
| user_tplan = st.session_state.user_tplan | |
| else: # This happens on the first run. | |
| with st.spinner(text="Generating profile..."): | |
| kp2val_vectors, user_tplan = generate_profile_values(profile_keyphrases=input_user_kps) | |
| st.session_state['kp2val_vectors'] = kp2val_vectors | |
| st.session_state['user_tplan'] = user_tplan | |
| else: | |
| with st.spinner(text="Generating profile..."): | |
| kp2val_vectors, user_tplan = generate_profile_values(profile_keyphrases=input_user_kps) | |
| st.session_state['kp2val_vectors'] = kp2val_vectors | |
| st.session_state['user_tplan'] = user_tplan | |
| st.session_state.run_user_kps.append(copy.copy(input_user_kps)) | |
| # Create a multiselect dropdown | |
| if 'kp2val_vectors' in st.session_state: | |
| # with st.expander("Examine paper-keyphrase alignment"): | |
| # user_tplan = st.session_state.user_tplan | |
| # fig = plot_sent_kp_alignment(tplan=user_tplan, kp_labels=input_user_kps, | |
| # sent_labels=range(user_tplan.shape[0])) | |
| # st.write(fig) | |
| st.subheader('\U0001F9D1 Profile keyphrases for ranking') | |
| with st.form('profile_input'): | |
| st.markdown(""" | |
| <style> | |
| .stMultiSelect [data-baseweb=select] span{ | |
| max-width: 500px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| profile_selections = st.multiselect(label='Include or exclude profile keyphrases to use for recommendations:', | |
| default=input_user_kps, # Use all the values by default. | |
| options=input_user_kps, | |
| help='Items selected here will be used for creating your ' | |
| 'recommended list') | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col2: | |
| generate_recs = st.form_submit_button('\U0001F9ED Recommend papers \U0001F9ED') | |
| # Use the uploaded files to create a ranked list of items. | |
| if generate_recs and profile_selections: | |
| # st.write('Generating recs...') | |
| st.session_state.tuning_i += 1 | |
| st.session_state.i_selections.append(copy.deepcopy(profile_selections)) | |
| with st.spinner(text="Recommending papers..."): | |
| top_papers = second_stage_ranked_docs(first_stage_pids=first_stage_ret_pids, | |
| selected_query_kps=profile_selections, | |
| pid2abstract=pid2abstract_cands, | |
| pid2sent_reps_cand=pid2sent_vectors_cands, | |
| to_rank=30) | |
| st.session_state.i_resultps[st.session_state.tuning_i] = copy.deepcopy(top_papers) | |
| # Read off from the result cache and allow users to save some papers. | |
| if st.session_state.tuning_i in st.session_state.i_resultps: | |
| # st.write('Waiting for selections...') | |
| cached_top_papers = st.session_state.i_resultps[st.session_state.tuning_i] | |
| for paper in cached_top_papers.values(): | |
| # This statement ensures correctness for when users unselect a previously selected item. | |
| st.session_state.i_savedps[st.session_state.tuning_i][paper['title']] = False | |
| dcol1, dcol2 = st.columns([1, 16]) | |
| with dcol1: | |
| save_paper = st.checkbox('\u2b50', key=paper['title']) | |
| with dcol2: | |
| plabel = format_abstract(paperd=paper, to_display=2, markdown=True) | |
| st.markdown(plabel, unsafe_allow_html=True) | |
| with st.expander('See more..'): | |
| full_abstract = ' '.join(paper['abstract']) | |
| st.markdown(full_abstract, unsafe_allow_html=True) | |
| if save_paper: | |
| st.session_state.i_savedps[st.session_state.tuning_i].update({paper['title']: True}) | |
| # Print the saved papers across iterations in the sidebar. | |
| with st.sidebar: | |
| with st.expander("Examine saved papers"): | |
| # st.write('Later write..') | |
| # st.write(st.session_state.i_savedps) | |
| for iteration, savedps in st.session_state.i_savedps.items(): | |
| st.markdown('Iteration: {:}'.format(iteration)) | |
| for papert, saved in savedps.items(): | |
| if saved: | |
| fpapert = '<p style=color:Gray; ">- {:}</p>'.format(papert) | |
| st.markdown('{:}'.format(fpapert), unsafe_allow_html=True) | |
| if st.session_state.tuning_i > 0: | |
| st.download_button('Download papers', perp_result_json(), mime='json', | |
| help='Download the papers saved in the session.') | |
| with st.expander("Copy saved papers to clipboard"): | |
| st.write(json.loads(perp_result_json())) | |