Spaces:
Runtime error
Runtime error
New changes to demo
Browse files- app.py +177 -16
- requirements.txt +2 -0
app.py
CHANGED
@@ -1,25 +1,20 @@
|
|
1 |
import random
|
2 |
from mtranslate import translate
|
3 |
import streamlit as st
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png"
|
8 |
-
|
9 |
-
MODELS = {
|
10 |
-
"RoBERTa Base Gaussian Seq Len 128": {
|
11 |
-
"url": "bertin-project/bertin-base-gaussian"
|
12 |
-
},
|
13 |
-
"RoBERTa Base Gaussian Seq Len 512": {
|
14 |
-
"url": "bertin-project/bertin-base-gaussian-exp-512seqlen"
|
15 |
-
},
|
16 |
-
"RoBERTa Base Random Seq Len 128": {
|
17 |
-
"url": "bertin-project/bertin-base-random"
|
18 |
-
},
|
19 |
-
"RoBERTa Base Stepwise Seq Len 128": {
|
20 |
-
"url": "bertin-project/bertin-base-stepwise"
|
21 |
-
},
|
22 |
-
}
|
23 |
|
24 |
PROMPT_LIST = [
|
25 |
"Fui a la librería a comprar un <mask>.",
|
@@ -37,6 +32,12 @@ PROMPT_LIST = [
|
|
37 |
"Al pan, pan, y al vino, <mask>.",
|
38 |
]
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
@st.cache(show_spinner=False, persist=True)
|
42 |
def load_model(masked_text, model_url):
|
@@ -47,6 +48,26 @@ def load_model(masked_text, model_url):
|
|
47 |
return result
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Page
|
51 |
st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO)
|
52 |
st.title("BERTIN")
|
@@ -84,6 +105,11 @@ st.markdown(
|
|
84 |
The first models have been trained (250.000 steps) on sequence length 128, and then training for Gaussian changed to sequence length 512 for the last 25.000 training steps to yield another version.
|
85 |
|
86 |
Please read our [full report](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) for more details on the methodology and metrics on downstream tasks.
|
|
|
|
|
|
|
|
|
|
|
87 |
"""
|
88 |
)
|
89 |
|
@@ -112,6 +138,141 @@ if st.button("Fill the mask"):
|
|
112 |
st.write("_English_ _translation:_", translate(result_sequence, "en", "es"))
|
113 |
st.write(result)
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
st.markdown(
|
116 |
"""
|
117 |
### Team members
|
|
|
1 |
import random
|
2 |
from mtranslate import translate
|
3 |
import streamlit as st
|
4 |
+
import seaborn as sns
|
5 |
+
from spacy import displacy
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoTokenizer,
|
9 |
+
AutoModelForMaskedLM,
|
10 |
+
AutoModelForSequenceClassification,
|
11 |
+
AutoModelForTokenClassification,
|
12 |
+
pipeline
|
13 |
+
)
|
14 |
|
15 |
|
16 |
LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png"
|
17 |
+
WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
PROMPT_LIST = [
|
20 |
"Fui a la librería a comprar un <mask>.",
|
|
|
32 |
"Al pan, pan, y al vino, <mask>.",
|
33 |
]
|
34 |
|
35 |
+
PAWS_X_PROMPT_LIST = [
|
36 |
+
"Te amo.</s>Te adoro.",
|
37 |
+
"Te odio.</s>Te detesto.",
|
38 |
+
"Me gusta montar en bicicleta.</s>París es una ciudad francesa."
|
39 |
+
]
|
40 |
+
|
41 |
|
42 |
@st.cache(show_spinner=False, persist=True)
|
43 |
def load_model(masked_text, model_url):
|
|
|
48 |
return result
|
49 |
|
50 |
|
51 |
+
@st.cache(show_spinner=False, persist=True)
|
52 |
+
def load_model(masked_text, model_url):
|
53 |
+
model = AutoModelForMaskedLM.from_pretrained(model_url)
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained(model_url)
|
55 |
+
nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
|
56 |
+
result = nlp(masked_text)
|
57 |
+
return result
|
58 |
+
|
59 |
+
|
60 |
+
@st.cache(show_spinner=False, persist=True)
|
61 |
+
def load_model_pair_classification(text, model_url_pair_classification):
|
62 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_url_pair_classification)
|
63 |
+
tokenizer = AutoTokenizer.from_pretrained(model_url)
|
64 |
+
nlp = pipeline("text-classification", model=model, tokenizer=tokenizer)
|
65 |
+
result = nlp(f"{text}</s>")
|
66 |
+
if result[0]["label"] == "LABEL_0":
|
67 |
+
return f"Different meaning: {result[0]['score']:02f}"
|
68 |
+
return f"Paraphrase: {result[0]['score']:02f}"
|
69 |
+
|
70 |
+
|
71 |
# Page
|
72 |
st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO)
|
73 |
st.title("BERTIN")
|
|
|
105 |
The first models have been trained (250.000 steps) on sequence length 128, and then training for Gaussian changed to sequence length 512 for the last 25.000 training steps to yield another version.
|
106 |
|
107 |
Please read our [full report](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) for more details on the methodology and metrics on downstream tasks.
|
108 |
+
|
109 |
+
### Masked language modeling
|
110 |
+
|
111 |
+
Here you can play with the filling the mask objective of all the models.
|
112 |
+
|
113 |
"""
|
114 |
)
|
115 |
|
|
|
138 |
st.write("_English_ _translation:_", translate(result_sequence, "en", "es"))
|
139 |
st.write(result)
|
140 |
|
141 |
+
st.markdown(
|
142 |
+
"""
|
143 |
+
### Fine-tuning to PAWS-X for paraphrase identification
|
144 |
+
Here you can play with the RoBERTa Base Gaussian Seq Len 512 model fine-tuned to PAWS-X.
|
145 |
+
"""
|
146 |
+
)
|
147 |
+
|
148 |
+
pawsx_model_url = "bertin-project/bertin-base-paws-x-es"
|
149 |
+
paraphrase_prompt = st.selectbox("Paraphrase Prompt", ["Random", "Custom"])
|
150 |
+
if paraphrase_prompt == "Custom":
|
151 |
+
paraphrase_prompt_box = "Enter two sentences separated by </s> here..."
|
152 |
+
else:
|
153 |
+
paraphrase_prompt_box = random.choice(PAWS_X_PROMPT_LIST)
|
154 |
+
text = st.text_area("Enter text", paraphrase_prompt_box)
|
155 |
+
if st.button("Clasify paraphrasing"):
|
156 |
+
with st.spinner(text="Clasifying paraphrasing..."):
|
157 |
+
st.subheader("Classification result")
|
158 |
+
paraphrase_score = load_model_pair_classification(text, pawsx_model_url)
|
159 |
+
st.write("_English_ _translation:_", translate(text, "en", "es"))
|
160 |
+
st.write(paraphrase_score)
|
161 |
+
|
162 |
+
|
163 |
+
def make_color_palette(labels):
|
164 |
+
color_palette = sns.color_palette(n_colors=len(labels))
|
165 |
+
color_map = {x: rgb2hex(*y) for x, y in zip(labels, color_palette)}
|
166 |
+
return color_map
|
167 |
+
|
168 |
+
|
169 |
+
@st.cache(allow_output_mutation=True)
|
170 |
+
def get_colormap(labels):
|
171 |
+
color_map = make_color_palette(labels)
|
172 |
+
return color_map
|
173 |
+
|
174 |
+
|
175 |
+
def add_colormap(labels):
|
176 |
+
color_map = get_colormap(labels)
|
177 |
+
for label in labels:
|
178 |
+
if label not in color_map:
|
179 |
+
rand_color = "#"+"%06x" % random.randint(0, 0xFFFFFF)
|
180 |
+
color_map[label]=rand_color
|
181 |
+
return color_map
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
def load_model_ner(model_url):
|
186 |
+
config = AutoConfig.from_pretrained(model_url)
|
187 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
188 |
+
model_url, config=config
|
189 |
+
)
|
190 |
+
tokenizer = AutoTokenizer.from_pretrained(model_url, use_fast=True)
|
191 |
+
return pipeline(
|
192 |
+
"ner",
|
193 |
+
model=model,
|
194 |
+
tokenizer=tokenizer,
|
195 |
+
ignore_labels=[],
|
196 |
+
aggregation_strategy="simple",
|
197 |
+
)
|
198 |
+
|
199 |
+
|
200 |
+
def display(entities):
|
201 |
+
doc = model_entities_to_displacy_format(entities, ignore_entities=["O"])
|
202 |
+
labels = list(set([ent["label"] for ent in doc["ents"]]))
|
203 |
+
color_map = add_colormap(labels)
|
204 |
+
html = displacy.render(
|
205 |
+
doc,
|
206 |
+
manual=True,
|
207 |
+
style="ent",
|
208 |
+
options={"colors": color_map}
|
209 |
+
)
|
210 |
+
html = html.replace("\n", " ")
|
211 |
+
st.write(WRAPPER.format(html), unsafe_allow_html=True)
|
212 |
+
|
213 |
+
|
214 |
+
def rgb2hex(r, g, b):
|
215 |
+
return "#{:02x}{:02x}{:02x}".format(
|
216 |
+
int(r * 255), int(g * 255), int(b * 255)
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
def model_entities_to_displacy_format(ents, ignore_entities=[]):
|
221 |
+
s_ents = {}
|
222 |
+
s_ents["text"] = " ".join([e["word"] for e in ents])
|
223 |
+
spacy_ents = []
|
224 |
+
start_pointer = 0
|
225 |
+
if isinstance(ents, list) and "entity_group" in ents[0]:
|
226 |
+
entity_key = "entity_group"
|
227 |
+
else:
|
228 |
+
entity_key = "entity"
|
229 |
+
for i, ent in enumerate(ents):
|
230 |
+
if ent[entity_key] not in ignore_entities:
|
231 |
+
spacy_ents.append({
|
232 |
+
"start": start_pointer,
|
233 |
+
"end": start_pointer + len(ent["word"]),
|
234 |
+
"label": ent[entity_key],
|
235 |
+
})
|
236 |
+
start_pointer = start_pointer + len(ent["word"]) + 1
|
237 |
+
s_ents["ents"] = spacy_ents
|
238 |
+
s_ents["title"] = None
|
239 |
+
return s_ents
|
240 |
+
|
241 |
+
st.markdown("""
|
242 |
+
|
243 |
+
### Fine-tuning to CoNLL 2002 es for Named Entity Recognition (NER)
|
244 |
+
|
245 |
+
Here you can play with the RoBERTa Base Gaussian Seq Len 512 model fine-tuned to conll2002-es.
|
246 |
+
|
247 |
+
""")
|
248 |
+
text_input = str(st.text_input(
|
249 |
+
"Text",
|
250 |
+
"Mi nombre es Íñigo Montoya. Viajo a Los Acantilados de la Locura "
|
251 |
+
))
|
252 |
+
ner_model_url = "bertin-project/bertin-base-ner-conll2002-es"
|
253 |
+
label2id = AutoConfig.from_pretrained(ner_model_url, cache=False).label2id
|
254 |
+
color_map = get_colormap(list(label2id.keys()))
|
255 |
+
if st.button("Recognize named entities"):
|
256 |
+
with st.spinner(text="Recognizing named entities..."):
|
257 |
+
ner = load_model_ner(ner_model_url)
|
258 |
+
entities = ner(str(text_input))
|
259 |
+
st.write("_English_ _translation:_", translate(str(text_input), "en", "es"))
|
260 |
+
if entities:
|
261 |
+
if isinstance(entities, dict) and "error" in entities:
|
262 |
+
st.write(entities)
|
263 |
+
else:
|
264 |
+
display(entities)
|
265 |
+
raw_entities = []
|
266 |
+
for entity in entities:
|
267 |
+
raw_entity = entity
|
268 |
+
raw_entity["start"] = int(raw_entity["start"])
|
269 |
+
raw_entity["end"] = int(raw_entity["end"])
|
270 |
+
raw_entity["score"] = float(raw_entity["score"])
|
271 |
+
raw_entities.append(raw_entity)
|
272 |
+
st.write(raw_entities)
|
273 |
+
else:
|
274 |
+
st.write("No entities found")
|
275 |
+
|
276 |
st.markdown(
|
277 |
"""
|
278 |
### Team members
|
requirements.txt
CHANGED
@@ -2,3 +2,5 @@ streamlit
|
|
2 |
mtranslate
|
3 |
transformers
|
4 |
torch
|
|
|
|
|
|
2 |
mtranslate
|
3 |
transformers
|
4 |
torch
|
5 |
+
seaborn
|
6 |
+
spacy
|