pr-ent's picture
Fix mask bug
e88235f
raw history blame
No virus
2.06 kB
import streamlit as st
import pandas as pd
import numpy as np
from transformers import pipeline
from nltk.tokenize import sent_tokenize
import nltk
@st.cache()
def download_punkt():
nltk.download('punkt')
download_punkt()
def choose_text_menu(text):
text = st.text_area('Text to analyze', 'Several demonstrators were injured.')
return text
# Load Models in cache
@st.cache(allow_output_mutation=True)
def load_model_prompting():
return pipeline("fill-mask", model="distilbert-base-uncased")
###### Prompting
def query_model_prompting(model, text, prompt_with_mask, top_k, targets):
sequence = text + prompt_with_mask
output_tokens = model(sequence, top_k=top_k, targets=targets)
return output_tokens
def display_pr_results_as_list(prompt, list_results):
prompt_mix_list = st.multiselect(
prompt,
list_results
,
list_results, key='results_mix')
### App START
st.markdown("# Tag text based on the Prompt approach")
st.markdown("## Author: ...")
st.markdown("**Tips:** - If the [MASK] of your prompt is at the end of the sentence. Don't forget to put a punctuation sign after or the prompt will outputs punctuations.")
model_prompting = load_model_prompting()
prompt_mix = st.text_input('Prompt with a [MASK]:','This event involves [MASK].')
prompt_mix_list = [prompt_mix]
top_k = st.number_input('Number of max tokens to output (higher = more computation time)? ',step = 1, min_value=0, max_value=50, value=10)
text = choose_text_menu('')
for prompt in prompt_mix_list:
model_load_state = st.text('Tagging Running')
prompt = prompt.replace('[MASK]', '{}')
prompt = prompt.format(model_prompting.tokenizer.mask_token)
output_tokens = query_model_prompting(model_prompting, text, prompt, top_k, targets=None)
list_results = []
for each in output_tokens:
list_results.append(each["token_str"] + ' ' + str(int(each['score']*100)) + '%')
display_pr_results_as_list(prompt, list_results)
model_load_state.text("Tagging Done!")