PR-ENT_Dashboard / PR_ENT.py
PR ENT
Push dashboard
88d8172
raw history blame
No virus
6.56 kB
import streamlit as st
from PIL import Image
from transformers import pipeline
import nltk
import spacy
import en_core_web_lg
from helpers import prompt_to_nli, display_nli_pr_results_as_list
@st.cache()
def download_punkt():
nltk.download('punkt')
@st.cache(allow_output_mutation=True)
def load_spacy_pipeline():
return en_core_web_lg.load()
def choose_text_menu(text):
if 'text' not in st.session_state:
st.session_state.text = 'Several demonstrators were injured.'
text = st.text_area('Event description', st.session_state.text)
return text
# Load Models in cache
@st.cache(allow_output_mutation=True)
def load_model_prompting():
return pipeline("fill-mask", model="distilbert-base-uncased")
@st.cache(allow_output_mutation=True)
def load_model_nli():
return pipeline(task="sentiment-analysis", model="roberta-large-mnli")
download_punkt()
nlp = load_spacy_pipeline()
### App START
st.markdown("""# Rethinking the Event Coding Pipeline with Prompt Entailment
## Author: Anonymized for submission""")
st.markdown("""
### 1. PR-ENT summary
""")
@st.cache()
def load_prent_image():
return Image.open('pipeline_flow.png')
st.image(load_prent_image(), caption="""PR-ENT Flow. First, we concatenate the event description e and the template t. Then we feed them through a pretrained
prompting model to obtain a list of answer candidates. Then, for each answer candidate we build a
hypothesis by filling the template and check for entailment with the premise (the event description).
Finally, by filtering on the entailment score, we obtain a list of entailed answer candidates related to the event description.
""")
model_nli = load_model_nli()
model_prompting = load_model_prompting()
st.markdown("""
### 2. Write an event description:
The first step is to write an event description that will be fed to the pipeline. This can be any text in English.
""")
text = choose_text_menu('')
st.session_state.text = text
st.markdown("""
### 3. Template design:
The second step is to design a template while keeping in mind the objective of the classification.
- A good starting point is to use `This event involves [Z].`. This template will ideally be filled with a 1 word summary of the event.
- Another good example is `People were [Z].`. With this one we mostly expect a verb that describes the action.
You can also use any template you design. Keep in mind that if the masked slot `[Z]` is at the end of the sentence, to not forget punctuation,
otherwise the model may fill the template with punctuation signs.
""")
if 'prompt' not in st.session_state:
st.session_state.prompt = 'This event involves [Z].'
prompt = st.text_input('Template:',st.session_state.prompt)
st.session_state.prompt = prompt
st.markdown("""
### 4. Select the two parameters:
- The first parameter `top_k` is the maximum number of tokens that will be given by the prompting model.
It's also the number of tokens that will be tried for entailment. Ideally, you want a high enough number of tokens, otherwise you may miss critical information.
However, each additional token will increase the computation time as it needs to go through the entailment model.
From our experiments, a good choice is between `[10,50]`, lower and you miss information, higher and you start getting unrelated tokens and long computation time.
- The second parameter is the minimum entailment score to confirm that the token is entailed with the event description.
By default, we set it at `0.5` (more entailed than not) but it can be modified depending on needs.
""")
def select_top_k():
if 'top_k' not in st.session_state:
st.session_state.top_k = 10
return st.number_input('Number of max tokens to output (default: 10, min: 0, max: 50)? ',step = 100, min_value=0, max_value=50, value=int(st.session_state.top_k))
def select_nli_limit():
if 'nli_limit' not in st.session_state:
st.session_state.nli_limit = 0.5
return st.number_input('Minimum score of entailment (default: 0.5, min: 0, max: 1)? ',step = 100.0, min_value=0.0, max_value=1.0, value=st.session_state.nli_limit)
def update_session_state_callback(value, key):
st.session_state[value] = st.session_state[key]
top_k = select_top_k()
st.session_state.top_k = top_k
nli_limit = select_nli_limit()
st.session_state.nli_limit = nli_limit
st.markdown("""
### 5. Remove similar tokens from output:
An additional option is to remove similar tokens (e.g. `protest, protests`) from the output.
This computes the lemma of each word (based on the template) and removes duplicate lemmas.
""")
if 'remove_lemma' not in st.session_state:
st.session_state.remove_lemma = False
remove_lemma = st.checkbox('Remove similar lemma (e.g. protest, protests) from output?', value= st.session_state.remove_lemma)
st.session_state.remove_lemma = remove_lemma
# Save settings to display before the results
if "old_prompt" not in st.session_state:
st.session_state.old_text =st.session_state.text
st.session_state.old_prompt =st.session_state.prompt
st.session_state.old_top_k = st.session_state.top_k
st.session_state.old_nli_limit = st.session_state.nli_limit
st.markdown("""
### 6. Run the pipeline
""")
st.markdown("""The entailed tokens are given as a list of words associated with the probability of entailment.""")
if st.button("Run PR-ENT"):
computation_state_prent = st.text("PR-ENT Computation Running.")
st.session_state.old_text =st.session_state.text
st.session_state.old_prompt =st.session_state.prompt
st.session_state.old_top_k = st.session_state.top_k
st.session_state.old_nli_limit = st.session_state.nli_limit
# Replace the mask
prompt = prompt.replace('[Z]', '{}')
prompt = prompt.replace('[MASK]', '{}')
results = prompt_to_nli(text, prompt, model_prompting, model_nli, nlp, top_k, nli_limit, remove_lemma)
list_results = [x[0][0] + ' ' + str(int(x[1][1]*100)) + '%' for x in results]
st.session_state.list_results = list_results
computation_state_prent.text("PR-ENT Computation Done.")
if 'list_results' in st.session_state:
st.write('**Event Description**: {}'.format(st.session_state.old_text))
st.write('**Template**: "{}"; **Top K**: {}; **Entailment Threshold**: {}.'.format(st.session_state.old_prompt,st.session_state.old_top_k, st.session_state.old_nli_limit))
display_nli_pr_results_as_list('', st.session_state.list_results)
st.markdown("""
### 7. Actor-target coding (experimental)
Available in actor-target tab (on the left)
""")