import streamlit as st from PIL import Image from transformers import pipeline import nltk import spacy import en_core_web_lg from helpers import prompt_to_nli, display_nli_pr_results_as_list @st.cache() def download_punkt(): nltk.download('punkt') @st.cache(allow_output_mutation=True) def load_spacy_pipeline(): return en_core_web_lg.load() def choose_text_menu(text): if 'text' not in st.session_state: st.session_state.text = 'Several demonstrators were injured.' text = st.text_area('Event description', st.session_state.text) return text # Load Models in cache @st.cache(allow_output_mutation=True) def load_model_prompting(): return pipeline("fill-mask", model="distilbert-base-uncased") @st.cache(allow_output_mutation=True) def load_model_nli(): return pipeline(task="sentiment-analysis", model="roberta-large-mnli") download_punkt() nlp = load_spacy_pipeline() ### App START st.markdown( """ # Demonstration dashboard of the PR-ENT Approach. ### *Rethinking the Event Coding Pipeline with Prompt Entailment* ### https://arxiv.org/abs/2210.05257 ### Clément Lefebvre (Swiss Data Science Center) ### Niklas Stoehr (ETH Zürich, https://niklas-stoehr.com/) ##### Version: 1.0 """ ) st.markdown(""" ### 1. PR-ENT summary """) @st.cache() def load_prent_image(): return Image.open('pipeline_flow.png') st.image(load_prent_image(), caption="""PR-ENT Flow. First, we concatenate the event description e and the template t. Then we feed them through a pretrained prompting model to obtain a list of answer candidates. Then, for each answer candidate we build a hypothesis by filling the template and check for entailment with the premise (the event description). Finally, by filtering on the entailment score, we obtain a list of entailed answer candidates related to the event description. """) model_nli = load_model_nli() model_prompting = load_model_prompting() st.markdown(""" ### 2. Write an event description: The first step is to write an event description that will be fed to the pipeline. This can be any text in English. """) text = choose_text_menu('') st.session_state.text = text st.markdown(""" ### 3. Template design: The second step is to design a template while keeping in mind the objective of the classification. - A good starting point is to use `This event involves [Z].`. This template will ideally be filled with a 1 word summary of the event. - Another good example is `People were [Z].`. With this one we mostly expect a verb that describes the action. You can also use any template you design. Keep in mind that if the masked slot `[Z]` is at the end of the sentence, to not forget punctuation, otherwise the model may fill the template with punctuation signs. """) if 'prompt' not in st.session_state: st.session_state.prompt = 'This event involves [Z].' prompt = st.text_input('Template:',st.session_state.prompt) st.session_state.prompt = prompt st.markdown(""" ### 4. Select the two parameters: - The first parameter `top_k` is the maximum number of tokens that will be given by the prompting model. It's also the number of tokens that will be tried for entailment. Ideally, you want a high enough number of tokens, otherwise you may miss critical information. However, each additional token will increase the computation time as it needs to go through the entailment model. From our experiments, a good choice is between `[10,50]`, lower and you miss information, higher and you start getting unrelated tokens and long computation time. - The second parameter is the minimum entailment score to confirm that the token is entailed with the event description. By default, we set it at `0.5` (more entailed than not) but it can be modified depending on needs. """) def select_top_k(): if 'top_k' not in st.session_state: st.session_state.top_k = 10 return st.number_input('Number of max tokens to output (default: 10, min: 0, max: 50)? ',step = 100, min_value=0, max_value=50, value=int(st.session_state.top_k)) def select_nli_limit(): if 'nli_limit' not in st.session_state: st.session_state.nli_limit = 0.5 return st.number_input('Minimum score of entailment (default: 0.5, min: 0, max: 1)? ',step = 100.0, min_value=0.0, max_value=1.0, value=st.session_state.nli_limit) def update_session_state_callback(value, key): st.session_state[value] = st.session_state[key] top_k = select_top_k() st.session_state.top_k = top_k nli_limit = select_nli_limit() st.session_state.nli_limit = nli_limit st.markdown(""" ### 5. Remove similar tokens from output: An additional option is to remove similar tokens (e.g. `protest, protests`) from the output. This computes the lemma of each word (based on the template) and removes duplicate lemmas. """) if 'remove_lemma' not in st.session_state: st.session_state.remove_lemma = False remove_lemma = st.checkbox('Remove similar lemma (e.g. protest, protests) from output?', value= st.session_state.remove_lemma) st.session_state.remove_lemma = remove_lemma # Save settings to display before the results if "old_prompt" not in st.session_state: st.session_state.old_text =st.session_state.text st.session_state.old_prompt =st.session_state.prompt st.session_state.old_top_k = st.session_state.top_k st.session_state.old_nli_limit = st.session_state.nli_limit st.markdown(""" ### 6. Run the pipeline """) st.markdown("""The entailed tokens are given as a list of words associated with the probability of entailment.""") if st.button("Run PR-ENT"): computation_state_prent = st.text("PR-ENT Computation Running.") st.session_state.old_text =st.session_state.text st.session_state.old_prompt =st.session_state.prompt st.session_state.old_top_k = st.session_state.top_k st.session_state.old_nli_limit = st.session_state.nli_limit # Replace the mask prompt = prompt.replace('[Z]', '{}') prompt = prompt.replace('[MASK]', '{}') results = prompt_to_nli(text, prompt, model_prompting, model_nli, nlp, top_k, nli_limit, remove_lemma) list_results = [x[0][0] + ' ' + str(int(x[1][1]*100)) + '%' for x in results] st.session_state.list_results = list_results computation_state_prent.text("PR-ENT Computation Done.") if 'list_results' in st.session_state: st.write('**Event Description**: {}'.format(st.session_state.old_text)) st.write('**Template**: "{}"; **Top K**: {}; **Entailment Threshold**: {}.'.format(st.session_state.old_prompt,st.session_state.old_top_k, st.session_state.old_nli_limit)) display_nli_pr_results_as_list('', st.session_state.list_results) st.markdown(""" ### 7. Actor-target coding (experimental) Available in actor-target tab (on the left) """)