Spaces:
Paused
Paused
added predictors
Browse files- ArtistCoherencyModel.py +1 -1
ArtistCoherencyModel.py
CHANGED
@@ -54,7 +54,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
54 |
return self.coherency_model(**inputs).logits
|
55 |
|
56 |
def predict_coherency(self, song: str) -> tuple[str, float]:
|
57 |
-
logits = self.
|
58 |
predicted_class_id = logits.argmax().item()
|
59 |
return self.coherency_model.config.id2label[predicted_class_id], float(
|
60 |
logits[predicted_class_id]
|
|
|
54 |
return self.coherency_model(**inputs).logits
|
55 |
|
56 |
def predict_coherency(self, song: str) -> tuple[str, float]:
|
57 |
+
logits = self.generate_coherency_logits(song)[0]
|
58 |
predicted_class_id = logits.argmax().item()
|
59 |
return self.coherency_model.config.id2label[predicted_class_id], float(
|
60 |
logits[predicted_class_id]
|