Spaces:
Paused
Paused
add required files for demo
Browse files- ArtistCoherencyModel.py +73 -0
- FFNN.py +89 -0
- app.py +12 -2
- artists.csv +21 -0
- requirements.txt +2 -1
ArtistCoherencyModel.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
+
from huggingface_hub import PyTorchModelHubMixin
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from FFNN import FFNN
|
10 |
+
|
11 |
+
|
12 |
+
class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
|
13 |
+
def __init__(self, config: dict):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
coherency_model_repo_id = config["coherency_model_repo_id"]
|
17 |
+
artist_model_repo_id = config["artist_model_repo_id"]
|
18 |
+
ffnn_model_repo_id = config["ffnn_model_repo_id"]
|
19 |
+
|
20 |
+
self.coherency_model_tokenizer = AutoTokenizer.from_pretrained(
|
21 |
+
coherency_model_repo_id
|
22 |
+
)
|
23 |
+
self.artist_model_tokenizer = AutoTokenizer.from_pretrained(
|
24 |
+
artist_model_repo_id
|
25 |
+
)
|
26 |
+
|
27 |
+
self.coherency_model = AutoModelForSequenceClassification.from_pretrained(
|
28 |
+
coherency_model_repo_id
|
29 |
+
)
|
30 |
+
self.artist_model = AutoModelForSequenceClassification.from_pretrained(
|
31 |
+
artist_model_repo_id
|
32 |
+
)
|
33 |
+
self.ffnn = FFNN.from_pretrained(ffnn_model_repo_id)
|
34 |
+
|
35 |
+
def generate_artist_logits(self, song: str) -> torch.FloatTensor:
|
36 |
+
inputs = self.artist_model_tokenizer(
|
37 |
+
song, return_tensors="pt", max_length=512, truncation=True
|
38 |
+
)
|
39 |
+
with torch.no_grad():
|
40 |
+
return self.artist_model(**inputs).logits
|
41 |
+
|
42 |
+
def generate_coherency_logits(self, song: str) -> torch.FloatTensor:
|
43 |
+
inputs = self.coherency_model_tokenizer(
|
44 |
+
song, return_tensors="pt", max_length=512, truncation=True
|
45 |
+
)
|
46 |
+
with torch.no_grad():
|
47 |
+
return self.coherency_model(**inputs).logits
|
48 |
+
|
49 |
+
def generate_song_embedding(self, song: str) -> torch.FloatTensor:
|
50 |
+
with torch.no_grad():
|
51 |
+
artist_logits = self.generate_artist_logits(song)
|
52 |
+
coherency_logits = self.generate_coherency_logits(song)
|
53 |
+
return torch.hstack((artist_logits[0], coherency_logits[0]))
|
54 |
+
|
55 |
+
def forward(self, song_or_embedding: Union[str, torch.Tensor]):
|
56 |
+
if type(song_or_embedding) is str:
|
57 |
+
song_or_embedding = self.generate_song_embedding(song_or_embedding)
|
58 |
+
|
59 |
+
return self.ffnn(song_or_embedding)
|
60 |
+
|
61 |
+
def generate_artist_coherency_logits(
|
62 |
+
self, song_or_embedding: Union[str, torch.Tensor]
|
63 |
+
) -> torch.FloatTensor:
|
64 |
+
with torch.no_grad():
|
65 |
+
return self.forward(song_or_embedding)
|
66 |
+
|
67 |
+
def predict(
|
68 |
+
self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
|
69 |
+
) -> Union[list[str], torch.Tensor]:
|
70 |
+
if type(song_or_embedding) is str:
|
71 |
+
song_or_embedding = self.generate_song_embedding(song_or_embedding)
|
72 |
+
|
73 |
+
return self.ffnn.predict(song_or_embedding, return_ids=return_ids)
|
FFNN.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import PyTorchModelHubMixin
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
|
10 |
+
class FFNN(nn.Module, PyTorchModelHubMixin):
|
11 |
+
def __init__(self, config: dict) -> None:
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"])
|
15 |
+
|
16 |
+
self.hidden_layers = nn.ModuleList()
|
17 |
+
for layer_index in range(1, config["num_layers"]):
|
18 |
+
self.hidden_layers.append(
|
19 |
+
nn.Linear(config["hidden_dim"], config["hidden_dim"])
|
20 |
+
)
|
21 |
+
|
22 |
+
self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"])
|
23 |
+
|
24 |
+
self.id2label = config["id2label"]
|
25 |
+
self.label2id = config["label2id"]
|
26 |
+
|
27 |
+
def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
|
28 |
+
z = F.relu(self.input_layer(embeddings))
|
29 |
+
for hidden_layer in self.hidden_layers:
|
30 |
+
z = F.relu(hidden_layer(z))
|
31 |
+
output = self.output_layer(z)
|
32 |
+
return F.softmax(output, dim=0)
|
33 |
+
|
34 |
+
def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]:
|
35 |
+
if len(logits.shape) != 1 and len(logits.shape) != 2:
|
36 |
+
raise ValueError("logits must either be a 1 or 2 dimensional tensor")
|
37 |
+
|
38 |
+
if len(logits.shape) == 1:
|
39 |
+
logits = [logits]
|
40 |
+
|
41 |
+
return [logits_row.argmax().item() for logits_row in logits]
|
42 |
+
|
43 |
+
def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]:
|
44 |
+
if len(logits.shape) != 1 and len(logits.shape) != 2:
|
45 |
+
raise ValueError("logits must either be a 1 or 2 dimensional tensor")
|
46 |
+
|
47 |
+
if len(logits.shape) == 1:
|
48 |
+
logits = [logits]
|
49 |
+
|
50 |
+
labels = []
|
51 |
+
for logits_row in logits:
|
52 |
+
labels.append(self.id2label[str(logits_row.argmax().item())])
|
53 |
+
|
54 |
+
return labels
|
55 |
+
|
56 |
+
def predict(
|
57 |
+
self, embeddings: torch.Tensor, return_ids: bool = False
|
58 |
+
) -> Union[list[str], list[int]]:
|
59 |
+
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
|
60 |
+
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
logits = self.forward(embeddings)
|
64 |
+
|
65 |
+
if return_ids:
|
66 |
+
return self.convert_logits_to_top_ids(logits)
|
67 |
+
|
68 |
+
return self.convert_logits_to_labels(logits)
|
69 |
+
|
70 |
+
def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]:
|
71 |
+
if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
|
72 |
+
raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
logits = self.forward(embeddings)
|
76 |
+
|
77 |
+
if len(logits.shape) == 1:
|
78 |
+
logits = [logits]
|
79 |
+
|
80 |
+
labeled_logits_list = []
|
81 |
+
|
82 |
+
for logits_row in logits:
|
83 |
+
labeled_logits = {}
|
84 |
+
for id, logit in enumerate(logits_row):
|
85 |
+
labeled_logits[self.id2label[str(id)]] = logit
|
86 |
+
|
87 |
+
labeled_logits_list.append(labeled_logits)
|
88 |
+
|
89 |
+
return labeled_logits_list
|
app.py
CHANGED
@@ -1,4 +1,14 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ArtistCoherencyModel import ArtistCoherencyModel
|
2 |
import streamlit as st
|
3 |
+
import pandas as pd
|
4 |
|
5 |
+
artists_df = pd.read_csv("artists.csv")
|
6 |
+
artist_names_list = list(artists_df["name"])
|
7 |
+
|
8 |
+
|
9 |
+
artist_name_input = st.selectbox("Artist", artist_names_list)
|
10 |
+
st.write(artist_name_input)
|
11 |
+
|
12 |
+
ensemble_model = ArtistCoherencyModel.from_pretrained(
|
13 |
+
"tjl223/artist-coherency-ensemble"
|
14 |
+
)
|
artists.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name, id
|
2 |
+
Taylor Swift, taylor-swift
|
3 |
+
Morgan Wallen, morgan-wallen
|
4 |
+
Megan Thee Stallion, megan-thee-stallion
|
5 |
+
Drake, drake
|
6 |
+
Nicki Minaj, nicki-minaj
|
7 |
+
Zach Bryan, zach-bryan
|
8 |
+
Grateful Dead, grateful-dead
|
9 |
+
Luke Combs, luke-combs
|
10 |
+
21 Savage, 21-savage
|
11 |
+
SZA, sza
|
12 |
+
Olivia Rodrigo, olivia-rodrigo
|
13 |
+
Chris Stapleton, chris-stapleton
|
14 |
+
The Smile, the-smile
|
15 |
+
Doja Cat, doja-cat
|
16 |
+
Jack Harlow, jack-harlow
|
17 |
+
Noah Kahan, noah-kahan
|
18 |
+
Travis Scott, travis-scott
|
19 |
+
Jelly Roll, jelly-roll
|
20 |
+
The Weeknd, the-weeknd
|
21 |
+
Dua Lipa, dua-lipa
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
transformers==4.40.0
|
2 |
huggingface_hub==0.22.2
|
3 |
torch==2.2.2
|
4 |
-
numpy==1.26.4
|
|
|
|
1 |
transformers==4.40.0
|
2 |
huggingface_hub==0.22.2
|
3 |
torch==2.2.2
|
4 |
+
numpy==1.26.4
|
5 |
+
pandas==2.2.2
|