Spaces:
Paused
Paused
added predictors
Browse files- ArtistCoherencyModel.py +3 -2
ArtistCoherencyModel.py
CHANGED
@@ -3,6 +3,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
|
|
6 |
|
7 |
from typing import Union
|
8 |
|
@@ -40,7 +41,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
40 |
return self.artist_model(**inputs).logits
|
41 |
|
42 |
def predict_artist(self, song: str) -> tuple[str, float]:
|
43 |
-
logits = self.generate_artist_logits(song)[0]
|
44 |
predicted_class_id = logits.argmax().item()
|
45 |
return self.artist_model.config.id2label[predicted_class_id], float(
|
46 |
logits[predicted_class_id]
|
@@ -54,7 +55,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
|
54 |
return self.coherency_model(**inputs).logits
|
55 |
|
56 |
def predict_coherency(self, song: str) -> tuple[str, float]:
|
57 |
-
logits = self.generate_coherency_logits(song)[0]
|
58 |
predicted_class_id = logits.argmax().item()
|
59 |
return self.coherency_model.config.id2label[predicted_class_id], float(
|
60 |
logits[predicted_class_id]
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
|
8 |
from typing import Union
|
9 |
|
|
|
41 |
return self.artist_model(**inputs).logits
|
42 |
|
43 |
def predict_artist(self, song: str) -> tuple[str, float]:
|
44 |
+
logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
|
45 |
predicted_class_id = logits.argmax().item()
|
46 |
return self.artist_model.config.id2label[predicted_class_id], float(
|
47 |
logits[predicted_class_id]
|
|
|
55 |
return self.coherency_model(**inputs).logits
|
56 |
|
57 |
def predict_coherency(self, song: str) -> tuple[str, float]:
|
58 |
+
logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
|
59 |
predicted_class_id = logits.argmax().item()
|
60 |
return self.coherency_model.config.id2label[predicted_class_id], float(
|
61 |
logits[predicted_class_id]
|