HamidRezaAttar commited on
Commit
d7acda5
1 Parent(s): 641df6e

First demo version

Browse files
__pycache__/examples.cpython-39.pyc ADDED
Binary file (446 Bytes). View file
 
__pycache__/meta.cpython-39.pyc ADDED
Binary file (292 Bytes). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.22 kB). View file
 
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, set_seed
3
+ from transformers import AutoTokenizer
4
+ import random
5
+
6
+ import meta
7
+ import examples
8
+ from utils import (
9
+ remote_css,
10
+ local_css
11
+ )
12
+
13
+
14
+ class TextGeneration:
15
+ def __init__(self):
16
+ self.debug = False
17
+ self.dummy_output = None
18
+ self.tokenizer = None
19
+ self.generator = None
20
+ self.task = "text-generation"
21
+ self.model_name_or_path = "HamidRezaAttar/gpt2-product-description-generator"
22
+ set_seed(42)
23
+
24
+ def load(self):
25
+ if not self.debug:
26
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
27
+ self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path)
28
+
29
+ def generate(self, prompt, generation_kwargs):
30
+ if not self.debug:
31
+ generation_kwargs["num_return_sequences"] = 1
32
+
33
+ max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"]
34
+ generation_kwargs["max_length"] = max_length
35
+
36
+ generation_kwargs["return_full_text"] = False
37
+
38
+ return self.generator(
39
+ prompt,
40
+ **generation_kwargs,
41
+ )[0]["generated_text"]
42
+
43
+ return self.dummy_output
44
+
45
+
46
+ @st.cache(allow_output_mutation=True)
47
+ def load_text_generator():
48
+ generator = TextGeneration()
49
+ generator.load()
50
+ return generator
51
+
52
+
53
+ def main():
54
+ st.set_page_config(
55
+ page_title="GPT2 - Home",
56
+ page_icon="🏡",
57
+ layout="wide",
58
+ initial_sidebar_state="expanded"
59
+ )
60
+ remote_css("https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/dist/font-face.css")
61
+ local_css("assets/rtl.css")
62
+ generator = load_text_generator()
63
+
64
+ st.sidebar.markdown(meta.SIDEBAR_INFO)
65
+
66
+ max_length = st.sidebar.slider(
67
+ label='Max Length',
68
+ help="The maximum length of the sequence to be generated.",
69
+ min_value=1,
70
+ max_value=128,
71
+ value=50,
72
+ step=1
73
+ )
74
+ top_k = st.sidebar.slider(
75
+ label='Top-k',
76
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering",
77
+ min_value=40,
78
+ max_value=80,
79
+ value=50,
80
+ step=1
81
+ )
82
+ top_p = st.sidebar.slider(
83
+ label='Top-p',
84
+ help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for "
85
+ "generation.",
86
+ min_value=0.0,
87
+ max_value=1.0,
88
+ value=0.95,
89
+ step=0.01
90
+ )
91
+ temperature = st.sidebar.slider(
92
+ label='Temperature',
93
+ help="The value used to module the next token probabilities",
94
+ min_value=0.1,
95
+ max_value=10.0,
96
+ value=1.0,
97
+ step=0.05
98
+ )
99
+ do_sample = st.sidebar.selectbox(
100
+ label='Sampling ?',
101
+ options=(True, False),
102
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
103
+ )
104
+ generation_kwargs = {
105
+ "max_length": max_length,
106
+ "top_k": top_k,
107
+ "top_p": top_p,
108
+ "temperature": temperature,
109
+ "do_sample": do_sample,
110
+ }
111
+
112
+ st.markdown(meta.HEADER_INFO)
113
+ prompts = list(examples.EXAMPLES.keys()) + ["Custom"]
114
+ prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
115
+
116
+ if prompt == "Custom":
117
+ prompt_box = meta.PROMPT_BOX
118
+ else:
119
+ prompt_box = random.choice(examples.EXAMPLES[prompt])
120
+
121
+ text = st.text_area("Enter text", prompt_box)
122
+ generation_kwargs_ph = st.empty()
123
+
124
+ if st.button("Generate !"):
125
+ with st.spinner(text="Generating ..."):
126
+ generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
127
+ if text:
128
+ generated_text = generator.generate(text, generation_kwargs)
129
+ st.markdown(
130
+ f'<p class="rtl rtl-box">'
131
+ f'<span class="result-text">{text} <span>'
132
+ f'<span class="result-text generated-text">{generated_text}</span>'
133
+ f'</p>',
134
+ unsafe_allow_html=True
135
+ )
136
+
137
+ if __name__ == '__main__':
138
+ main()
assets/rtl.css ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .rtl,
2
+ textarea {
3
+ font-family: Vazir !important;
4
+ text-align: right;
5
+ direction: rtl !important;
6
+ }
7
+ .rtl-box {
8
+ border-bottom: 1px solid #ddd;
9
+ padding-bottom: 20px;
10
+ }
11
+ .ltr {
12
+ text-align: left;
13
+ direction: ltr !important;
14
+ }
15
+
16
+ span.result-text {
17
+ padding: 3px 3px;
18
+ line-height: 32px;
19
+ }
20
+ span.generated-text {
21
+ background-color: rgb(118 200 147 / 13%);
22
+ }
examples.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXAMPLES = {
2
+ "Table": [
3
+ "Handcrafted of solid acacia in weathered gray, our round Jozy drop-leaf dining table is a space-saving."
4
+ ],
5
+ "Bed": [
6
+ "Maximize your bedroom space without sacrificing style with the storage bed."
7
+ ],
8
+ "Sofa": [
9
+ "Our plush and luxurious Emmett modular sofa brings custom comfort to your living space."
10
+ ]
11
+ }
meta.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ HEADER_INFO = """
2
+ # GPT2 - Home
3
+ English GPT-2 home product description generator demo.
4
+ """.strip()
5
+ SIDEBAR_INFO = """
6
+ # Configuration
7
+ """.strip()
8
+ PROMPT_BOX = "Enter your text..."
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ hazm
3
+ Pillow
4
+ mtranslate
5
+ torch
6
+ transformers
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from PIL import Image
4
+
5
+
6
+ def load_image(image_path, image_resize=None):
7
+ image = Image.open(image_path)
8
+ if isinstance(image_resize, tuple):
9
+ image.resize(image_resize)
10
+ return image
11
+
12
+
13
+ def load_text(text_path):
14
+ text = ''
15
+ with open(text_path) as f:
16
+ text = f.read()
17
+
18
+ return text
19
+
20
+
21
+ def load_json(json_path):
22
+ jdata = ''
23
+ with open(json_path) as f:
24
+ jdata = json.load(f)
25
+
26
+ return jdata
27
+
28
+
29
+ def local_css(css_path):
30
+ with open(css_path) as f:
31
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
32
+
33
+
34
+ def remote_css(css_url):
35
+ st.markdown(f'<link href="{css_url}" rel="stylesheet">', unsafe_allow_html=True)
36
+