File size: 3,755 Bytes
0ba900e
 
 
 
ef1f66e
0ba900e
 
8384a73
0ba900e
 
 
 
 
 
 
 
 
 
8eb8892
0ba900e
 
 
 
 
 
8a66e28
 
 
0ba900e
43b8437
0ba900e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8384a73
0ba900e
 
 
 
 
 
 
 
 
 
 
 
 
8384a73
 
 
 
 
 
 
 
 
 
0ba900e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import copy
import logging
from typing import List

import torch
import streamlit as st
from transformers import BertTokenizer, TFAutoModelForMaskedLM
from transformers import CamembertModel, CamembertTokenizer

from rhyme_with_ai.utils import color_new_words, sanitize
from rhyme_with_ai.rhyme import query_rhyme_words
from rhyme_with_ai.rhyme_generator import RhymeGenerator


DEFAULT_QUERY = "Machines will take over the world soon"
N_RHYMES = 10


LANGUAGE = st.sidebar.radio("Language", ["english", "dutch", "french"],0)
if LANGUAGE == "english":
    MODEL_PATH = "bert-large-cased-whole-word-masking"
    ITER_FACTOR = 5
elif LANGUAGE == "dutch":
    MODEL_PATH = "GroNLP/bert-base-dutch-cased"
    ITER_FACTOR = 10  # Faster model
elif LANGUAGE == "french":
    MODEL_PATH = "camembert-base"
    ITER_FACTOR = 5
else:
    raise NotImplementedError(f"Unsupported language ({LANGUAGE}) expected 'english','dutch' or 'french.")

def main():
    st.markdown(
        "<sup>Created with "
        "[Datamuse](https://www.datamuse.com/api/), "
        "[Mick's rijmwoordenboek](https://rijmwoordenboek.nl), "
        "[Hugging Face](https://huggingface.co/), "
        "[Streamlit](https://streamlit.io/) and "
        "[App Engine](https://cloud.google.com/appengine/)."
        " Read our [blog](https://blog.godatadriven.com/rhyme-with-ai) "
        "or check the "
        "[source](https://github.com/godatadriven/rhyme-with-ai).</sup>",
        unsafe_allow_html=True,
    )
    st.title("Rhyme with AI")
    query = get_query()
    if not query:
        query = DEFAULT_QUERY
    rhyme_words_options = query_rhyme_words(query, n_rhymes=N_RHYMES,language=LANGUAGE)
    if rhyme_words_options:
        logging.getLogger(__name__).info("Got rhyme words: %s", rhyme_words_options)
        start_rhyming(query, rhyme_words_options)
    else:
        st.write("No rhyme words found")


def get_query():
    q = sanitize(
        st.text_input("Write your first line and press ENTER to rhyme:", DEFAULT_QUERY)
    )
    if not q:
        return DEFAULT_QUERY
    return q


def start_rhyming(query, rhyme_words_options):
    st.markdown("## My Suggestions:")

    progress_bar = st.progress(0)
    status_text = st.empty()
    max_iter = len(query.split()) * ITER_FACTOR

    rhyme_words = rhyme_words_options[:N_RHYMES]

    model, tokenizer = load_model(MODEL_PATH, LANGUAGE)
    sentence_generator = RhymeGenerator(model, tokenizer)
    sentence_generator.start(query, rhyme_words)

    current_sentences = [" " for _ in range(N_RHYMES)]
    for i in range(max_iter):
        previous_sentences = copy.deepcopy(current_sentences)
        current_sentences = sentence_generator.mutate()
        display_output(status_text, query, current_sentences, previous_sentences)
        progress_bar.progress(i / (max_iter - 1))
    st.balloons()


@st.cache(allow_output_mutation=True)
def load_model(model_path, language):
    if language != "french":
        return (
            TFAutoModelForMaskedLM.from_pretrained(model_path),
            BertTokenizer.from_pretrained(model_path),
        )
    else :
        return (
        CamembertModel.from_pretrained(model_path),
        CamembertTokenizer.from_pretrained(model_path),
    )

def display_output(status_text, query, current_sentences, previous_sentences):
    print_sentences = []
    for new, old in zip(current_sentences, previous_sentences):
        formatted = color_new_words(new, old)
        after_comma = "<li>" + formatted.split(",")[1][:-2] + "</li>"
        print_sentences.append(after_comma)
    status_text.markdown(
        query + ",<br>" + "".join(print_sentences), unsafe_allow_html=True
    )



if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()