File size: 1,711 Bytes
46460db
9d46201
 
 
4ab88cc
0d61d97
4ab88cc
 
 
 
bfd1de9
4ab88cc
4a01ccd
 
 
 
a1b1972
4a01ccd
 
 
4ab88cc
292ac90
bfd1de9
9d46201
292ac90
4a01ccd
 
67f25dd
 
4cd22cd
 
9d46201
944da04
9d46201
 
183cde3
0502abd
46460db
 
9d46201
be51355
9d46201
 
 
 
f43cb7c
 
dc7fbd8
9d46201
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from os import CLD_CONTINUED
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline

@st.cache(allow_output_mutation=True)
def load_model():
  model_ckpt = "flax-community/gpt2-rap-lyric-generator"
  tokenizer = AutoTokenizer.from_pretrained(model_ckpt,from_flax=True)
  model = AutoModelForCausalLM.from_pretrained(model_ckpt,from_flax=True)
  return tokenizer, model

@st.cache()
def load_rappers():
  text_file = open("rappers.txt")
  rappers = text_file.readlines()
  rappers = [name[:-1] for name in rappers]
  rappers.sort()
  return rappers


title = st.title("Loading model")
tokenizer, model = load_model()
text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
title.title("Rap lyrics generator")
#artist = st.text_input("Enter the artist", "Wu-Tang Clan")
list_of_rappers = load_rappers()
artist = st.selectbox("Choose your rapper", tuple(list_of_rappers), index = len(list_of_rappers)-1)
song_name = st.text_input("Enter the desired song name", "Sadboys")



if st.button("Generate lyrics", help="The lyrics generation can last up to 2 minutres"):
    st.title(f"{artist}: {song_name}")
    prefix_text = f"<BOS>{song_name} [Verse 1:{artist}]"
    generated_song = text_generation(prefix_text, max_length=750, do_sample=True)[0]
    for count, line in enumerate(generated_song['generated_text'].split("\n")):
      if"<EOS>" in line:
        break
      if count == 0:
        st.markdown(f"**{line[line.find('['):]}**")
        continue
      if "<BOS>" in line:
        st.write(line[5:])
        continue
      if line.startswith("["):
        st.markdown(f"**{line}**")
        continue
      st.write(line)