versae commited on
Commit
2388248
1 Parent(s): 9ed2311

New changes to demo

Browse files
Files changed (2) hide show
  1. app.py +177 -16
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,25 +1,20 @@
1
  import random
2
  from mtranslate import translate
3
  import streamlit as st
4
- from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png"
8
-
9
- MODELS = {
10
- "RoBERTa Base Gaussian Seq Len 128": {
11
- "url": "bertin-project/bertin-base-gaussian"
12
- },
13
- "RoBERTa Base Gaussian Seq Len 512": {
14
- "url": "bertin-project/bertin-base-gaussian-exp-512seqlen"
15
- },
16
- "RoBERTa Base Random Seq Len 128": {
17
- "url": "bertin-project/bertin-base-random"
18
- },
19
- "RoBERTa Base Stepwise Seq Len 128": {
20
- "url": "bertin-project/bertin-base-stepwise"
21
- },
22
- }
23
 
24
  PROMPT_LIST = [
25
  "Fui a la librería a comprar un <mask>.",
@@ -37,6 +32,12 @@ PROMPT_LIST = [
37
  "Al pan, pan, y al vino, <mask>.",
38
  ]
39
 
 
 
 
 
 
 
40
 
41
  @st.cache(show_spinner=False, persist=True)
42
  def load_model(masked_text, model_url):
@@ -47,6 +48,26 @@ def load_model(masked_text, model_url):
47
  return result
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Page
51
  st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO)
52
  st.title("BERTIN")
@@ -84,6 +105,11 @@ st.markdown(
84
  The first models have been trained (250.000 steps) on sequence length 128, and then training for Gaussian changed to sequence length 512 for the last 25.000 training steps to yield another version.
85
 
86
  Please read our [full report](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) for more details on the methodology and metrics on downstream tasks.
 
 
 
 
 
87
  """
88
  )
89
 
@@ -112,6 +138,141 @@ if st.button("Fill the mask"):
112
  st.write("_English_ _translation:_", translate(result_sequence, "en", "es"))
113
  st.write(result)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  st.markdown(
116
  """
117
  ### Team members
1
  import random
2
  from mtranslate import translate
3
  import streamlit as st
4
+ import seaborn as sns
5
+ from spacy import displacy
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoTokenizer,
9
+ AutoModelForMaskedLM,
10
+ AutoModelForSequenceClassification,
11
+ AutoModelForTokenClassification,
12
+ pipeline
13
+ )
14
 
15
 
16
  LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png"
17
+ WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  PROMPT_LIST = [
20
  "Fui a la librería a comprar un <mask>.",
32
  "Al pan, pan, y al vino, <mask>.",
33
  ]
34
 
35
+ PAWS_X_PROMPT_LIST = [
36
+ "Te amo.</s>Te adoro.",
37
+ "Te odio.</s>Te detesto.",
38
+ "Me gusta montar en bicicleta.</s>París es una ciudad francesa."
39
+ ]
40
+
41
 
42
  @st.cache(show_spinner=False, persist=True)
43
  def load_model(masked_text, model_url):
48
  return result
49
 
50
 
51
+ @st.cache(show_spinner=False, persist=True)
52
+ def load_model(masked_text, model_url):
53
+ model = AutoModelForMaskedLM.from_pretrained(model_url)
54
+ tokenizer = AutoTokenizer.from_pretrained(model_url)
55
+ nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
56
+ result = nlp(masked_text)
57
+ return result
58
+
59
+
60
+ @st.cache(show_spinner=False, persist=True)
61
+ def load_model_pair_classification(text, model_url_pair_classification):
62
+ model = AutoModelForSequenceClassification.from_pretrained(model_url_pair_classification)
63
+ tokenizer = AutoTokenizer.from_pretrained(model_url)
64
+ nlp = pipeline("text-classification", model=model, tokenizer=tokenizer)
65
+ result = nlp(f"{text}</s>")
66
+ if result[0]["label"] == "LABEL_0":
67
+ return f"Different meaning: {result[0]['score']:02f}"
68
+ return f"Paraphrase: {result[0]['score']:02f}"
69
+
70
+
71
  # Page
72
  st.set_page_config(page_title="BERTIN Demo", page_icon=LOGO)
73
  st.title("BERTIN")
105
  The first models have been trained (250.000 steps) on sequence length 128, and then training for Gaussian changed to sequence length 512 for the last 25.000 training steps to yield another version.
106
 
107
  Please read our [full report](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) for more details on the methodology and metrics on downstream tasks.
