tjl223 commited on
Commit
634eb1a
1 Parent(s): d40f3b2

added predictors

Browse files
Files changed (1) hide show
  1. ArtistCoherencyModel.py +3 -2
ArtistCoherencyModel.py CHANGED
@@ -3,6 +3,7 @@ from huggingface_hub import PyTorchModelHubMixin
3
 
4
  import torch
5
  import torch.nn as nn
 
6
 
7
  from typing import Union
8
 
@@ -40,7 +41,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
40
  return self.artist_model(**inputs).logits
41
 
42
  def predict_artist(self, song: str) -> tuple[str, float]:
43
- logits = self.generate_artist_logits(song)[0]
44
  predicted_class_id = logits.argmax().item()
45
  return self.artist_model.config.id2label[predicted_class_id], float(
46
  logits[predicted_class_id]
@@ -54,7 +55,7 @@ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
54
  return self.coherency_model(**inputs).logits
55
 
56
  def predict_coherency(self, song: str) -> tuple[str, float]:
57
- logits = self.generate_coherency_logits(song)[0]
58
  predicted_class_id = logits.argmax().item()
59
  return self.coherency_model.config.id2label[predicted_class_id], float(
60
  logits[predicted_class_id]
 
3
 
4
  import torch
5
  import torch.nn as nn
6
+ import torch.nn.functional as F
7
 
8
  from typing import Union
9
 
 
41
  return self.artist_model(**inputs).logits
42
 
43
  def predict_artist(self, song: str) -> tuple[str, float]:
44
+ logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
45
  predicted_class_id = logits.argmax().item()
46
  return self.artist_model.config.id2label[predicted_class_id], float(
47
  logits[predicted_class_id]
 
55
  return self.coherency_model(**inputs).logits
56
 
57
  def predict_coherency(self, song: str) -> tuple[str, float]:
58
+ logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
59
  predicted_class_id = logits.argmax().item()
60
  return self.coherency_model.config.id2label[predicted_class_id], float(
61
  logits[predicted_class_id]