tjl223 commited on
Commit
3a18893
1 Parent(s): 6db72ff

changed metric based on testing

Browse files
Files changed (1) hide show
  1. ArtistCoherencyModel.py +1 -11
ArtistCoherencyModel.py CHANGED
@@ -83,19 +83,9 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
83
  self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
84
  ) -> float:
85
  coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
86
- incoherent_index = self.ffnn.label2id[f"{artist_name}-incoherent"]
87
  logits = self.generate_artist_coherency_logits(song_or_embedding)
88
  coherent_score = logits[coherent_index]
89
- incoherent_score = logits[incoherent_index]
90
- score = (
91
- 100
92
- * coherent_score
93
- * (coherent_score - incoherent_score)
94
- / (coherent_score + incoherent_score)
95
- )
96
- print(f"coherent_score: {float(coherent_score)}")
97
- print(f"incoherent_score: {float(incoherent_score)}")
98
- return float(score)
99
 
100
  def predict(
101
  self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
 
83
  self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
84
  ) -> float:
85
  coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
 
86
  logits = self.generate_artist_coherency_logits(song_or_embedding)
87
  coherent_score = logits[coherent_index]
88
+ return float(coherent_score)
 
 
 
 
 
 
 
 
 
89
 
90
  def predict(
91
  self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False