pr-ent commited on
Commit
f35867d
1 Parent(s): ead2fc0

app prompting

Browse files
Files changed (2) hide show
  1. app.py +74 -2
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,4 +1,76 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ from transformers import pipeline
6
+
7
+
8
+ from nltk.tokenize import sent_tokenize
9
+ import nltk
10
+
11
+ @st.cache()
12
+ def download_punkt():
13
+ nltk.download('punkt')
14
+
15
+ download_punkt()
16
+
17
+ def choose_text_menu(text):
18
+
19
+ text = st.text_area('Text to analyze', 'Several demonstrators were injured.')
20
+
21
+ return text
22
+
23
+
24
+ # Load Models in cache
25
+ @st.cache(allow_output_mutation=True)
26
+ def load_model_prompting():
27
+ return pipeline("fill-mask", model="distilbert-base-uncased")
28
+
29
+
30
+ ###### Prompting
31
+ def query_model_prompting(model, text, prompt_with_mask, top_k, targets):
32
+ sequence = text + prompt_with_mask
33
+ output_tokens = model(sequence, top_k=top_k, targets=targets)
34
+
35
+ return output_tokens
36
+
37
+
38
+ def display_pr_results_as_list(prompt, list_results):
39
+ prompt_mix_list = st.multiselect(
40
+ prompt,
41
+ list_results
42
+ ,
43
+ list_results, key='results_mix')
44
+
45
+
46
+ ### App START
47
+ st.markdown("# Tag text based on the Prompt approach")
48
+ st.markdown("## Author: ...")
49
+ st.markdown("**Tips:** - If the [MASK] of your prompt is at the end of the sentence. Don't forget to put a punctuation sign after or the prompt will outputs punctuations.")
50
+
51
+
52
+ model_prompting = load_model_prompting()
53
+
54
+ prompt_mix = st.text_input('Prompt with a [MASK]:','This event involves [MASK].')
55
+ prompt_mix_list = [prompt_mix]
56
+
57
+
58
+
59
+ top_k = st.number_input('Number of max tokens to output (higher = more computation time)? ',step = 1, min_value=0, max_value=50, value=10)
60
+
61
+ text = choose_text_menu('')
62
+
63
+ for prompt in prompt_mix_list:
64
+ model_load_state = st.text('Tagging Running')
65
+ prompt = prompt.replace('[MASK]', '{}')
66
+ output_tokens = query_model_prompting(model_prompting, text, prompt, top_k, targets=None)
67
+ list_results = []
68
+ for each in output_tokens:
69
+ list_results.append(each["token_str"] + ' ' + str(int(each['score']*100)) + '%')
70
+ display_pr_results_as_list(prompt, list_results)
71
+ model_load_state.text("Tagging Done!")
72
+
73
+
74
+
75
+
76
 
 
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ # see environments.yml
2
+ numpy
3
+ pandas
4
+ transformers[torch]
5
+ nltk
6
+ sentence_transformers