from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import PyTorchModelHubMixin import torch import torch.nn as nn import torch.nn.functional as F from typing import Union from FFNN import FFNN class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin): def __init__(self, config: dict): super().__init__() coherency_model_repo_id = config["coherency_model_repo_id"] artist_model_repo_id = config["artist_model_repo_id"] ffnn_model_repo_id = config["ffnn_model_repo_id"] self.coherency_model_tokenizer = AutoTokenizer.from_pretrained( coherency_model_repo_id ) self.artist_model_tokenizer = AutoTokenizer.from_pretrained( artist_model_repo_id ) self.coherency_model = AutoModelForSequenceClassification.from_pretrained( coherency_model_repo_id ) self.artist_model = AutoModelForSequenceClassification.from_pretrained( artist_model_repo_id ) self.ffnn = FFNN.from_pretrained(ffnn_model_repo_id) def generate_artist_logits(self, song: str) -> torch.FloatTensor: inputs = self.artist_model_tokenizer( song, return_tensors="pt", max_length=512, truncation=True ) with torch.no_grad(): return self.artist_model(**inputs).logits def predict_artist(self, song: str) -> tuple[str, float]: logits = F.softmax(self.generate_artist_logits(song)[0], dim=0) predicted_class_id = logits.argmax().item() return self.artist_model.config.id2label[predicted_class_id], 100 * float( logits[predicted_class_id] ) def generate_coherency_logits(self, song: str) -> torch.FloatTensor: inputs = self.coherency_model_tokenizer( song, return_tensors="pt", max_length=512, truncation=True ) with torch.no_grad(): return self.coherency_model(**inputs).logits def predict_coherency(self, song: str) -> tuple[str, float]: logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0) predicted_class_id = logits.argmax().item() return self.coherency_model.config.id2label[predicted_class_id], 100 * float( logits[predicted_class_id] ) def generate_song_embedding(self, song: str) -> torch.FloatTensor: with torch.no_grad(): artist_logits = self.generate_artist_logits(song) coherency_logits = self.generate_coherency_logits(song) return torch.hstack((artist_logits[0], coherency_logits[0])) def forward(self, song_or_embedding: Union[str, torch.Tensor]): if type(song_or_embedding) is str: song_or_embedding = self.generate_song_embedding(song_or_embedding) return self.ffnn(song_or_embedding) def generate_artist_coherency_logits( self, song_or_embedding: Union[str, torch.Tensor] ) -> torch.FloatTensor: with torch.no_grad(): return self.forward(song_or_embedding) def generate_artist_coherency_score( self, artist_name: str, song_or_embedding: Union[str, torch.Tensor] ) -> float: coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"] logits = self.generate_artist_coherency_logits(song_or_embedding) coherent_score = logits[coherent_index] return 100 * float(coherent_score) def predict( self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False ) -> Union[list[str], torch.Tensor]: if type(song_or_embedding) is str: song_or_embedding = self.generate_song_embedding(song_or_embedding) return self.ffnn.predict(song_or_embedding, return_ids=return_ids)