tjl223 commited on
Commit
6db72ff
1 Parent(s): 634eb1a

added predictors

Browse files
Files changed (1) hide show
  1. ArtistCoherencyModel.py +2 -2
ArtistCoherencyModel.py CHANGED
@@ -43,7 +43,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
43
  def predict_artist(self, song: str) -> tuple[str, float]:
44
  logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
45
  predicted_class_id = logits.argmax().item()
46
- return self.artist_model.config.id2label[predicted_class_id], float(
47
  logits[predicted_class_id]
48
  )
49
 
@@ -57,7 +57,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
57
  def predict_coherency(self, song: str) -> tuple[str, float]:
58
  logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
59
  predicted_class_id = logits.argmax().item()
60
- return self.coherency_model.config.id2label[predicted_class_id], float(
61
  logits[predicted_class_id]
62
  )
63
 
 
43
  def predict_artist(self, song: str) -> tuple[str, float]:
44
  logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
45
  predicted_class_id = logits.argmax().item()
46
+ return self.artist_model.config.id2label[predicted_class_id], 100 * float(
47
  logits[predicted_class_id]
48
  )
49
 
 
57
  def predict_coherency(self, song: str) -> tuple[str, float]:
58
  logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
59
  predicted_class_id = logits.argmax().item()
60
+ return self.coherency_model.config.id2label[predicted_class_id], 100 * float(
61
  logits[predicted_class_id]
62
  )
63