Spaces:
Runtime error
Runtime error
added new way to sample songs that prevents duplicates
Browse files
app.py
CHANGED
@@ -6,10 +6,7 @@ from langchain.chains import LLMChain
|
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
|
8 |
load_dotenv()
|
9 |
-
import json
|
10 |
import os
|
11 |
-
import random
|
12 |
-
from enum import Enum
|
13 |
from typing import List, Tuple
|
14 |
|
15 |
import numpy as np
|
@@ -20,6 +17,7 @@ from langchain.schema import Document
|
|
20 |
from data import load_db
|
21 |
from names import DATASET_ID, MODEL_ID
|
22 |
from storage import RedisStorage, UserInput
|
|
|
23 |
|
24 |
|
25 |
class RetrievalType:
|
@@ -32,6 +30,7 @@ USE_STORAGE = os.environ.get("USE_STORAGE", "True").lower() in ("true", "t", "1"
|
|
32 |
|
33 |
print("USE_STORAGE", USE_STORAGE)
|
34 |
|
|
|
35 |
@st.cache_resource
|
36 |
def init():
|
37 |
embeddings = OpenAIEmbeddings(model=MODEL_ID)
|
@@ -139,7 +138,9 @@ def get_song(user_input: str, k: int = 20):
|
|
139 |
docs, scores = zip(
|
140 |
*normalize_scores_by_sum(filter_scores(matches, filter_threshold))
|
141 |
)
|
142 |
-
choosen_docs =
|
|
|
|
|
143 |
return choosen_docs, emotions
|
144 |
|
145 |
|
@@ -176,5 +177,6 @@ def set_song(user_input):
|
|
176 |
if not success_storage:
|
177 |
print("[ERROR] was not able to store user_input")
|
178 |
|
|
|
179 |
if run_btn:
|
180 |
set_song(text_input)
|
|
|
6 |
from langchain.prompts import PromptTemplate
|
7 |
|
8 |
load_dotenv()
|
|
|
9 |
import os
|
|
|
|
|
10 |
from typing import List, Tuple
|
11 |
|
12 |
import numpy as np
|
|
|
17 |
from data import load_db
|
18 |
from names import DATASET_ID, MODEL_ID
|
19 |
from storage import RedisStorage, UserInput
|
20 |
+
from utils import weighted_random_sample
|
21 |
|
22 |
|
23 |
class RetrievalType:
|
|
|
30 |
|
31 |
print("USE_STORAGE", USE_STORAGE)
|
32 |
|
33 |
+
|
34 |
@st.cache_resource
|
35 |
def init():
|
36 |
embeddings = OpenAIEmbeddings(model=MODEL_ID)
|
|
|
138 |
docs, scores = zip(
|
139 |
*normalize_scores_by_sum(filter_scores(matches, filter_threshold))
|
140 |
)
|
141 |
+
choosen_docs = weighted_random_sample(
|
142 |
+
np.array(docs), np.array(scores), n=number_of_displayed_songs
|
143 |
+
).tolist()
|
144 |
return choosen_docs, emotions
|
145 |
|
146 |
|
|
|
177 |
if not success_storage:
|
178 |
print("[ERROR] was not able to store user_input")
|
179 |
|
180 |
+
|
181 |
if run_btn:
|
182 |
set_song(text_input)
|