Spaces:
Paused
Paused
added predictors
Browse files- ArtistCoherencyModel.py +2 -2
ArtistCoherencyModel.py
CHANGED
@@ -43,7 +43,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
43 |
def predict_artist(self, song: str) -> tuple[str, float]:
|
44 |
logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
|
45 |
predicted_class_id = logits.argmax().item()
|
46 |
-
return self.artist_model.config.id2label[predicted_class_id], float(
|
47 |
logits[predicted_class_id]
|
48 |
)
|
49 |
|
@@ -57,7 +57,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
57 |
def predict_coherency(self, song: str) -> tuple[str, float]:
|
58 |
logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
|
59 |
predicted_class_id = logits.argmax().item()
|
60 |
-
return self.coherency_model.config.id2label[predicted_class_id], float(
|
61 |
logits[predicted_class_id]
|
62 |
)
|
63 |
|
|
|
43 |
def predict_artist(self, song: str) -> tuple[str, float]:
|
44 |
logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
|
45 |
predicted_class_id = logits.argmax().item()
|
46 |
+
return self.artist_model.config.id2label[predicted_class_id], 100 * float(
|
47 |
logits[predicted_class_id]
|
48 |
)
|
49 |
|
|
|
57 |
def predict_coherency(self, song: str) -> tuple[str, float]:
|
58 |
logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
|
59 |
predicted_class_id = logits.argmax().item()
|
60 |
+
return self.coherency_model.config.id2label[predicted_class_id], 100 * float(
|
61 |
logits[predicted_class_id]
|
62 |
)
|
63 |
|