Update app.py
Browse files
app.py
CHANGED
@@ -26,46 +26,14 @@ from utils import ClaudeLLM
|
|
26 |
from qdrant_client import models, QdrantClient
|
27 |
from sentence_transformers import SentenceTransformer
|
28 |
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
31 |
# db_yt = FAISS.load_local('db_yt', embeddings)
|
32 |
mp_docs = {}
|
33 |
|
34 |
-
qdrant = QdrantClient(
|
35 |
-
"https://0a1b865d-8291-41ef-8c29-ca6c35e26391.us-east4-0.gcp.cloud.qdrant.io:6333",
|
36 |
-
prefer_grpc=True,
|
37 |
-
api_key=os.environ.get('Qdrant_Api_Key')
|
38 |
-
)
|
39 |
-
encoder = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
40 |
-
def q_retrieve_thoughts(query, n, db = "articles"):
|
41 |
-
# get collection all chunks here, to be used/ retrieved.
|
42 |
-
v_len = qdrant.get_collection(db).dict()['vectors_count']
|
43 |
-
hits = qdrant.search(
|
44 |
-
collection_name="articles",
|
45 |
-
query_vector=encoder.encode(query).tolist(),
|
46 |
-
limit=v_len # TO-DO: know the right number of thoughts existing maybe using get_collection
|
47 |
-
)
|
48 |
-
df = pd.DataFrame.from_records([dict(hit) for hit in hits] )
|
49 |
-
payload = pd.DataFrame(list(df['payload'].values[:]))
|
50 |
-
|
51 |
-
# payload.rename(columns = ['id': 'order_id'])
|
52 |
-
# payload['id'] = df['id']
|
53 |
-
|
54 |
-
payload['score'] = df['score']
|
55 |
-
del df
|
56 |
-
payload.sort_values('score', ascending = False, inplace = True)
|
57 |
-
|
58 |
-
tier_1 = payload
|
59 |
-
|
60 |
-
chunks_1 = tier_1.groupby(['_id', ]).apply(lambda x: "\n...\n".join(x.sort_values('id')['page_content'].values)).values
|
61 |
-
score = tier_1.groupby(['_id', ]).apply(lambda x: x['score'].mean()).values
|
62 |
-
|
63 |
-
tier_1_adjusted = tier_1.groupby(['_id', ]).first().reset_index()[['_id', 'title', 'url', 'author']]
|
64 |
-
tier_1_adjusted['content'] = list(chunks_1)
|
65 |
-
tier_1_adjusted['score'] = score
|
66 |
-
tier_1_adjusted = tier_1_adjusted[tier_1_adjusted['score']>0.5]
|
67 |
-
tier_1_adjusted.sort_values('score', ascending = False, inplace = True)
|
68 |
-
return {'tier 1':tier_1_adjusted, }
|
69 |
|
70 |
def retrieve_thoughts(query, n, db):
|
71 |
|
@@ -120,6 +88,27 @@ def qa_retrieve_art(query,):
|
|
120 |
return {'Reference': reference}
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def qa_retrieve_yt(query,):
|
124 |
|
125 |
docs = ""
|
@@ -145,13 +134,15 @@ def flush():
|
|
145 |
return None
|
146 |
|
147 |
|
148 |
-
|
149 |
-
|
|
|
|
|
150 |
inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
|
151 |
outputs = gr.components.JSON(label="articles"))
|
152 |
# ref_yt = gr.Interface(fn=qa_retrieve_yt, label="Youtube",
|
153 |
# inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
|
154 |
# outputs = gr.components.JSON(label="youtube"),title = "youtube", examples=examples)
|
155 |
-
demo = gr.Parallel(
|
156 |
|
157 |
demo.launch()
|
|
|
26 |
from qdrant_client import models, QdrantClient
|
27 |
from sentence_transformers import SentenceTransformer
|
28 |
|
29 |
+
embeddings = HuggingFaceEmbeddings()
|
30 |
+
embeddings_1 = HuggingFaceEmbeddings(model_name = "BAAI/bge-large-en-v1.5")
|
31 |
+
|
32 |
+
db_art = FAISS.load_local('db_art', embeddings)
|
33 |
+
db_art_1 = FAISS.load_local('db_art_1', embeddings_1)
|
34 |
# db_yt = FAISS.load_local('db_yt', embeddings)
|
35 |
mp_docs = {}
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def retrieve_thoughts(query, n, db):
|
39 |
|
|
|
88 |
return {'Reference': reference}
|
89 |
|
90 |
|
91 |
+
def qa_retrieve_bge(query,):
|
92 |
+
|
93 |
+
docs = ""
|
94 |
+
|
95 |
+
global db_art_1
|
96 |
+
|
97 |
+
global mp_docs
|
98 |
+
thoughts = retrieve_thoughts(query, 0, db_art_1)
|
99 |
+
if not(thoughts):
|
100 |
+
|
101 |
+
if mp_docs:
|
102 |
+
thoughts = mp_docs
|
103 |
+
else:
|
104 |
+
mp_docs = thoughts
|
105 |
+
|
106 |
+
tier_1 = thoughts['tier 1']
|
107 |
+
|
108 |
+
reference = tier_1[['_id', 'url', 'author', 'title', 'chunks', 'score']].to_dict('records')
|
109 |
+
|
110 |
+
return {'Reference': reference}
|
111 |
+
|
112 |
def qa_retrieve_yt(query,):
|
113 |
|
114 |
docs = ""
|
|
|
134 |
return None
|
135 |
|
136 |
|
137 |
+
ref_art_1 = gr.Interface(fn=qa_retrieve_art, label="bge Articles",
|
138 |
+
inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
|
139 |
+
outputs = gr.components.JSON(label="articles"))
|
140 |
+
ref_art = gr.Interface(fn=qa_retrieve_art, label="mpnet Articles",
|
141 |
inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
|
142 |
outputs = gr.components.JSON(label="articles"))
|
143 |
# ref_yt = gr.Interface(fn=qa_retrieve_yt, label="Youtube",
|
144 |
# inputs=gr.inputs.Textbox(lines=5, label="what would you like to learn about?"),
|
145 |
# outputs = gr.components.JSON(label="youtube"),title = "youtube", examples=examples)
|
146 |
+
demo = gr.Parallel( ref_art_1, ref_art_1)
|
147 |
|
148 |
demo.launch()
|