Pippoz commited on
Commit
4f14a8f
1 Parent(s): 4779f1e

Adding text-gen and fill-mask tasks

Browse files
__pycache__/multipage.cpython-38.pyc ADDED
Binary file (1.55 kB). View file
multipage.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is the framework for generating multiple Streamlit applications
3
+ through an object oriented framework.
4
+ """
5
+
6
+ # Import necessary libraries
7
+ import streamlit as st
8
+
9
+ # Define the multipage class to manage the multiple apps in our program
10
+ class MultiPage:
11
+ """Framework for combining multiple streamlit applications."""
12
+
13
+ def __init__(self) -> None:
14
+ """Constructor class to generate a list which will store all our applications as an instance variable."""
15
+ self.pages = []
16
+
17
+ def add_page(self, title, func) -> None:
18
+ """Class Method to Add pages to the project
19
+ Args:
20
+ title ([str]): The title of page which we are adding to the list of apps
21
+
22
+ func: Python function to render this page in Streamlit
23
+ """
24
+
25
+ self.pages.append({
26
+
27
+ "title": title,
28
+ "function": func
29
+ })
30
+
31
+ def run(self):
32
+ # Drodown to select the page to run
33
+ page = st.sidebar.selectbox(
34
+ 'App Navigation',
35
+ self.pages,
36
+ format_func=lambda page: page['title']
37
+ )
38
+
39
+ # run the app function
40
+ page['function']()
pages/__pycache__/fill_mask.cpython-38.pyc ADDED
Binary file (2.66 kB). View file
pages/__pycache__/home_page.cpython-38.pyc ADDED
Binary file (423 Bytes). View file
pages/__pycache__/text_gen.cpython-38.pyc ADDED
Binary file (1.69 kB). View file
pages/fill_mask.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import pandas as pd
4
+ import altair as alt
5
+ from multipage import MultiPage
6
+ from transformers import pipeline
7
+
8
+ def app():
9
+ st.markdown('## Mask Fill task')
10
+ st.write('Write a sentence with a [MASK] gap to fill')
11
+ st.markdown('## ')
12
+
13
+
14
+ @st.cache(allow_output_mutation=True, suppress_st_warning =True, show_spinner=False)
15
+ def get_model():
16
+ return pipeline('fill-mask', model='bert-base-uncased',skip_special_tokens=True)
17
+
18
+ def create_graph(answer):
19
+ x_bar = [i['token_str'] for i in answer]
20
+ y_bar = [i['score'] for i in answer]
21
+ chart_data = pd.DataFrame(y_bar, index=x_bar)
22
+ data = pd.melt(chart_data.reset_index(), id_vars=["index"])
23
+ # Horizontal stacked bar chart
24
+ chart = (
25
+ alt.Chart(data)
26
+ .mark_bar(color='#d7abf5')
27
+ .encode(
28
+ x=alt.X("index", type="nominal", title='',sort=alt.EncodingSortField(field="index", op="count", order='ascending')),
29
+ y=alt.Y("value", type="quantitative", title="Score", sort='-x'),
30
+ )
31
+ )
32
+ st.altair_chart(chart, use_container_width=True)
33
+
34
+
35
+ col1, col2 = st.beta_columns([2,1])
36
+
37
+
38
+ with col1:
39
+ prompt= st.text_area('Your prompt here',
40
+ '''Who is Elon [MASK]?''')
41
+
42
+ with col2:
43
+ select_model = st.radio(
44
+ "Select the model to use:",
45
+ ('Bert cased', 'Bert Un-cased'), index = 1)
46
+
47
+ if select_model == 'Bert cased':
48
+ model = 'bert-base-cased'
49
+ elif select_model == 'Bert Un-cased':
50
+ model = 'bert-base-uncased'
51
+
52
+ with st.spinner('Loading Model... (This may take a while)'):
53
+ # unmasker = get_model()
54
+ st.success('Model loaded correctly!')
55
+
56
+ gen = st.info('Generating text...')
57
+ # answer = unmasker(prompt)
58
+ answer = [{'sequence': "[CLS] hello i'm a fashion model. [SEP]", 'score': 0.1073106899857521, 'token': 4827, 'token_str': 'fashion'}, {'sequence': "[CLS] hello i'm a role model. [SEP]", 'score': 0.08774490654468536, 'token': 2535, 'token_str': 'role'}, {'sequence': "[CLS] hello i'm a new model. [SEP]", 'score': 0.05338378623127937, 'token': 2047, 'token_str': 'new'}, {'sequence': "[CLS] hello i'm a super model. [SEP]", 'score': 0.04667217284440994, 'token': 3565, 'token_str': 'super'}, {'sequence': "[CLS] hello i'm a fine model. [SEP]", 'score': 0.027095865458250046, 'token': 2986, 'token_str': 'fine'}]
59
+ gen.empty()
60
+
61
+ with col1:
62
+ create_graph(answer)
63
+
64
+
pages/home_page.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from multipage import MultiPage
3
+
4
+ def app():
5
+ st.title('All-in-One')
6
+ st.markdown('#### Select a specific task from the sidebar...')
7
+
8
+
9
+
pages/text_gen.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ from multipage import MultiPage
4
+ from transformers import pipeline
5
+ import torch
6
+
7
+
8
+ def app():
9
+ st.markdown('## Text Generation task')
10
+ st.write('Write something and AI will continue the sentence ')
11
+ st.markdown('## ')
12
+
13
+ @st.cache(allow_output_mutation=True, suppress_st_warning =True, show_spinner=False)
14
+ def get_model():
15
+ return pipeline('text-generation', model=model, do_sample=True, skip_special_tokens=True)
16
+
17
+ col1, col2 = st.beta_columns([2,1])
18
+
19
+
20
+ with col1:
21
+ prompt= st.text_area('Your prompt here',
22
+ '''Who is Elon Musk?''')
23
+
24
+ with col2:
25
+ select_model = st.radio(
26
+ "Select the model to use:",
27
+ ('OPT-125m', 'OPT-350m'), index = 1)
28
+ if select_model == 'OPT-350m':
29
+ model = 'facebook/opt-350m'
30
+ elif select_model == 'OPT-125m':
31
+ model = 'facebook/opt-125m'
32
+
33
+ with st.spinner('Loading Model... (This may take a while)'):
34
+ generator = get_model()
35
+ st.success('Model loaded correctly!')
36
+
37
+ with col1:
38
+ gen = st.info('Generating text...')
39
+ answer = generator(prompt,
40
+ max_length=max_length, no_repeat_ngram_size=no_ngram_repeat,
41
+ early_stopping=early_stopping, num_beams=num_beams,
42
+ skip_special_tokens=True)
43
+ gen.empty()
44
+
45
+ lst = answer[0]['generated_text']
46
+
47
+ t = st.empty()
48
+ for i in range(len(lst)):
49
+ t.markdown("#### %s..." % lst[0:i])
50
+ time.sleep(0.04)
51
+
52
+
53
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ altair
4
+ pandas
5
+ torch