Francesco commited on
Commit
e3c4cb8
1 Parent(s): 157ebae

added new way to sample songs that prevents duplicates

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -6,10 +6,7 @@ from langchain.chains import LLMChain
6
  from langchain.prompts import PromptTemplate
7
 
8
  load_dotenv()
9
- import json
10
  import os
11
- import random
12
- from enum import Enum
13
  from typing import List, Tuple
14
 
15
  import numpy as np
@@ -20,6 +17,7 @@ from langchain.schema import Document
20
  from data import load_db
21
  from names import DATASET_ID, MODEL_ID
22
  from storage import RedisStorage, UserInput
 
23
 
24
 
25
  class RetrievalType:
@@ -32,6 +30,7 @@ USE_STORAGE = os.environ.get("USE_STORAGE", "True").lower() in ("true", "t", "1"
32
 
33
  print("USE_STORAGE", USE_STORAGE)
34
 
 
35
  @st.cache_resource
36
  def init():
37
  embeddings = OpenAIEmbeddings(model=MODEL_ID)
@@ -139,7 +138,9 @@ def get_song(user_input: str, k: int = 20):
139
  docs, scores = zip(
140
  *normalize_scores_by_sum(filter_scores(matches, filter_threshold))
141
  )
142
- choosen_docs = np.random.choice(docs, size=number_of_displayed_songs, p=scores)
 
 
143
  return choosen_docs, emotions
144
 
145
 
@@ -176,5 +177,6 @@ def set_song(user_input):
176
  if not success_storage:
177
  print("[ERROR] was not able to store user_input")
178
 
 
179
  if run_btn:
180
  set_song(text_input)
 
6
  from langchain.prompts import PromptTemplate
7
 
8
  load_dotenv()
 
9
  import os
 
 
10
  from typing import List, Tuple
11
 
12
  import numpy as np
 
17
  from data import load_db
18
  from names import DATASET_ID, MODEL_ID
19
  from storage import RedisStorage, UserInput
20
+ from utils import weighted_random_sample
21
 
22
 
23
  class RetrievalType:
 
30
 
31
  print("USE_STORAGE", USE_STORAGE)
32
 
33
+
34
  @st.cache_resource
35
  def init():
36
  embeddings = OpenAIEmbeddings(model=MODEL_ID)
 
138
  docs, scores = zip(
139
  *normalize_scores_by_sum(filter_scores(matches, filter_threshold))
140
  )
141
+ choosen_docs = weighted_random_sample(
142
+ np.array(docs), np.array(scores), n=number_of_displayed_songs
143
+ ).tolist()
144
  return choosen_docs, emotions
145
 
146
 
 
177
  if not success_storage:
178
  print("[ERROR] was not able to store user_input")
179
 
180
+
181
  if run_btn:
182
  set_song(text_input)