Spaces:
Paused
Paused
testing lyric generation
Browse files- ArtistCoherencyModel.py +13 -0
- LyricGeneratorModel.py +23 -0
- app.py +12 -6
ArtistCoherencyModel.py
CHANGED
@@ -64,6 +64,19 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
64 |
with torch.no_grad():
|
65 |
return self.forward(song_or_embedding)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def predict(
|
68 |
self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
|
69 |
) -> Union[list[str], torch.Tensor]:
|
|
|
64 |
with torch.no_grad():
|
65 |
return self.forward(song_or_embedding)
|
66 |
|
67 |
+
def generate_artist_coherency_score(
|
68 |
+
self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
|
69 |
+
) -> float:
|
70 |
+
coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
|
71 |
+
incoherent_index = self.ffnn.label2id[f"{artist_name}-incoherent"]
|
72 |
+
logits = self.generate_artist_coherency_logits(song_or_embedding)
|
73 |
+
coherent_score = logits[coherent_index]
|
74 |
+
incoherent_score = logits[incoherent_index]
|
75 |
+
score = (coherent_score + incoherent_score) * (
|
76 |
+
coherent_score / incoherent_score
|
77 |
+
)
|
78 |
+
return float(score)
|
79 |
+
|
80 |
def predict(
|
81 |
self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
|
82 |
) -> Union[list[str], torch.Tensor]:
|
LyricGeneratorModel.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
2 |
+
|
3 |
+
|
4 |
+
class LyricGeneratorModel:
|
5 |
+
def __init__(self, repo_id: str):
|
6 |
+
self.model = AutoModelForCausalLM.from_pretrained(repo_id)
|
7 |
+
|
8 |
+
self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
9 |
+
self.tokenizer.truncation_side = "right"
|
10 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
11 |
+
self.tokenizer.padding_side = "right"
|
12 |
+
|
13 |
+
def generate_lyrics(self, prompt: str, max_length: int):
|
14 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
15 |
+
input_ids = input_ids.to("cuda")
|
16 |
+
|
17 |
+
output_tokens = self.model.generate(
|
18 |
+
input_ids, do_sample=True, max_length=max_length
|
19 |
+
)
|
20 |
+
|
21 |
+
output_text = self.tokenizer.batch_decode(output_tokens)[0]
|
22 |
+
|
23 |
+
return output_text
|
app.py
CHANGED
@@ -2,6 +2,8 @@ from ArtistCoherencyModel import ArtistCoherencyModel
|
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
|
|
|
|
|
5 |
artists_df = pd.read_csv("artists.csv")
|
6 |
artist_names_list = list(artists_df["name"])
|
7 |
|
@@ -11,11 +13,15 @@ song_title = st.text_input("Song Title")
|
|
11 |
song_description = st.text_area("Song Description")
|
12 |
submit_button = st.button("Submit")
|
13 |
|
14 |
-
|
15 |
-
st.write(
|
16 |
-
f"[Song Title] {song_title}\n[Song Artist] {artist_name_input}\n[Song Description] {song_description}"
|
17 |
-
)
|
18 |
-
|
19 |
-
ensemble_model = ArtistCoherencyModel.from_pretrained(
|
20 |
"tjl223/artist-coherency-ensemble"
|
21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import streamlit as st
|
3 |
import pandas as pd
|
4 |
|
5 |
+
from LyricGeneratorModel import LyricGeneratorModel
|
6 |
+
|
7 |
artists_df = pd.read_csv("artists.csv")
|
8 |
artist_names_list = list(artists_df["name"])
|
9 |
|
|
|
13 |
song_description = st.text_area("Song Description")
|
14 |
submit_button = st.button("Submit")
|
15 |
|
16 |
+
lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
17 |
"tjl223/artist-coherency-ensemble"
|
18 |
)
|
19 |
+
|
20 |
+
lyric_generator_model = LyricGeneratorModel(
|
21 |
+
"tjl223/testllama2-qlora-lyric-generator-with-description"
|
22 |
+
)
|
23 |
+
|
24 |
+
if submit_button:
|
25 |
+
prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_input}\n[Song Description] {song_description}"
|
26 |
+
lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
|
27 |
+
st.write(lyrics)
|