Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from transformers import pipeline | |
from nltk.tokenize import sent_tokenize | |
import nltk | |
def download_punkt(): | |
nltk.download('punkt') | |
download_punkt() | |
def choose_text_menu(text): | |
text = st.text_area('Text to analyze', 'Several demonstrators were injured.') | |
return text | |
# Load Models in cache | |
def load_model_prompting(): | |
return pipeline("fill-mask", model="distilbert-base-uncased") | |
###### Prompting | |
def query_model_prompting(model, text, prompt_with_mask, top_k, targets): | |
sequence = text + prompt_with_mask | |
output_tokens = model(sequence, top_k=top_k, targets=targets) | |
return output_tokens | |
def display_pr_results_as_list(prompt, list_results): | |
prompt_mix_list = st.multiselect( | |
prompt, | |
list_results | |
, | |
list_results, key='results_mix') | |
### App START | |
st.markdown("# Tag text based on the Prompt approach") | |
st.markdown("## Author: ...") | |
st.markdown("**Tips:** - If the [MASK] of your prompt is at the end of the sentence. Don't forget to put a punctuation sign after or the prompt will outputs punctuations.") | |
model_prompting = load_model_prompting() | |
prompt_mix = st.text_input('Prompt with a [MASK]:','This event involves [MASK].') | |
prompt_mix_list = [prompt_mix] | |
top_k = st.number_input('Number of max tokens to output (higher = more computation time)? ',step = 1, min_value=0, max_value=50, value=10) | |
text = choose_text_menu('') | |
for prompt in prompt_mix_list: | |
model_load_state = st.text('Tagging Running') | |
prompt = prompt.replace('[MASK]', '{}') | |
prompt = prompt.format(model_prompting.tokenizer.mask_token) | |
output_tokens = query_model_prompting(model_prompting, text, prompt, top_k, targets=None) | |
list_results = [] | |
for each in output_tokens: | |
list_results.append(each["token_str"] + ' ' + str(int(each['score']*100)) + '%') | |
display_pr_results_as_list(prompt, list_results) | |
model_load_state.text("Tagging Done!") | |