tjl223 commited on
Commit
0d812a0
1 Parent(s): 7ed50b7

add required files for demo

Browse files
Files changed (5) hide show
  1. ArtistCoherencyModel.py +73 -0
  2. FFNN.py +89 -0
  3. app.py +12 -2
  4. artists.csv +21 -0
  5. requirements.txt +2 -1
ArtistCoherencyModel.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from typing import Union
8
+
9
+ from FFNN import FFNN
10
+
11
+
12
+ class ArtistCoherencyModel(nn.Module, PyTorchModelHubMixin):
13
+ def __init__(self, config: dict):
14
+ super().__init__()
15
+
16
+ coherency_model_repo_id = config["coherency_model_repo_id"]
17
+ artist_model_repo_id = config["artist_model_repo_id"]
18
+ ffnn_model_repo_id = config["ffnn_model_repo_id"]
19
+
20
+ self.coherency_model_tokenizer = AutoTokenizer.from_pretrained(
21
+ coherency_model_repo_id
22
+ )
23
+ self.artist_model_tokenizer = AutoTokenizer.from_pretrained(
24
+ artist_model_repo_id
25
+ )
26
+
27
+ self.coherency_model = AutoModelForSequenceClassification.from_pretrained(
28
+ coherency_model_repo_id
29
+ )
30
+ self.artist_model = AutoModelForSequenceClassification.from_pretrained(
31
+ artist_model_repo_id
32
+ )
33
+ self.ffnn = FFNN.from_pretrained(ffnn_model_repo_id)
34
+
35
+ def generate_artist_logits(self, song: str) -> torch.FloatTensor:
36
+ inputs = self.artist_model_tokenizer(
37
+ song, return_tensors="pt", max_length=512, truncation=True
38
+ )
39
+ with torch.no_grad():
40
+ return self.artist_model(**inputs).logits
41
+
42
+ def generate_coherency_logits(self, song: str) -> torch.FloatTensor:
43
+ inputs = self.coherency_model_tokenizer(
44
+ song, return_tensors="pt", max_length=512, truncation=True
45
+ )
46
+ with torch.no_grad():
47
+ return self.coherency_model(**inputs).logits
48
+
49
+ def generate_song_embedding(self, song: str) -> torch.FloatTensor:
50
+ with torch.no_grad():
51
+ artist_logits = self.generate_artist_logits(song)
52
+ coherency_logits = self.generate_coherency_logits(song)
53
+ return torch.hstack((artist_logits[0], coherency_logits[0]))
54
+
55
+ def forward(self, song_or_embedding: Union[str, torch.Tensor]):
56
+ if type(song_or_embedding) is str:
57
+ song_or_embedding = self.generate_song_embedding(song_or_embedding)
58
+
59
+ return self.ffnn(song_or_embedding)
60
+
61
+ def generate_artist_coherency_logits(
62
+ self, song_or_embedding: Union[str, torch.Tensor]
63
+ ) -> torch.FloatTensor:
64
+ with torch.no_grad():
65
+ return self.forward(song_or_embedding)
66
+
67
+ def predict(
68
+ self, song_or_embedding: Union[str, torch.Tensor], return_ids: bool = False
69
+ ) -> Union[list[str], torch.Tensor]:
70
+ if type(song_or_embedding) is str:
71
+ song_or_embedding = self.generate_song_embedding(song_or_embedding)
72
+
73
+ return self.ffnn.predict(song_or_embedding, return_ids=return_ids)
FFNN.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import PyTorchModelHubMixin
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from typing import Union
8
+
9
+
10
+ class FFNN(nn.Module, PyTorchModelHubMixin):
11
+ def __init__(self, config: dict) -> None:
12
+ super().__init__()
13
+
14
+ self.input_layer = nn.Linear(config["embedding_dim"], config["hidden_dim"])
15
+
16
+ self.hidden_layers = nn.ModuleList()
17
+ for layer_index in range(1, config["num_layers"]):
18
+ self.hidden_layers.append(
19
+ nn.Linear(config["hidden_dim"], config["hidden_dim"])
20
+ )
21
+
22
+ self.output_layer = nn.Linear(config["hidden_dim"], config["output_dim"])
23
+
24
+ self.id2label = config["id2label"]
25
+ self.label2id = config["label2id"]
26
+
27
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
28
+ z = F.relu(self.input_layer(embeddings))
29
+ for hidden_layer in self.hidden_layers:
30
+ z = F.relu(hidden_layer(z))
31
+ output = self.output_layer(z)
32
+ return F.softmax(output, dim=0)
33
+
34
+ def convert_logits_to_top_ids(self, logits: torch.Tensor) -> list[int]:
35
+ if len(logits.shape) != 1 and len(logits.shape) != 2:
36
+ raise ValueError("logits must either be a 1 or 2 dimensional tensor")
37
+
38
+ if len(logits.shape) == 1:
39
+ logits = [logits]
40
+
41
+ return [logits_row.argmax().item() for logits_row in logits]
42
+
43
+ def convert_logits_to_labels(self, logits: torch.Tensor) -> list[str]:
44
+ if len(logits.shape) != 1 and len(logits.shape) != 2:
45
+ raise ValueError("logits must either be a 1 or 2 dimensional tensor")
46
+
47
+ if len(logits.shape) == 1:
48
+ logits = [logits]
49
+
50
+ labels = []
51
+ for logits_row in logits:
52
+ labels.append(self.id2label[str(logits_row.argmax().item())])
53
+
54
+ return labels
55
+
56
+ def predict(
57
+ self, embeddings: torch.Tensor, return_ids: bool = False
58
+ ) -> Union[list[str], list[int]]:
59
+ if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
60
+ raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
61
+
62
+ with torch.no_grad():
63
+ logits = self.forward(embeddings)
64
+
65
+ if return_ids:
66
+ return self.convert_logits_to_top_ids(logits)
67
+
68
+ return self.convert_logits_to_labels(logits)
69
+
70
+ def generate_labeled_logits(self, embeddings: torch.Tensor) -> dict[str, float]:
71
+ if len(embeddings.shape) != 1 and len(embeddings.shape) != 2:
72
+ raise ValueError("embeddings must either be a 1 or 2 dimensional tensor")
73
+
74
+ with torch.no_grad():
75
+ logits = self.forward(embeddings)
76
+
77
+ if len(logits.shape) == 1:
78
+ logits = [logits]
79
+
80
+ labeled_logits_list = []
81
+
82
+ for logits_row in logits:
83
+ labeled_logits = {}
84
+ for id, logit in enumerate(logits_row):
85
+ labeled_logits[self.id2label[str(id)]] = logit
86
+
87
+ labeled_logits_list.append(labeled_logits)
88
+
89
+ return labeled_logits_list
app.py CHANGED
@@ -1,4 +1,14 @@
 
