import random import re from poems import SAMPLE_POEMS import langid import numpy as np import streamlit as st import torch from icu_tokenizer import Tokenizer from transformers import pipeline MODELS = { "ALBERTI": "flax-community/alberti-bert-base-multilingual-cased", "mBERT": "bert-base-multilingual-cased" } TOPK = 50 st.set_page_config(layout="wide") def mask_line(line, language="es", restrictive=True): tokenizer = Tokenizer(lang=language) token_list = tokenizer.tokenize(line) if lang != "zh": restrictive = not all([len(token) <= 3 for token in token_list]) random_num = random.randint(0, len(token_list) - 1) random_word = token_list[random_num] if not restrictive: token_list[random_num] = "[MASK]" masked_l = " ".join(token_list) return masked_l elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()): token_list[random_num] = "[MASK]" masked_l = " ".join(token_list) return masked_l else: return mask_line(line, language) def filter_candidates(candidates, get_any_candidate=False): cand_list = [] score_list = [] for candidate in candidates: if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha(): cand = candidate["sequence"] score = candidate["score"] cand_list.append(cand) score_list.append('{0:.5f}'.format(score)) elif get_any_candidate: cand = candidate["sequence"] score = candidate["score"] cand_list.append(cand) score_list.append('{0:.5f}'.format(score)) if len(score_list) == TOPK: break if len(cand_list) < 1: return filter_candidates(candidates, get_any_candidate=True) else: return cand_list[0] def infer_candidates(nlp, line): line = re.sub("–", "-", line) line = re.sub("—", "-", line) line = re.sub("’", "'", line) line = re.sub("…", "...", line) inputs = nlp._parse_and_tokenize(line) outputs = nlp._forward(inputs, return_tensors=True) input_ids = inputs["input_ids"][0] masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id, as_tuple=False) logits = outputs[0, masked_index.item(), :] probs = logits.softmax(dim=0) values, predictions = probs.topk(TOPK) result = [] for v, p in zip(values.tolist(), predictions.tolist()): tokens = input_ids.numpy() tokens[masked_index] = p # Filter padding out: tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)] l = [] token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens] for idx, token in enumerate(token_list): if token.startswith('##'): l[-1] += token[2:] elif idx == masked_index.item(): l += ['', token, ""] else: l += [token] sequence = " ".join(l).strip() result.append( { "sequence": sequence, "score": v, "token": p, "token_str": nlp.tokenizer.decode(p), "masked_index": masked_index.item() } ) return result def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"): nlp = pipeline("fill-mask", model=ml_model) unmasked_lines = [] masked_lines = [] for line in poem: if line == "": unmasked_lines.append("") masked_lines.append("") continue if masking: masked_line = mask_line(line, language) else: masked_line = line masked_lines.append(masked_line) unmasked_line_candidates = infer_candidates(nlp, masked_line) unmasked_line = filter_candidates(unmasked_line_candidates) unmasked_lines.append(unmasked_line) unmasked_poem = "
".join(unmasked_lines) return unmasked_poem, masked_lines instructions_text_0 = st.sidebar.markdown( """# ALBERTI vs BERT 🥊 We present ALBERTI, our BERT-based multilingual model for poetry.""") instructions_text_1 = st.sidebar.markdown( """We have trained bert on a huge (for poetry, that is) corpus of multilingual poetry to try to get a more 'poetic' model. This is the result of our work. You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""") sample_chooser = st.sidebar.selectbox( "Choose a poem", list(SAMPLE_POEMS.keys()) ) instructions_text_2 = st.sidebar.markdown("""# How to use You can choose from a list of example poems in Spanish, English, French, German, Chinese and Arabic, but you can also paste a poem, or write it yourself! Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen poem, randomly masking one word per verse, and get the two new versions for each of the models. The list of languages used on the training of ALBERTI are: * Arabic * Chinese * Czech * English * Finnish * French * German * Hungarian * Italian * Portuguese * Russian * Spanish""") col1, col2, col3 = st.columns(3) st.markdown( """ """, unsafe_allow_html=True) if sample_chooser: model_list = set(MODELS.values()) user_input = col1.text_area("Input poem", "\n".join(SAMPLE_POEMS[sample_chooser]), height=600) poem = user_input.split("\n") rewrite_button = col1.button("Rewrite!") if "[MASK]" in user_input or "" in user_input: col1.error("You don't have to mask the poem, we'll do it for you!") if rewrite_button: lang = langid.classify(user_input)[0] unmasked_poem, masked_poem = rewrite_poem(poem, language=lang) user_input_2 = col2.write(f"""Output poem from ALBERTI {unmasked_poem}""", unsafe_allow_html=True) unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"], masking=False) user_input_3 = col3.write(f"""Output poem from mBERT {unmasked_poem_2}""", unsafe_allow_html=True)