File size: 3,787 Bytes
0d812a0
 
 
 
 
634eb1a
0d812a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4181b0f
634eb1a
4181b0f
6db72ff
4181b0f
 
 
0d812a0
 
 
 
 
 
 
4181b0f
634eb1a
4181b0f
6db72ff
4181b0f
 
 
0d812a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f95071d
 
 
 
 
 
059c109
f95071d
0d812a0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import PyTorchModelHubMixin

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Union

from FFNN import FFNN


class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: dict):
        super().__init__()

        coherency_model_repo_id = config["coherency_model_repo_id"]
        artist_model_repo_id = config["artist_model_repo_id"]
        ffnn_model_repo_id = config["ffnn_model_repo_id"]

        self.coherency_model_tokenizer = AutoTokenizer.from_pretrained(
            coherency_model_repo_id
        )
        self.artist_model_tokenizer = AutoTokenizer.from_pretrained(
            artist_model_repo_id
        )

        self.coherency_model = AutoModelForSequenceClassification.from_pretrained(
            coherency_model_repo_id
        )
        self.artist_model = AutoModelForSequenceClassification.from_pretrained(
            artist_model_repo_id
        )
        self.ffnn = FFNN.from_pretrained(ffnn_model_repo_id)

    def generate_artist_logits(self, song: str) -> torch.FloatTensor:
        inputs = self.artist_model_tokenizer(
            song, return_tensors="pt", max_length=512, truncation=True
        )
        with torch.no_grad():
            return self.artist_model(**inputs).logits

    def predict_artist(self, song: str) -> tuple[str, float]:
        logits = F.softmax(self.generate_artist_logits(song)[0], dim=0)
        predicted_class_id = logits.argmax().item()
        return self.artist_model.config.id2label[predicted_class_id], 100 * float(
            logits[predicted_class_id]
        )

    def generate_coherency_logits(self, song: str) -> torch.FloatTensor:
        inputs = self.coherency_model_tokenizer(
            song, return_tensors="pt", max_length=512, truncation=True
        )
        with torch.no_grad():
            return self.coherency_model(**inputs).logits

    def predict_coherency(self, song: str) -> tuple[str, float]:
        logits = F.softmax(self.generate_coherency_logits(song)[0], dim=0)
        predicted_class_id = logits.argmax().item()
        return self.coherency_model.config.id2label[predicted_class_id], 100 * float(
            logits[predicted_class_id]
        )

    def generate_song_embedding(self, song: str) -> torch.FloatTensor:
        with torch.no_grad():
            artist_logits = self.generate_artist_logits(song)
            coherency_logits = self.generate_coherency_logits(song)
            return torch.hstack((artist_logits[0], coherency_logits[0]))

    def forward(self, song_or_embedding: Union[str, torch.Tensor]):
        if type(song_or_embedding) is str:
            song_or_embedding = self.generate_song_embedding(song_or_embedding)

        return self.ffnn(song_or_embedding)

    def generate_artist_coherency_logits(
        self, song_or_embedding: Union[str, torch.Tensor]
    ) -> torch.FloatTensor:
        with torch.no_grad():
            return self.forward(song_or_embedding)

    def generate_artist_coherency_score(
        self, artist_name: str, song_or_embedding: Union[str, torch.Tensor]
    ) -> float:
        coherent_index = self.ffnn.label2id[f"{artist_name}-coherent"]
        logits = self.generate_artist_coherency_logits(song_or_embedding)
        coherent_score = logits[coherent_index]
        return 100 * float(coherent_score)

    def predict(
        self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
    ) -> Union[list[str], torch.Tensor]:
        if type(song_or_embedding) is str:
            song_or_embedding = self.generate_song_embedding(song_or_embedding)

        return self.ffnn.predict(song_or_embedding, return_ids=return_ids)