File size: 6,541 Bytes
f78244e
 
 
 
 
 
92627cb
f78244e
 
 
 
 
 
 
 
 
 
 
 
 
92627cb
f78244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92627cb
 
f78244e
92627cb
 
 
f78244e
92627cb
 
 
 
f78244e
 
 
 
 
 
92627cb
f78244e
 
 
 
 
 
 
0880270
87596b4
f78244e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a843e3
f78244e
 
 
 
 
88ec4e4
f78244e
 
c1f4179
88ec4e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f78244e
a1028e2
f78244e
 
 
 
 
 
 
 
74e144a
 
ab26b95
a4645a6
f78244e
 
92627cb
f78244e
 
 
 
 
 
 
 
 
92627cb
f78244e
 
 
 
92627cb
 
f78244e
 
 
 
92627cb
f78244e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import random
import re
from poems import SAMPLE_POEMS

import langid
import numpy as np
import streamlit as st
import torch

from icu_tokenizer import Tokenizer
from transformers import pipeline

MODELS = {
    "ALBERTI": "flax-community/alberti-bert-base-multilingual-cased",
    "mBERT": "bert-base-multilingual-cased"
}

TOPK = 50
st.set_page_config(layout="wide")


def mask_line(line, language="es", restrictive=True):
    tokenizer = Tokenizer(lang=language)
    token_list = tokenizer.tokenize(line)
    if lang != "zh":
        restrictive = not all([len(token) <= 3 for token in token_list])
    random_num = random.randint(0, len(token_list) - 1)
    random_word = token_list[random_num]
    if not restrictive:
        token_list[random_num] = "[MASK]"
        masked_l = " ".join(token_list)
        return masked_l
    elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()):
        token_list[random_num] = "[MASK]"
        masked_l = " ".join(token_list)
        return masked_l
    else:
        return mask_line(line, language)


def filter_candidates(candidates, get_any_candidate=False):
    cand_list = []
    score_list = []
    for candidate in candidates:
        if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha():
            cand = candidate["sequence"]
            score = candidate["score"]
            cand_list.append(cand)
            score_list.append('{0:.5f}'.format(score))
        elif get_any_candidate:
            cand = candidate["sequence"]
            score = candidate["score"]
            cand_list.append(cand)
            score_list.append('{0:.5f}'.format(score))
        if len(score_list) == TOPK:
            break
    if len(cand_list) < 1:
        return filter_candidates(candidates, get_any_candidate=True)
    else:
        return cand_list[0]


def infer_candidates(nlp, line):
    line = re.sub("–", "-", line)
    line = re.sub("β€”", "-", line)
    line = re.sub("’", "'", line)
    line = re.sub("…", "...", line)
    inputs = nlp._parse_and_tokenize(line)
    outputs = nlp._forward(inputs, return_tensors=True)
    input_ids = inputs["input_ids"][0]
    masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id,
                                 as_tuple=False)
    logits = outputs[0, masked_index.item(), :]
    probs = logits.softmax(dim=0)
    values, predictions = probs.topk(TOPK)
    result = []
    for v, p in zip(values.tolist(), predictions.tolist()):
        tokens = input_ids.numpy()
        tokens[masked_index] = p
        # Filter padding out:
        tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)]
        l = []
        token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens]
        for idx, token in enumerate(token_list):
            if token.startswith('##'):
                l[-1] += token[2:]
            elif idx == masked_index.item():
                l += ['<b style="color: #ff0000;">', token, "</b>"]
            else:
                l += [token]
        sequence = " ".join(l).strip()
        result.append(
            {
                "sequence": sequence,
                "score": v,
                "token": p,
                "token_str": nlp.tokenizer.decode(p),
                "masked_index": masked_index.item()
            }
        )
    return result


def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"):
    nlp = pipeline("fill-mask", model=ml_model)
    unmasked_lines = []
    masked_lines = []
    for line in poem:
        if line == "":
            unmasked_lines.append("")
            masked_lines.append("")
            continue
        if masking:
            masked_line = mask_line(line, language)
        else:
            masked_line = line
        masked_lines.append(masked_line)
        unmasked_line_candidates = infer_candidates(nlp, masked_line)
        unmasked_line = filter_candidates(unmasked_line_candidates)
        unmasked_lines.append(unmasked_line)
        unmasked_poem = "<br>".join(unmasked_lines)
    return unmasked_poem, masked_lines


instructions_text_0 = st.sidebar.markdown(
    """# ALBERTI vs BERT πŸ₯Š

We present ALBERTI, our BERT-based multilingual model for poetry.""")

instructions_text_1 = st.sidebar.markdown(
    """We have trained bert on a huge (for poetry, that is) corpus of
multilingual poetry to try to get a more 'poetic' model. This is the result
of our work.

You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""")

sample_chooser = st.sidebar.selectbox(
    "Choose a poem",
    list(SAMPLE_POEMS.keys())
)

instructions_text_2 = st.sidebar.markdown("""# How to use

You can choose from a list of example poems in Spanish, English, French, German,
Chinese and Arabic, but you can also paste a poem, or write it yourself!

Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen
poem, randomly masking one word per verse, and get the two new versions for each of the models.

The list of languages used on the training of ALBERTI are:

* Arabic
* Chinese
* Czech
* English
* Finnish
* French
* German
* Hungarian
* Italian
* Portuguese
* Russian
* Spanish""")

col1, col2, col3 = st.columns(3)

st.markdown(
    """
    <style>
        label {
        font-size: 1rem !important;
        font-weight: bold !important;
        }
        .block-container {
        padding-left: 1rem !important;
        padding-right: 1rem !important;
        }
    </style>
    """, unsafe_allow_html=True)

if sample_chooser:
    model_list = set(MODELS.values())
    user_input = col1.text_area("Input poem",
                                "\n".join(SAMPLE_POEMS[sample_chooser]),
                                height=600)
    poem = user_input.split("\n")
    rewrite_button = col1.button("Rewrite!")
    if "[MASK]" in user_input or "<mask>" in user_input:
        col1.error("You don't have to mask the poem, we'll do it for you!")

if rewrite_button:
    lang = langid.classify(user_input)[0]
    unmasked_poem, masked_poem = rewrite_poem(poem, language=lang)
    user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b>


{unmasked_poem}""", unsafe_allow_html=True)
    unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"],
                                      masking=False)
    user_input_3 = col3.write(f"""<b>Output poem from mBERT</b>

{unmasked_poem_2}""", unsafe_allow_html=True)