Shea commited on
Commit
58a182b
1 Parent(s): 542351c

try new interface

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -1,41 +1,41 @@
1
 
2
- import numpy as np
3
  import gradio as gr
 
 
 
4
  from sentence_transformers import SentenceTransformer
 
5
 
 
 
 
 
 
 
6
 
7
  minilm = SentenceTransformer('all-MiniLM-L12-v2')
8
  #roberta = SentenceTransformer('all-distilroberta-v1')
9
  #glove = SentenceTransformer('average_word_embeddings_glove.840B.300d')
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- labels = ["contradiction", "entailment", "neutral"]
13
-
14
- def predict(sentence1, sentence2):
15
- sentence_pairs = np.array([[str(sentence1), str(sentence2)]])
16
- print(sentence1)
17
- print(sentence2)
18
- # test_data = BertSemanticDataGenerator(
19
- # sentence_pairs, labels=None, batch_size=1, shuffle=False, include_targets=False,
20
- # )
21
- # probs = model.predict(test_data[0])[0]
22
-
23
- # labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}
24
- # return labels_probs
25
-
26
- examples = [["Two women are observing something together.", "Two women are standing with their eyes closed."],
27
- ["A smiling costumed woman is holding an umbrella", "A happy woman in a fairy costume holds an umbrella"],
28
- ["A soccer game with multiple males playing", "Some men are playing a sport"],
29
- ]
30
-
31
- gr.Interface(
32
- fn=predict,
33
- title="Semantic Song Search",
34
- description = "Search for songs based on the meaning in the song's lyrics using a variety of embeddings",
35
- inputs=["text", "text"],
36
- examples=examples,
37
- #outputs=gr.Textbox(label='Prediction'),
38
- outputs=gr.outputs.Label(num_top_classes=3, label='Semantic similarity'),
39
- cache_examples=True,
40
- article = "Author: @sheacon",
41
- ).launch(debug=True, enable_queue=True)
 
1
 
2
+
3
  import gradio as gr
4
+ import pandas as pd
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
  from sentence_transformers import SentenceTransformer
8
+ from datasets import load_dataset
9
 
10
+ dataset = load_dataset(
11
+ "sheacon/song_lyrics",
12
+ revision="main" # tag name, or branch name, or commit hash
13
+ )
14
+
15
+ df = dataset.to_pandas()
16
 
17
  minilm = SentenceTransformer('all-MiniLM-L12-v2')
18
  #roberta = SentenceTransformer('all-distilroberta-v1')
19
  #glove = SentenceTransformer('average_word_embeddings_glove.840B.300d')
20
 
21
+ # Tokenize and encode the song lyrics using the embedding model
22
+ song_embeddings = df["embedding"].tolist()
23
+
24
+ def search_songs(text, top_n=5):
25
+ # Tokenize and encode the text entry using the same embedding model
26
+ text_embedding = minilm([text])[0]
27
+
28
+ # Calculate the cosine similarity between the text entry embedding and each song embedding
29
+ similarities = cosine_similarity([text_embedding], song_embeddings)[0]
30
+
31
+ # Sort the songs by similarity score and return the top N songs with their titles and lyrics
32
+ top_indices = similarities.argsort()[::-1][:top_n]
33
+ results = [{"title": df.iloc[i]["title"], "lyrics": df.iloc[i]["lyrics"]} for i in top_indices]
34
+
35
+ return results
36
+
37
+ # Define the Gradio interface
38
+ iface = gr.Interface(search_songs, "textbox", "text", examples=[["I'm feeling lonely tonight"]])
39
 
40
+ # Launch the interface
41
+ iface.launch()