Spaces:
Runtime error
Runtime error
First version of app
Browse files
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import os
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import torch
|
6 |
+
from transformers import pipeline, set_seed
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
+
|
9 |
+
|
10 |
+
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
|
11 |
+
DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
|
12 |
+
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
|
13 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
14 |
+
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
15 |
+
HEADER_INFO = """
|
16 |
+
# BERTIN-GPT-J-6B
|
17 |
+
Spanish BERTIN GPT-J-6B Model.
|
18 |
+
""".strip()
|
19 |
+
SIDEBAR_INFO = """
|
20 |
+
# Configuration
|
21 |
+
""".strip()
|
22 |
+
PROMPT_BOX = "Introduzca su texto..."
|
23 |
+
EXAMPLES = [
|
24 |
+
"¿Cuál es la capital de Francia? Respuesta:",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def style():
|
29 |
+
st.markdown("""
|
30 |
+
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet">
|
31 |
+
<style>
|
32 |
+
.ltr,
|
33 |
+
textarea {
|
34 |
+
font-family: Roboto !important;
|
35 |
+
text-align: left;
|
36 |
+
direction: ltr !important;
|
37 |
+
}
|
38 |
+
.ltr-box {
|
39 |
+
border-bottom: 1px solid #ddd;
|
40 |
+
padding-bottom: 20px;
|
41 |
+
}
|
42 |
+
.rtl {
|
43 |
+
text-align: left;
|
44 |
+
direction: ltr !important;
|
45 |
+
}
|
46 |
+
span.result-text {
|
47 |
+
padding: 3px 3px;
|
48 |
+
line-height: 32px;
|
49 |
+
}
|
50 |
+
span.generated-text {
|
51 |
+
background-color: rgb(118 200 147 / 13%);
|
52 |
+
}
|
53 |
+
</style>""", unsafe_allow_html=True)
|
54 |
+
|
55 |
+
|
56 |
+
class Normalizer:
|
57 |
+
def remove_repetitions(self, text):
|
58 |
+
"""Remove repetitions"""
|
59 |
+
first_ocurrences = []
|
60 |
+
for sentence in text.split("."):
|
61 |
+
if sentence not in first_ocurrences:
|
62 |
+
first_ocurrences.append(sentence)
|
63 |
+
return '.'.join(first_ocurrences)
|
64 |
+
|
65 |
+
def trim_last_sentence(self, text):
|
66 |
+
"""Trim last sentence if incomplete"""
|
67 |
+
return text[:text.rfind(".") + 1]
|
68 |
+
|
69 |
+
def clean_txt(self, text):
|
70 |
+
return self.trim_last_sentence(self.remove_repetitions(text))
|
71 |
+
|
72 |
+
|
73 |
+
class TextGeneration:
|
74 |
+
def __init__(self):
|
75 |
+
self.tokenizer = None
|
76 |
+
self.generator = None
|
77 |
+
self.task = "text-generation"
|
78 |
+
self.model_name_or_path = MODEL_NAME
|
79 |
+
set_seed(42)
|
80 |
+
|
81 |
+
def load(self):
|
82 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
83 |
+
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
84 |
+
)
|
85 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
86 |
+
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
|
87 |
+
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
|
88 |
+
torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
|
89 |
+
).to(device=DEVICE, non_blocking=True)
|
90 |
+
_ = self.model.eval()
|
91 |
+
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
|
92 |
+
self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number)
|
93 |
+
# with torch.no_grad():
|
94 |
+
# tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
|
95 |
+
# gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128)
|
96 |
+
# generated = tokenizer.batch_decode(gen_tokens)[0]
|
97 |
+
|
98 |
+
# return generated
|
99 |
+
|
100 |
+
|
101 |
+
def generate(self, prompt, generation_kwargs):
|
102 |
+
max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"]
|
103 |
+
generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
|
104 |
+
# generation_kwargs["num_return_sequences"] = 1
|
105 |
+
# generation_kwargs["return_full_text"] = False
|
106 |
+
return self.generator(
|
107 |
+
prompt,
|
108 |
+
**generation_kwargs,
|
109 |
+
)[0]["generated_text"]
|
110 |
+
|
111 |
+
|
112 |
+
@st.cache(allow_output_mutation=True)
|
113 |
+
def load_text_generator():
|
114 |
+
generator = TextGeneration()
|
115 |
+
generator.load()
|
116 |
+
return generator
|
117 |
+
|
118 |
+
|
119 |
+
def main():
|
120 |
+
st.set_page_config(
|
121 |
+
page_title="BERTIN-GPT-J-6B",
|
122 |
+
page_icon="🇪🇸",
|
123 |
+
layout="wide",
|
124 |
+
initial_sidebar_state="expanded"
|
125 |
+
)
|
126 |
+
style()
|
127 |
+
with st.spinner('Loading the model. Please, wait...'):
|
128 |
+
generator = load_text_generator()
|
129 |
+
|
130 |
+
st.sidebar.markdown(SIDEBAR_INFO)
|
131 |
+
|
132 |
+
max_length = st.sidebar.slider(
|
133 |
+
label='Longitud máxima',
|
134 |
+
help="Número máximo aproximado de palabras a generar).",
|
135 |
+
min_value=1,
|
136 |
+
max_value=MAX_LENGTH,
|
137 |
+
value=50,
|
138 |
+
step=1
|
139 |
+
)
|
140 |
+
top_k = st.sidebar.slider(
|
141 |
+
label='Top-k',
|
142 |
+
help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`",
|
143 |
+
min_value=40,
|
144 |
+
max_value=80,
|
145 |
+
value=50,
|
146 |
+
step=1
|
147 |
+
)
|
148 |
+
top_p = st.sidebar.slider(
|
149 |
+
label='Top-p',
|
150 |
+
help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.",
|
151 |
+
min_value=0.0,
|
152 |
+
max_value=1.0,
|
153 |
+
value=0.95,
|
154 |
+
step=0.01
|
155 |
+
)
|
156 |
+
temperature = st.sidebar.slider(
|
157 |
+
label='Temperatura',
|
158 |
+
help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.",
|
159 |
+
min_value=0.1,
|
160 |
+
max_value=10.0,
|
161 |
+
value=0.8,
|
162 |
+
step=0.05
|
163 |
+
)
|
164 |
+
do_sample = st.sidebar.selectbox(
|
165 |
+
label='¿Muestrear?',
|
166 |
+
options=(True, False),
|
167 |
+
help="Si no se muestrea se usará una decodificación voraz (_greedy_).",
|
168 |
+
)
|
169 |
+
do_clean = st.sidebar.selectbox(
|
170 |
+
label='¿Limpiar texto?',
|
171 |
+
options=(True, False),
|
172 |
+
help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.",
|
173 |
+
)
|
174 |
+
generation_kwargs = {
|
175 |
+
"max_length": max_length,
|
176 |
+
"top_k": top_k,
|
177 |
+
"top_p": top_p,
|
178 |
+
"temperature": temperature,
|
179 |
+
"do_sample": do_sample,
|
180 |
+
"do_clean": do_clean,
|
181 |
+
}
|
182 |
+
st.markdown(HEADER_INFO)
|
183 |
+
prompts = EXAMPLES + ["Personalizado"]
|
184 |
+
prompt = st.selectbox('Ejemplos', prompts, index=len(prompts) - 1)
|
185 |
+
|
186 |
+
if prompt == "Personalizado":
|
187 |
+
prompt_box = PROMPT_BOX
|
188 |
+
else:
|
189 |
+
prompt_box = prompt
|
190 |
+
|
191 |
+
text = st.text_area("Texto", prompt_box)
|
192 |
+
generation_kwargs_ph = st.empty()
|
193 |
+
cleaner = Normalizer()
|
194 |
+
if st.button("¡Generar!"):
|
195 |
+
with st.spinner(text="Generando..."):
|
196 |
+
generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
|
197 |
+
if text:
|
198 |
+
generated_text = generator.generate(text, generation_kwargs)
|
199 |
+
if do_clean:
|
200 |
+
generated_text = cleaner.clean_txt(generated_text)
|
201 |
+
if generated_text.strip().startswith(text):
|
202 |
+
generated_text = generated_text.replace(text, "", 1).strip()
|
203 |
+
st.markdown(
|
204 |
+
f'<p class="ltr ltr-box">'
|
205 |
+
f'<span class="result-text">{text} <span>'
|
206 |
+
f'<span class="result-text generated-text">{generated_text}</span>'
|
207 |
+
f'</p>',
|
208 |
+
unsafe_allow_html=True
|
209 |
+
)
|
210 |
+
|
211 |
+
if __name__ == '__main__':
|
212 |
+
main()
|