1
  import streamlit as st
 
2
 
3
- x = st.slider("Select a value")
4
- st.write(x, "squared is", x * x)
 
 
 
 
 
 
 
 
 
1
+ from ArtistCoherencyModel import ArtistCoherencyModel
2
  import streamlit as st
3
+ import pandas as pd
4
 
5
+ artists_df = pd.read_csv("artists.csv")
6
+ artist_names_list = list(artists_df["name"])
7
+
8
+
9
+ artist_name_input = st.selectbox("Artist", artist_names_list)
10
+ st.write(artist_name_input)
11
+
12
+ ensemble_model = ArtistCoherencyModel.from_pretrained(
13
+ "tjl223/artist-coherency-ensemble"
14
+ )
artists.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name, id
2
+ Taylor Swift, taylor-swift
3
+ Morgan Wallen, morgan-wallen
4
+ Megan Thee Stallion, megan-thee-stallion
5
+ Drake, drake
6
+ Nicki Minaj, nicki-minaj
7
+ Zach Bryan, zach-bryan
8
+ Grateful Dead, grateful-dead
9
+ Luke Combs, luke-combs
10
+ 21 Savage, 21-savage
11
+ SZA, sza
12
+ Olivia Rodrigo, olivia-rodrigo
13
+ Chris Stapleton, chris-stapleton
14
+ The Smile, the-smile
15
+ Doja Cat, doja-cat
16
+ Jack Harlow, jack-harlow
17
+ Noah Kahan, noah-kahan
18
+ Travis Scott, travis-scott
19
+ Jelly Roll, jelly-roll
20
+ The Weeknd, the-weeknd
21
+ Dua Lipa, dua-lipa
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers==4.40.0
2
  huggingface_hub==0.22.2
3
  torch==2.2.2
4
- numpy==1.26.4
 
 
1
  transformers==4.40.0
2
  huggingface_hub==0.22.2
3
  torch==2.2.2
4
+ numpy==1.26.4
5
+ pandas==2.2.2