import random import re from poems import SAMPLE_POEMS import langid import numpy as np import streamlit as st import torch from icu_tokenizer import Tokenizer from transformers import pipeline MODELS = { "ALBERTI": "flax-community/alberti-bert-base-multilingual-cased", "mBERT": "bert-base-multilingual-cased" } TOPK = 50 st.set_page_config(layout="wide") def mask_line(line, language="es", restrictive=True): tokenizer = Tokenizer(lang=language) token_list = tokenizer.tokenize(line) if lang != "zh": restrictive = not all([len(token) <= 3 for token in token_list]) random_num = random.randint(0, len(token_list) - 1) random_word = token_list[random_num] if not restrictive: token_list[random_num] = "[MASK]" masked_l = " ".join(token_list) return masked_l elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()): token_list[random_num] = "[MASK]" masked_l = " ".join(token_list) return masked_l else: return mask_line(line, language) def filter_candidates(candidates, get_any_candidate=False): cand_list = [] score_list = [] for candidate in candidates: if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha(): cand = candidate["sequence"] score = candidate["score"] cand_list.append(cand) score_list.append('{0:.5f}'.format(score)) elif get_any_candidate: cand = candidate["sequence"] score = candidate["score"] cand_list.append(cand) score_list.append('{0:.5f}'.format(score)) if len(score_list) == TOPK: break if len(cand_list) < 1: return filter_candidates(candidates, get_any_candidate=True) else: return cand_list[0] def infer_candidates(nlp, line): line = re.sub("’", "'", line) line = re.sub("…", "...", line) inputs = nlp._parse_and_tokenize(line) outputs = nlp._forward(inputs, return_tensors=True) input_ids = inputs["input_ids"][0] masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id, as_tuple=False) logits = outputs[0, masked_index.item(), :] probs = logits.softmax(dim=0) values, predictions = probs.topk(TOPK) result = [] for v, p in zip(values.tolist(), predictions.tolist()): tokens = input_ids.numpy() tokens[masked_index] = p # Filter padding out: tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)] l = [] token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens] for idx, token in enumerate(token_list): if token.startswith('##'): l[-1] += token[2:] elif idx == masked_index.item(): l += ['', token, ""] else: l += [token] sequence = " ".join(l).strip() result.append( { "sequence": sequence, "score": v, "token": p, "token_str": nlp.tokenizer.decode(p), "masked_index": masked_index.item() } ) return result def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"): nlp = pipeline("fill-mask", model=ml_model) unmasked_lines = [] masked_lines = [] for line in poem: if line == "": unmasked_lines.append("") masked_lines.append("") continue if masking: masked_line = mask_line(line, language) else: masked_line = line masked_lines.append(masked_line) unmasked_line_candidates = infer_candidates(nlp, masked_line) unmasked_line = filter_candidates(unmasked_line_candidates) unmasked_lines.append(unmasked_line) unmasked_poem = "
".join(unmasked_lines) return unmasked_poem, masked_lines instructions_text_0 = st.sidebar.markdown( """# ALBERTI vs BERT 🥊 We present ALBERTI, our BERT-based multilingual model for poetry.""") instructions_text_1 = st.sidebar.markdown( """We have trained bert on a huge (for poetry, that is) corpus of multilingual poetry to try to get a more 'poetic' model. This is the result of our work. You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""") sample_chooser = st.sidebar.selectbox( "Choose a poem", list(SAMPLE_POEMS.keys()) ) instructions_text_2 = st.sidebar.markdown("""# How to use You can choose from a list of example poems in Spanish, English, French, German, Chinese and Arabic, but you can also paste a poem, or write it yourself! Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen poem. The list of languages used on the training of ALBERTI are: * Arabic * Chinese * Czech * English * Finnish * French * German * Hungarian * Italian * Portuguese * Portuguese * Russian * Spanish""") col1, col2, col3 = st.beta_columns(3) st.markdown( """ """, unsafe_allow_html=True) if sample_chooser: model_list = set(MODELS.values()) user_input = col1.text_area("Input poem", "\n".join(SAMPLE_POEMS[sample_chooser]), height=600) poem = user_input.split("\n") rewrite_button = col1.button("Rewrite!") if "[MASK]" in user_input or "" in user_input: col1.error("You don't have to mask the poem, we'll do it for you!") if rewrite_button: lang = langid.classify(user_input)[0] unmasked_poem, masked_poem = rewrite_poem(poem, language=lang) user_input_2 = col2.write(f"""Output poem from ALBERTI {unmasked_poem}""", unsafe_allow_html=True) unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"], masking=False) user_input_3 = col3.write(f"""Output poem from mBERT {unmasked_poem_2}""", unsafe_allow_html=True)