Spaces:
Paused
Paused
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from huggingface_hub import PyTorchModelHubMixin | |
import torch | |
import torch.nn as nn | |
from typing import Union | |
from FFNN import FFNN | |
class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, config: dict): | |
super().__init__() | |
coherency_model_repo_id = config["coherency_model_repo_id"] | |
artist_model_repo_id = config["artist_model_repo_id"] | |
ffnn_model_repo_id = config["ffnn_model_repo_id"] | |
self.coherency_model_tokenizer = AutoTokenizer.from_pretrained( | |
coherency_model_repo_id | |
) | |
self.artist_model_tokenizer = AutoTokenizer.from_pretrained( | |
artist_model_repo_id | |
) | |
self.coherency_model = AutoModelForSequenceClassification.from_pretrained( | |
coherency_model_repo_id | |
) | |
self.artist_model = AutoModelForSequenceClassification.from_pretrained( | |
artist_model_repo_id | |
) | |
self.ffnn = FFNN.from_pretrained(ffnn_model_repo_id) | |
def generate_artist_logits(self, song: str) -> torch.FloatTensor: | |
inputs = self.artist_model_tokenizer( | |
song, return_tensors="pt", max_length=512, truncation=True | |
) | |
with torch.no_grad(): | |
return self.artist_model(**inputs).logits | |
def generate_coherency_logits(self, song: str) -> torch.FloatTensor: | |
inputs = self.coherency_model_tokenizer( | |
song, return_tensors="pt", max_length=512, truncation=True | |
) | |
with torch.no_grad(): | |
return self.coherency_model(**inputs).logits | |
def generate_song_embedding(self, song: str) -> torch.FloatTensor: | |
with torch.no_grad(): | |
artist_logits = self.generate_artist_logits(song) | |
coherency_logits = self.generate_coherency_logits(song) | |
return torch.hstack((artist_logits[0], coherency_logits[0])) | |
def forward(self, song_or_embedding: Union[str, torch.Tensor]): | |
if type(song_or_embedding) is str: | |
song_or_embedding = self.generate_song_embedding(song_or_embedding) | |
return self.ffnn(song_or_embedding) | |
def generate_artist_coherency_logits( | |
self, song_or_embedding: Union[str, torch.Tensor] | |
) -> torch.FloatTensor: | |
with torch.no_grad(): | |
return self.forward(song_or_embedding) | |
def generate_artist_coherency_score( | |
self, artist_name: str, song_or_embedding: Union[str, torch.Tensor] | |
) -> float: | |
coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"] | |
incoherent_index = self.ffnn.label2id[f"{artist_name}-incoherent"] | |
logits = self.generate_artist_coherency_logits(song_or_embedding) | |
coherent_score = logits[coherent_index] | |
incoherent_score = logits[incoherent_index] | |
score = (coherent_score + incoherent_score) * ( | |
coherent_score / incoherent_score | |
) | |
return float(score) | |
def predict( | |
self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False | |
) -> Union[list[str], torch.Tensor]: | |
if type(song_or_embedding) is str: | |
song_or_embedding = self.generate_song_embedding(song_or_embedding) | |
return self.ffnn.predict(song_or_embedding, return_ids=return_ids) | |