tjl223 commited on
Commit
f95071d
1 Parent(s): 0e2132d

testing lyric generation

Browse files
Files changed (3) hide show
  1. ArtistCoherencyModel.py +13 -0
  2. LyricGeneratorModel.py +23 -0
  3. app.py +12 -6
ArtistCoherencyModel.py CHANGED
@@ -64,6 +64,19 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
64
  with torch.no_grad():
65
  return self.forward(song_or_embedding)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def predict(
68
  self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
69
  ) -> Union[list[str], torch.Tensor]:
 
64
  with torch.no_grad():
65
  return self.forward(song_or_embedding)
66
 
67
+ def generate_artist_coherency_score(
68
+ self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
69
+ ) -> float:
70
+ coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
71
+ incoherent_index = self.ffnn.label2id[f"{artist_name}-incoherent"]
72
+ logits = self.generate_artist_coherency_logits(song_or_embedding)
73
+ coherent_score = logits[coherent_index]
74
+ incoherent_score = logits[incoherent_index]
75
+ score = (coherent_score + incoherent_score) * (
76
+ coherent_score / incoherent_score
77
+ )
78
+ return float(score)
79
+
80
  def predict(
81
  self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
82
  ) -> Union[list[str], torch.Tensor]:
LyricGeneratorModel.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+
4
+ class LyricGeneratorModel:
5
+ def __init__(self, repo_id: str):
6
+ self.model = AutoModelForCausalLM.from_pretrained(repo_id)
7
+
8
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
9
+ self.tokenizer.truncation_side = "right"
10
+ self.tokenizer.pad_token = self.tokenizer.eos_token
11
+ self.tokenizer.padding_side = "right"
12
+
13
+ def generate_lyrics(self, prompt: str, max_length: int):
14
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
15
+ input_ids = input_ids.to("cuda")
16
+
17
+ output_tokens = self.model.generate(
18
+ input_ids, do_sample=True, max_length=max_length
19
+ )
20
+
21
+ output_text = self.tokenizer.batch_decode(output_tokens)[0]
22
+
23
+ return output_text
app.py CHANGED
@@ -2,6 +2,8 @@ from ArtistCoherencyModel import ArtistCoherencyModel
2
  import streamlit as st
3
  import pandas as pd
4
 
 
 
5
  artists_df = pd.read_csv("artists.csv")
6
  artist_names_list = list(artists_df["name"])
7
 
@@ -11,11 +13,15 @@ song_title = st.text_input("Song Title")
11
  song_description = st.text_area("Song Description")
12
  submit_button = st.button("Submit")
13
 
14
- if submit_button:
15
- st.write(
16
- f"[Song Title] {song_title}\n[Song Artist] {artist_name_input}\n[Song Description] {song_description}"
17
- )
18
-
19
- ensemble_model = ArtistCoherencyModel.from_pretrained(
20
  "tjl223/artist-coherency-ensemble"
21
  )
 
 
 
 
 
 
 
 
 
 
2
  import streamlit as st
3
  import pandas as pd
4
 
5
+ from LyricGeneratorModel import LyricGeneratorModel
6
+
7
  artists_df = pd.read_csv("artists.csv")
8
  artist_names_list = list(artists_df["name"])
9
 
 
13
  song_description = st.text_area("Song Description")
14
  submit_button = st.button("Submit")
15
 
16
+ lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
 
 
 
 
 
17
  "tjl223/artist-coherency-ensemble"
18
  )
19
+
20
+ lyric_generator_model = LyricGeneratorModel(
21
+ "tjl223/testllama2-qlora-lyric-generator-with-description"
22
+ )
23
+
24
+ if submit_button:
25
+ prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_input}\n[Song Description] {song_description}"
26
+ lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
27
+ st.write(lyrics)