File size: 2,057 Bytes
ead2fc0
f35867d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e88235f
f35867d
 
 
 
 
 
 
 
 
 
ead2fc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
import pandas as pd
import numpy as np

from transformers import pipeline


from nltk.tokenize import sent_tokenize
import nltk

@st.cache()
def download_punkt():
    nltk.download('punkt')

download_punkt()

def choose_text_menu(text):

    text = st.text_area('Text to analyze', 'Several demonstrators were injured.')

    return text


# Load Models in cache
@st.cache(allow_output_mutation=True)
def load_model_prompting():
    return pipeline("fill-mask", model="distilbert-base-uncased")


###### Prompting
def query_model_prompting(model, text, prompt_with_mask, top_k, targets):
    sequence = text + prompt_with_mask
    output_tokens = model(sequence, top_k=top_k, targets=targets)
    
    return output_tokens


def display_pr_results_as_list(prompt, list_results):
    prompt_mix_list = st.multiselect(
        prompt,
        list_results
        ,
        list_results, key='results_mix')


### App START
st.markdown("# Tag text based on the Prompt approach")
st.markdown("## Author: ...")
st.markdown("**Tips:** - If the [MASK] of your prompt is at the end of the sentence. Don't forget to put a punctuation sign after or the prompt will outputs punctuations.")


model_prompting = load_model_prompting()

prompt_mix = st.text_input('Prompt with a [MASK]:','This event involves [MASK].')
prompt_mix_list = [prompt_mix]



top_k = st.number_input('Number of max tokens to output (higher = more computation time)? ',step = 1, min_value=0, max_value=50, value=10)

text = choose_text_menu('')

for prompt in prompt_mix_list:
    model_load_state = st.text('Tagging Running')
    prompt = prompt.replace('[MASK]', '{}')
    prompt = prompt.format(model_prompting.tokenizer.mask_token)
    output_tokens = query_model_prompting(model_prompting, text, prompt, top_k, targets=None)
    list_results = []
    for each in output_tokens:
        list_results.append(each["token_str"] + ' ' + str(int(each['score']*100)) + '%')
    display_pr_results_as_list(prompt, list_results)
    model_load_state.text("Tagging Done!")