tjl223 commited on
Commit
4181b0f
1 Parent(s): 4de628b

added predictors

Browse files
Files changed (2) hide show
  1. ArtistCoherencyModel.py +14 -0
  2. app.py +18 -5
ArtistCoherencyModel.py CHANGED
@@ -39,6 +39,13 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
39
  with torch.no_grad():
40
  return self.artist_model(**inputs).logits
41
 
 
 
 
 
 
 
 
42
  def generate_coherency_logits(self, song: str) -> torch.FloatTensor:
43
  inputs = self.coherency_model_tokenizer(
44
  song, return_tensors="pt", max_length=512, truncation=True
@@ -46,6 +53,13 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
46
  with torch.no_grad():
47
  return self.coherency_model(**inputs).logits
48
 
 
 
 
 
 
 
 
49
  def generate_song_embedding(self, song: str) -> torch.FloatTensor:
50
  with torch.no_grad():
51
  artist_logits = self.generate_artist_logits(song)
 
39
  with torch.no_grad():
40
  return self.artist_model(**inputs).logits
41
 
42
+ def predict_artist(self, song: str) -> tuple[str, float]:
43
+ logits = self.generate_artist_logits(song)
44
+ predicted_class_id = logits.argmax().item()
45
+ return self.artist_model.config.id2label[predicted_class_id], float(
46
+ logits[predicted_class_id]
47
+ )
48
+
49
  def generate_coherency_logits(self, song: str) -> torch.FloatTensor:
50
  inputs = self.coherency_model_tokenizer(
51
  song, return_tensors="pt", max_length=512, truncation=True
 
53
  with torch.no_grad():
54
  return self.coherency_model(**inputs).logits
55
 
56
+ def predict_coherency(self, song: str) -> tuple[str, float]:
57
+ logits = self.generate_artist_logits(song)
58
+ predicted_class_id = logits.argmax().item()
59
+ return self.coherency_model.config.id2label[predicted_class_id], float(
60
+ logits[predicted_class_id]
61
+ )
62
+
63
  def generate_song_embedding(self, song: str) -> torch.FloatTensor:
64
  with torch.no_grad():
65
  artist_logits = self.generate_artist_logits(song)
app.py CHANGED
@@ -62,13 +62,26 @@ evaluator_title = st.title("Lyric Evaluator")
62
  artist_name_for_evaluator = st.selectbox(
63
  "Artist", artist_names_list, key="genorator_select"
64
  )
65
- song_lyrics = st.text_area("Song Lyrics")
66
 
67
- if st.button("Submit", key="genorator_submit"):
68
  score = lyric_evaluator_model.generate_artist_coherency_score(
69
- artist_name_for_evaluator, song_lyrics.replace("\n\n", "\n")
70
  )
71
- print(song_lyrics)
72
  print(f"Score: {score}")
73
  st.write(f"Score: {score}")
74
- print(song_lyrics.replace("\n\n", "\n"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  artist_name_for_evaluator = st.selectbox(
63
  "Artist", artist_names_list, key="genorator_select"
64
  )
65
+ evaluator_song_lyrics = st.text_area("Song Lyrics")
66
 
67
+ if st.button("Submit", key="generator_submit"):
68
  score = lyric_evaluator_model.generate_artist_coherency_score(
69
+ artist_name_for_evaluator, evaluator_song_lyrics.replace("\n\n", "\n")
70
  )
 
71
  print(f"Score: {score}")
72
  st.write(f"Score: {score}")
73
+
74
+ predictor_title = st.title("Lyric Predictor")
75
+ predictor_song_lyrics = st.text_area("Song Lyrics", key="predictor song lyrics")
76
+
77
+ if st.button("Submit", key="predictor_submit"):
78
+ artist, artist_score = lyric_evaluator_model.predict_artist(
79
+ predictor_song_lyrics.replace("\n\n", "\n")
80
+ )
81
+ coherency, coherency_score = lyric_evaluator_model.predict_coherency(
82
+ predictor_song_lyrics.replace("\n\n", "\n")
83
+ )
84
+ print(f"Predicted {artist} with a score of {artist_score}")
85
+ st.write(f"Predicted {artist} with a score of {artist_score}")
86
+ print(f"Predicted {coherency} with a score of {coherency_score}")
87
+ st.write(f"Predicted {coherency} with a score of {coherency_score}")