versae commited on
Commit
0f096f0
·
1 Parent(s): d3358f8

First version of app

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+
4
+ import streamlit as st
5
+ import torch
6
+ from transformers import pipeline, set_seed
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+
10
+ HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
11
+ DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
12
+ DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
13
+ MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
14
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
15
+ HEADER_INFO = """
16
+ # BERTIN-GPT-J-6B
17
+ Spanish BERTIN GPT-J-6B Model.
18
+ """.strip()
19
+ SIDEBAR_INFO = """
20
+ # Configuration
21
+ """.strip()
22
+ PROMPT_BOX = "Introduzca su texto..."
23
+ EXAMPLES = [
24
+ "¿Cuál es la capital de Francia? Respuesta:",
25
+ ]
26
+
27
+
28
+ def style():
29
+ st.markdown("""
30
+ <link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet">
31
+ <style>
32
+ .ltr,
33
+ textarea {
34
+ font-family: Roboto !important;
35
+ text-align: left;
36
+ direction: ltr !important;
37
+ }
38
+ .ltr-box {
39
+ border-bottom: 1px solid #ddd;
40
+ padding-bottom: 20px;
41
+ }
42
+ .rtl {
43
+ text-align: left;
44
+ direction: ltr !important;
45
+ }
46
+ span.result-text {
47
+ padding: 3px 3px;
48
+ line-height: 32px;
49
+ }
50
+ span.generated-text {
51
+ background-color: rgb(118 200 147 / 13%);
52
+ }
53
+ </style>""", unsafe_allow_html=True)
54
+
55
+
56
+ class Normalizer:
57
+ def remove_repetitions(self, text):
58
+ """Remove repetitions"""
59
+ first_ocurrences = []
60
+ for sentence in text.split("."):
61
+ if sentence not in first_ocurrences:
62
+ first_ocurrences.append(sentence)
63
+ return '.'.join(first_ocurrences)
64
+
65
+ def trim_last_sentence(self, text):
66
+ """Trim last sentence if incomplete"""
67
+ return text[:text.rfind(".") + 1]
68
+
69
+ def clean_txt(self, text):
70
+ return self.trim_last_sentence(self.remove_repetitions(text))
71
+
72
+
73
+ class TextGeneration:
74
+ def __init__(self):
75
+ self.tokenizer = None
76
+ self.generator = None
77
+ self.task = "text-generation"
78
+ self.model_name_or_path = MODEL_NAME
79
+ set_seed(42)
80
+
81
+ def load(self):
82
+ self.tokenizer = AutoTokenizer.from_pretrained(
83
+ self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
84
+ )
85
+ self.model = AutoModelForCausalLM.from_pretrained(
86
+ self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
87
+ pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
88
+ torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
89
+ ).to(device=DEVICE, non_blocking=True)
90
+ _ = self.model.eval()
91
+ device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
92
+ self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number)
93
+ # with torch.no_grad():
94
+ # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
95
+ # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128)
96
+ # generated = tokenizer.batch_decode(gen_tokens)[0]
97
+
98
+ # return generated
99
+
100
+
101
+ def generate(self, prompt, generation_kwargs):
102
+ max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"]
103
+ generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
104
+ # generation_kwargs["num_return_sequences"] = 1
105
+ # generation_kwargs["return_full_text"] = False
106
+ return self.generator(
107
+ prompt,
108
+ **generation_kwargs,
109
+ )[0]["generated_text"]
110
+
111
+
112
+ @st.cache(allow_output_mutation=True)
113
+ def load_text_generator():
114
+ generator = TextGeneration()
115
+ generator.load()
116
+ return generator
117
+
118
+
119
+ def main():
120
+ st.set_page_config(
121
+ page_title="BERTIN-GPT-J-6B",
122
+ page_icon="🇪🇸",
123
+ layout="wide",
124
+ initial_sidebar_state="expanded"
125
+ )
126
+ style()
127
+ with st.spinner('Loading the model. Please, wait...'):
128
+ generator = load_text_generator()
129
+
130
+ st.sidebar.markdown(SIDEBAR_INFO)
131
+
132
+ max_length = st.sidebar.slider(
133
+ label='Longitud máxima',
134
+ help="Número máximo aproximado de palabras a generar).",
135
+ min_value=1,
136
+ max_value=MAX_LENGTH,
137
+ value=50,
138
+ step=1
139
+ )
140
+ top_k = st.sidebar.slider(
141
+ label='Top-k',
142
+ help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`",
143
+ min_value=40,
144
+ max_value=80,
145
+ value=50,
146
+ step=1
147
+ )
148
+ top_p = st.sidebar.slider(
149
+ label='Top-p',
150
+ help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.",
151
+ min_value=0.0,
152
+ max_value=1.0,
153
+ value=0.95,
154
+ step=0.01
155
+ )
156
+ temperature = st.sidebar.slider(
157
+ label='Temperatura',
158
+ help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.",
159
+ min_value=0.1,
160
+ max_value=10.0,
161
+ value=0.8,
162
+ step=0.05
163
+ )
164
+ do_sample = st.sidebar.selectbox(
165
+ label='¿Muestrear?',
166
+ options=(True, False),
167
+ help="Si no se muestrea se usará una decodificación voraz (_greedy_).",
168
+ )
169
+ do_clean = st.sidebar.selectbox(
170
+ label='¿Limpiar texto?',
171
+ options=(True, False),
172
+ help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.",
173
+ )
174
+ generation_kwargs = {
175
+ "max_length": max_length,
176
+ "top_k": top_k,
177
+ "top_p": top_p,
178
+ "temperature": temperature,
179
+ "do_sample": do_sample,
180
+ "do_clean": do_clean,
181
+ }
182
+ st.markdown(HEADER_INFO)
183
+ prompts = EXAMPLES + ["Personalizado"]
184
+ prompt = st.selectbox('Ejemplos', prompts, index=len(prompts) - 1)
185
+
186
+ if prompt == "Personalizado":
187
+ prompt_box = PROMPT_BOX
188
+ else:
189
+ prompt_box = prompt
190
+
191
+ text = st.text_area("Texto", prompt_box)
192
+ generation_kwargs_ph = st.empty()
193
+ cleaner = Normalizer()
194
+ if st.button("¡Generar!"):
195
+ with st.spinner(text="Generando..."):
196
+ generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
197
+ if text:
198
+ generated_text = generator.generate(text, generation_kwargs)
199
+ if do_clean:
200
+ generated_text = cleaner.clean_txt(generated_text)
201
+ if generated_text.strip().startswith(text):
202
+ generated_text = generated_text.replace(text, "", 1).strip()
203
+ st.markdown(
204
+ f'<p class="ltr ltr-box">'
205
+ f'<span class="result-text">{text} <span>'
206
+ f'<span class="result-text generated-text">{generated_text}</span>'
207
+ f'</p>',
208
+ unsafe_allow_html=True
209
+ )
210
+
211
+ if __name__ == '__main__':
212
+ main()