pr-ent's picture
Fix mask bug
e88235f
import streamlit as st
import pandas as pd
import numpy as np
from transformers import pipeline
from nltk.tokenize import sent_tokenize
import nltk
@st.cache()
def download_punkt():
nltk.download('punkt')
download_punkt()
def choose_text_menu(text):
text = st.text_area('Text to analyze', 'Several demonstrators were injured.')
return text
# Load Models in cache
@st.cache(allow_output_mutation=True)
def load_model_prompting():
return pipeline("fill-mask", model="distilbert-base-uncased")
###### Prompting
def query_model_prompting(model, text, prompt_with_mask, top_k, targets):
sequence = text + prompt_with_mask
output_tokens = model(sequence, top_k=top_k, targets=targets)
return output_tokens
def display_pr_results_as_list(prompt, list_results):
prompt_mix_list = st.multiselect(
prompt,
list_results
,
list_results, key='results_mix')
### App START
st.markdown("# Tag text based on the Prompt approach")
st.markdown("## Author: ...")
st.markdown("**Tips:** - If the [MASK] of your prompt is at the end of the sentence. Don't forget to put a punctuation sign after or the prompt will outputs punctuations.")
model_prompting = load_model_prompting()
prompt_mix = st.text_input('Prompt with a [MASK]:','This event involves [MASK].')
prompt_mix_list = [prompt_mix]
top_k = st.number_input('Number of max tokens to output (higher = more computation time)? ',step = 1, min_value=0, max_value=50, value=10)
text = choose_text_menu('')
for prompt in prompt_mix_list:
model_load_state = st.text('Tagging Running')
prompt = prompt.replace('[MASK]', '{}')
prompt = prompt.format(model_prompting.tokenizer.mask_token)
output_tokens = query_model_prompting(model_prompting, text, prompt, top_k, targets=None)
list_results = []
for each in output_tokens:
list_results.append(each["token_str"] + ' ' + str(int(each['score']*100)) + '%')
display_pr_results_as_list(prompt, list_results)
model_load_state.text("Tagging Done!")