Spaces:
Paused
Paused
changed metric based on testing
Browse files- ArtistCoherencyModel.py +1 -11
ArtistCoherencyModel.py
CHANGED
@@ -83,19 +83,9 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
83 |
self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
|
84 |
) -> float:
|
85 |
coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
|
86 |
-
incoherent_index = self.ffnn.label2id[f"{artist_name}-incoherent"]
|
87 |
logits = self.generate_artist_coherency_logits(song_or_embedding)
|
88 |
coherent_score = logits[coherent_index]
|
89 |
-
|
90 |
-
score = (
|
91 |
-
100
|
92 |
-
* coherent_score
|
93 |
-
* (coherent_score - incoherent_score)
|
94 |
-
/ (coherent_score + incoherent_score)
|
95 |
-
)
|
96 |
-
print(f"coherent_score: {float(coherent_score)}")
|
97 |
-
print(f"incoherent_score: {float(incoherent_score)}")
|
98 |
-
return float(score)
|
99 |
|
100 |
def predict(
|
101 |
self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
|
|
|
83 |
self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
|
84 |
) -> float:
|
85 |
coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
|
|
|
86 |
logits = self.generate_artist_coherency_logits(song_or_embedding)
|
87 |
coherent_score = logits[coherent_index]
|
88 |
+
return float(coherent_score)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def predict(
|
91 |
self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
|