108
+
109
+ ### Masked language modeling
110
+
111
+ Here you can play with the filling the mask objective of all the models.
112
+
113
  """
114
  )
115
 
138
  st.write("_English_ _translation:_", translate(result_sequence, "en", "es"))
139
  st.write(result)
140
 
141
+ st.markdown(
142
+ """
143
+ ### Fine-tuning to PAWS-X for paraphrase identification
144
+ Here you can play with the RoBERTa Base Gaussian Seq Len 512 model fine-tuned to PAWS-X.
145
+ """
146
+ )
147
+
148
+ pawsx_model_url = "bertin-project/bertin-base-paws-x-es"
149
+ paraphrase_prompt = st.selectbox("Paraphrase Prompt", ["Random", "Custom"])
150
+ if paraphrase_prompt == "Custom":
151
+ paraphrase_prompt_box = "Enter two sentences separated by </s> here..."
152
+ else:
153
+ paraphrase_prompt_box = random.choice(PAWS_X_PROMPT_LIST)
154
+ text = st.text_area("Enter text", paraphrase_prompt_box)
155
+ if st.button("Clasify paraphrasing"):
156
+ with st.spinner(text="Clasifying paraphrasing..."):
157
+ st.subheader("Classification result")
158
+ paraphrase_score = load_model_pair_classification(text, pawsx_model_url)
159
+ st.write("_English_ _translation:_", translate(text, "en", "es"))
160
+ st.write(paraphrase_score)
161
+
162
+
163
+ def make_color_palette(labels):
164
+ color_palette = sns.color_palette(n_colors=len(labels))
165
+ color_map = {x: rgb2hex(*y) for x, y in zip(labels, color_palette)}
166
+ return color_map
167
+
168
+
169
+ @st.cache(allow_output_mutation=True)
170
+ def get_colormap(labels):
171
+ color_map = make_color_palette(labels)
172
+ return color_map
173
+
174
+
175
+ def add_colormap(labels):
176
+ color_map = get_colormap(labels)
177
+ for label in labels:
178
+ if label not in color_map:
179
+ rand_color = "#"+"%06x" % random.randint(0, 0xFFFFFF)
180
+ color_map[label]=rand_color
181
+ return color_map
182
+
183
+
184
+
185
+ def load_model_ner(model_url):
186
+ config = AutoConfig.from_pretrained(model_url)
187
+ model = AutoModelForTokenClassification.from_pretrained(
188
+ model_url, config=config
189
+ )
190
+ tokenizer = AutoTokenizer.from_pretrained(model_url, use_fast=True)
191
+ return pipeline(
192
+ "ner",
193
+ model=model,
194
+ tokenizer=tokenizer,
195
+ ignore_labels=[],
196
+ aggregation_strategy="simple",
197
+ )
198
+
199
+
200
+ def display(entities):
201
+ doc = model_entities_to_displacy_format(entities, ignore_entities=["O"])
202
+ labels = list(set([ent["label"] for ent in doc["ents"]]))
203
+ color_map = add_colormap(labels)
204
+ html = displacy.render(
205
+ doc,
206
+ manual=True,
207
+ style="ent",
208
+ options={"colors": color_map}
209
+ )
210
+ html = html.replace("\n", " ")
211
+ st.write(WRAPPER.format(html), unsafe_allow_html=True)
212
+
213
+
214
+ def rgb2hex(r, g, b):
215
+ return "#{:02x}{:02x}{:02x}".format(
216
+ int(r * 255), int(g * 255), int(b * 255)
217
+ )
218
+
219
+
220
+ def model_entities_to_displacy_format(ents, ignore_entities=[]):
221
+ s_ents = {}
222
+ s_ents["text"] = " ".join([e["word"] for e in ents])
223
+ spacy_ents = []
224
+ start_pointer = 0
225
+ if isinstance(ents, list) and "entity_group" in ents[0]:
226
+ entity_key = "entity_group"
227
+ else:
228
+ entity_key = "entity"
229
+ for i, ent in enumerate(ents):
230
+ if ent[entity_key] not in ignore_entities:
231
+ spacy_ents.append({
232
+ "start": start_pointer,
233
+ "end": start_pointer + len(ent["word"]),
234
+ "label": ent[entity_key],
235
+ })
236
+ start_pointer = start_pointer + len(ent["word"]) + 1
237
+ s_ents["ents"] = spacy_ents
238
+ s_ents["title"] = None
239
+ return s_ents
240
+
241
+ st.markdown("""
242
+
243
+ ### Fine-tuning to CoNLL 2002 es for Named Entity Recognition (NER)
244
+
245
+ Here you can play with the RoBERTa Base Gaussian Seq Len 512 model fine-tuned to conll2002-es.
246
+
247
+ """)
248
+ text_input = str(st.text_input(
249
+ "Text",
250
+ "Mi nombre es Íñigo Montoya. Viajo a Los Acantilados de la Locura "
251
+ ))
252
+ ner_model_url = "bertin-project/bertin-base-ner-conll2002-es"
253
+ label2id = AutoConfig.from_pretrained(ner_model_url, cache=False).label2id
254
+ color_map = get_colormap(list(label2id.keys()))
255
+ if st.button("Recognize named entities"):
256
+ with st.spinner(text="Recognizing named entities..."):
257
+ ner = load_model_ner(ner_model_url)
258
+ entities = ner(str(text_input))
259
+ st.write("_English_ _translation:_", translate(str(text_input), "en", "es"))
260
+ if entities:
261
+ if isinstance(entities, dict) and "error" in entities:
262
+ st.write(entities)
263
+ else:
264
+ display(entities)
265
+ raw_entities = []
266
+ for entity in entities:
267
+ raw_entity = entity
268
+ raw_entity["start"] = int(raw_entity["start"])
269
+ raw_entity["end"] = int(raw_entity["end"])
270
+ raw_entity["score"] = float(raw_entity["score"])
271
+ raw_entities.append(raw_entity)
272
+ st.write(raw_entities)
273
+ else:
274
+ st.write("No entities found")
275
+
276
  st.markdown(
277
  """
278
  ### Team members
requirements.txt CHANGED
@@ -2,3 +2,5 @@ streamlit
2
  mtranslate
3
  transformers
4
  torch
 
 
2
  mtranslate
3
  transformers
4
  torch
5
+ seaborn
6
+ spacy