|
import random |
|
import re |
|
from poems import SAMPLE_POEMS |
|
|
|
import langid |
|
import numpy as np |
|
import streamlit as st |
|
import torch |
|
|
|
from icu_tokenizer import Tokenizer |
|
from transformers import pipeline |
|
|
|
MODELS = { |
|
"ALBERTI": "linhd-postdata/alberti-bert-base-multilingual-cased", |
|
"mBERT": "bert-base-multilingual-cased" |
|
} |
|
|
|
TOPK = 50 |
|
st.set_page_config(layout="wide") |
|
|
|
|
|
def mask_line(line, language="es", restrictive=True): |
|
tokenizer = Tokenizer(lang=language) |
|
token_list = tokenizer.tokenize(line) |
|
if lang != "zh": |
|
restrictive = not all([len(token) <= 3 for token in token_list]) |
|
random_num = random.randint(0, len(token_list) - 1) |
|
random_word = token_list[random_num] |
|
if not restrictive: |
|
token_list[random_num] = "[MASK]" |
|
masked_l = " ".join(token_list) |
|
return masked_l |
|
elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()): |
|
token_list[random_num] = "[MASK]" |
|
masked_l = " ".join(token_list) |
|
return masked_l |
|
else: |
|
return mask_line(line, language) |
|
|
|
|
|
def filter_candidates(candidates, get_any_candidate=False): |
|
cand_list = [] |
|
score_list = [] |
|
for candidate in candidates: |
|
if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha(): |
|
cand = candidate["sequence"] |
|
score = candidate["score"] |
|
cand_list.append(cand) |
|
score_list.append('{0:.5f}'.format(score)) |
|
elif get_any_candidate: |
|
cand = candidate["sequence"] |
|
score = candidate["score"] |
|
cand_list.append(cand) |
|
score_list.append('{0:.5f}'.format(score)) |
|
if len(score_list) == TOPK: |
|
break |
|
if len(cand_list) < 1: |
|
return filter_candidates(candidates, get_any_candidate=True) |
|
else: |
|
return cand_list[0] |
|
|
|
|
|
def infer_candidates(nlp, line): |
|
line = re.sub("β", "-", line) |
|
line = re.sub("β", "-", line) |
|
line = re.sub("β", "'", line) |
|
line = re.sub("β¦", "...", line) |
|
inputs = nlp._parse_and_tokenize(line) |
|
outputs = nlp._forward(inputs, return_tensors=True) |
|
input_ids = inputs["input_ids"][0] |
|
masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id, |
|
as_tuple=False) |
|
logits = outputs[0, masked_index.item(), :] |
|
probs = logits.softmax(dim=0) |
|
values, predictions = probs.topk(TOPK) |
|
result = [] |
|
for v, p in zip(values.tolist(), predictions.tolist()): |
|
tokens = input_ids.numpy() |
|
tokens[masked_index] = p |
|
|
|
tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)] |
|
l = [] |
|
token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens] |
|
for idx, token in enumerate(token_list): |
|
if token.startswith('##'): |
|
l[-1] += token[2:] |
|
elif idx == masked_index.item(): |
|
l += ['<b style="color: #ff0000;">', token, "</b>"] |
|
else: |
|
l += [token] |
|
sequence = " ".join(l).strip() |
|
result.append( |
|
{ |
|
"sequence": sequence, |
|
"score": v, |
|
"token": p, |
|
"token_str": nlp.tokenizer.decode(p), |
|
"masked_index": masked_index.item() |
|
} |
|
) |
|
return result |
|
|
|
|
|
def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"): |
|
nlp = pipeline("fill-mask", model=ml_model) |
|
unmasked_lines = [] |
|
masked_lines = [] |
|
for line in poem: |
|
if line == "": |
|
unmasked_lines.append("") |
|
masked_lines.append("") |
|
continue |
|
if masking: |
|
masked_line = mask_line(line, language) |
|
else: |
|
masked_line = line |
|
masked_lines.append(masked_line) |
|
unmasked_line_candidates = infer_candidates(nlp, masked_line) |
|
unmasked_line = filter_candidates(unmasked_line_candidates) |
|
unmasked_lines.append(unmasked_line) |
|
unmasked_poem = "<br>".join(unmasked_lines) |
|
return unmasked_poem, masked_lines |
|
|
|
|
|
instructions_text_0 = st.sidebar.markdown( |
|
"""# ALBERTI vs BERT π₯ |
|
|
|
We present ALBERTI, our BERT-based multilingual model for poetry.""") |
|
|
|
instructions_text_1 = st.sidebar.markdown( |
|
"""We have trained bert on a huge (for poetry, that is) corpus of |
|
multilingual poetry to try to get a more 'poetic' model. This is the result |
|
of our work. |
|
|
|
You can find more information on the [project's site](https://huggingface.co/linhd-postdata/alberti). See also https://arxiv.org/abs/2307.01387""") |
|
|
|
sample_chooser = st.sidebar.selectbox( |
|
"Choose a poem", |
|
list(SAMPLE_POEMS.keys()) |
|
) |
|
|
|
instructions_text_2 = st.sidebar.markdown("""# How to use |
|
|
|
You can choose from a list of example poems in Spanish, English, French, German, |
|
Chinese and Arabic, but you can also paste a poem, or write it yourself! |
|
|
|
Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen |
|
poem, randomly masking one word per verse, and get the two new versions for each of the models. |
|
|
|
The list of languages used on the training of ALBERTI are: |
|
|
|
* Chinese |
|
* Czech |
|
* English |
|
* Finnish |
|
* French |
|
* German |
|
* Hungarian |
|
* Italian |
|
* Portuguese |
|
* Russian |
|
* Spanish |
|
""") |
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
label { |
|
font-size: 1rem !important; |
|
font-weight: bold !important; |
|
} |
|
.block-container { |
|
padding-left: 1rem !important; |
|
padding-right: 1rem !important; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
if sample_chooser: |
|
model_list = set(MODELS.values()) |
|
user_input = col1.text_area("Input poem", |
|
"\n".join(SAMPLE_POEMS[sample_chooser]), |
|
height=600) |
|
poem = user_input.split("\n") |
|
rewrite_button = col1.button("Rewrite!") |
|
if "[MASK]" in user_input or "<mask>" in user_input: |
|
col1.error("You don't have to mask the poem, we'll do it for you!") |
|
|
|
if rewrite_button: |
|
lang = langid.classify(user_input)[0] |
|
unmasked_poem, masked_poem = rewrite_poem(poem, language=lang) |
|
user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b> |
|
|
|
|
|
{unmasked_poem}""", unsafe_allow_html=True) |
|
unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"], |
|
masking=False) |
|
user_input_3 = col3.write(f"""<b>Output poem from mBERT</b> |
|
|
|
{unmasked_poem_2}""", unsafe_allow_html=True) |
|
|