from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import PyTorchModelHubMixin import torch import torch.nn as nn from typing import Union from FFNN import FFNN class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin): def __init__(self, config: dict): super().__init__() coherency_model_repo_id = config["coherency_model_repo_id"] artist_model_repo_id = config["artist_model_repo_id"] ffnn_model_repo_id = config["ffnn_model_repo_id"] self.coherency_model_tokenizer = AutoTokenizer.from_pretrained( coherency_model_repo_id ) self.artist_model_tokenizer = AutoTokenizer.from_pretrained( artist_model_repo_id ) self.coherency_model = AutoModelForSequenceClassification.from_pretrained( coherency_model_repo_id ) self.artist_model = AutoModelForSequenceClassification.from_pretrained( artist_model_repo_id ) self.ffnn = FFNN.from_pretrained(ffnn_model_repo_id) def generate_artist_logits(self, song: str) -> torch.FloatTensor: inputs = self.artist_model_tokenizer( song, return_tensors="pt", max_length=512, truncation=True ) with torch.no_grad(): return self.artist_model(**inputs).logits def generate_coherency_logits(self, song: str) -> torch.FloatTensor: inputs = self.coherency_model_tokenizer( song, return_tensors="pt", max_length=512, truncation=True ) with torch.no_grad(): return self.coherency_model(**inputs).logits def generate_song_embedding(self, song: str) -> torch.FloatTensor: with torch.no_grad(): artist_logits = self.generate_artist_logits(song) coherency_logits = self.generate_coherency_logits(song) return torch.hstack((artist_logits[0], coherency_logits[0])) def forward(self, song_or_embedding: Union[str, torch.Tensor]): if type(song_or_embedding) is str: song_or_embedding = self.generate_song_embedding(song_or_embedding) return self.ffnn(song_or_embedding) def generate_artist_coherency_logits( self, song_or_embedding: Union[str, torch.Tensor] ) -> torch.FloatTensor: with torch.no_grad(): return self.forward(song_or_embedding) def generate_artist_coherency_score( self, artist_name: str, song_or_embedding: Union[str, torch.Tensor] ) -> float: coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"] incoherent_index = self.ffnn.label2id[f"{artist_name}-incoherent"] logits = self.generate_artist_coherency_logits(song_or_embedding) coherent_score = logits[coherent_index] incoherent_score = logits[incoherent_index] score = (coherent_score + incoherent_score) * ( coherent_score / incoherent_score ) return float(score) def predict( self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False ) -> Union[list[str], torch.Tensor]: if type(song_or_embedding) is str: song_or_embedding = self.generate_song_embedding(song_or_embedding) return self.ffnn.predict(song_or_embedding, return_ids=return_ids)