Spaces:
Running
Running
Rajat.bans
commited on
Commit
•
0b952f2
1
Parent(s):
3da4cba
Updated the code with comments and type definitions
Browse files
rag.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
rag.py
CHANGED
@@ -11,14 +11,18 @@ import time
|
|
11 |
import os
|
12 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
13 |
from dotenv import load_dotenv
|
14 |
-
import pandas as pd
|
15 |
from typing import List, Tuple, Dict, Any
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class CLUSTERING:
|
19 |
def cluster_embeddings(
|
20 |
self,
|
21 |
-
embeddings:
|
22 |
clustering_algo: str,
|
23 |
no_of_clusters: int,
|
24 |
no_of_points: int,
|
@@ -27,7 +31,7 @@ class CLUSTERING:
|
|
27 |
Clusters embeddings using the specified clustering algorithm and returns the indices of points in each cluster.
|
28 |
|
29 |
Parameters:
|
30 |
-
embeddings (
|
31 |
clustering_algo (str): The clustering algorithm to use ("kmeans-cc", "kmeans-sp", or "spectral").
|
32 |
no_of_clusters (int): The number of clusters to form.
|
33 |
no_of_points (int): The maximum number of points to include in each cluster.
|
@@ -123,7 +127,7 @@ class VECTOR_DB:
|
|
123 |
|
124 |
def queryVectorDB(
|
125 |
self, page_information: str, threshold: float = None
|
126 |
-
) -> Tuple[List[List[Tuple]], float]:
|
127 |
"""
|
128 |
Query the vector database and cluster the retrieved documents.
|
129 |
|
@@ -196,7 +200,7 @@ class FAISS_DB:
|
|
196 |
metadata: List[Dict[str, Any]],
|
197 |
CHUNK_SIZE: int = 2048,
|
198 |
CHUNK_OVERLAP: int = 512,
|
199 |
-
) -> List[
|
200 |
"""
|
201 |
Split the provided content into chunks with metadata.
|
202 |
|
@@ -207,7 +211,7 @@ class FAISS_DB:
|
|
207 |
CHUNK_OVERLAP (int): The overlap between chunks. Default is 512.
|
208 |
|
209 |
Returns:
|
210 |
-
List[
|
211 |
"""
|
212 |
text_splitter = RecursiveCharacterTextSplitter(
|
213 |
chunk_size=CHUNK_SIZE,
|
@@ -218,13 +222,13 @@ class FAISS_DB:
|
|
218 |
return split_docs
|
219 |
|
220 |
def createDBFromDocs(
|
221 |
-
self, split_docs: List[
|
222 |
) -> FAISS:
|
223 |
"""
|
224 |
Create a FAISS database from the provided documents and embeddings model.
|
225 |
|
226 |
Parameters:
|
227 |
-
split_docs (List[
|
228 |
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
|
229 |
|
230 |
Returns:
|
@@ -235,7 +239,7 @@ class FAISS_DB:
|
|
235 |
|
236 |
def createAndSaveDBInChunks(
|
237 |
self,
|
238 |
-
split_docs: List[
|
239 |
embeddings_model: HuggingFaceEmbeddings,
|
240 |
DB_FAISS_PATH: str,
|
241 |
chunk_size: int = 1000,
|
@@ -244,7 +248,7 @@ class FAISS_DB:
|
|
244 |
Create and save the FAISS database in chunks.
|
245 |
|
246 |
Parameters:
|
247 |
-
split_docs (List[
|
248 |
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
|
249 |
DB_FAISS_PATH (str): The path to save the FAISS database.
|
250 |
chunk_size (int): The size of each chunk. Default is 1000.
|
@@ -347,7 +351,7 @@ class FAISS_DB:
|
|
347 |
cv = fl[6:-6]
|
348 |
ind = max(ind, int(cv))
|
349 |
|
350 |
-
all_dbs = []
|
351 |
for i in range(0, ind + 1, 2):
|
352 |
print(i)
|
353 |
db1 = FAISS.load_local(
|
@@ -389,12 +393,13 @@ class FAISS_DB:
|
|
389 |
class ADS_RAG:
|
390 |
def __init__(
|
391 |
self,
|
392 |
-
db,
|
393 |
-
qa_model_name,
|
394 |
-
relation_check_best_value_thresh,
|
395 |
-
bestRelationSystemPrompt,
|
396 |
-
bestQuestionSystemPrompt,
|
397 |
-
):
|
|
|
398 |
self.client = OpenAI()
|
399 |
self.db = db
|
400 |
self.qa_model_name = qa_model_name
|
@@ -402,7 +407,18 @@ class ADS_RAG:
|
|
402 |
self.bestRelationSystemPrompt = bestRelationSystemPrompt
|
403 |
self.bestQuestionSystemPrompt = bestQuestionSystemPrompt
|
404 |
|
405 |
-
def callOpenAiApi(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
while True:
|
407 |
try:
|
408 |
response = self.client.chat.completions.create(
|
@@ -423,12 +439,25 @@ class ADS_RAG:
|
|
423 |
|
424 |
def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
|
425 |
self,
|
426 |
-
page_information,
|
427 |
-
adsData,
|
428 |
-
relationSystemPrompt,
|
429 |
-
questionSystemPrompt,
|
430 |
-
bestRetreivedAdValue,
|
431 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
if adsData == "":
|
433 |
return ({"reasoning": "No ads data present", "classification": 0}, 0), (
|
434 |
{"reasoning": "", "question": "", "options": []},
|
@@ -454,11 +483,9 @@ class ADS_RAG:
|
|
454 |
}
|
455 |
]
|
456 |
)
|
457 |
-
|
458 |
-
tokens_used_question = 0
|
459 |
else:
|
460 |
relation_answer["reasoning"] = (
|
461 |
-
"First
|
462 |
)
|
463 |
|
464 |
if relation_answer["classification"] != 0:
|
@@ -483,7 +510,18 @@ class ADS_RAG:
|
|
483 |
"tokens_used_question": tokens_used_question,
|
484 |
}
|
485 |
|
486 |
-
def convertDocumentsClustersToStringForApiCall(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
key_counter = count(1)
|
488 |
res = json.dumps(
|
489 |
{
|
@@ -497,14 +535,30 @@ class ADS_RAG:
|
|
497 |
return res
|
498 |
|
499 |
def getRagResponse(
|
500 |
-
self,
|
501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
curr_relation_prompt = self.bestRelationSystemPrompt
|
503 |
-
if RelationPrompt
|
504 |
curr_relation_prompt = RelationPrompt
|
505 |
|
506 |
curr_question_prompt = self.bestQuestionSystemPrompt
|
507 |
-
if QuestionPrompt
|
508 |
curr_question_prompt = QuestionPrompt
|
509 |
|
510 |
documents_clusters, best_value = self.db.queryVectorDB(
|
@@ -520,7 +574,18 @@ class ADS_RAG:
|
|
520 |
|
521 |
return answer, documents_clusters
|
522 |
|
523 |
-
def changeDocumentsToPrintableString(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
res = ""
|
525 |
i = 0
|
526 |
for ind, documents_cluster in enumerate(documents_clusters):
|
@@ -531,7 +596,19 @@ class ADS_RAG:
|
|
531 |
res += "\n"
|
532 |
return res
|
533 |
|
534 |
-
def changeResponseToPrintableString(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
if task == "relation":
|
536 |
return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
|
537 |
res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
|
@@ -543,8 +620,21 @@ class ADS_RAG:
|
|
543 |
return res
|
544 |
|
545 |
def logResult(
|
546 |
-
self,
|
547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
print(
|
549 |
"**************************************************************************************************\n",
|
550 |
# curr_relation_prompt,
|
@@ -555,13 +645,32 @@ class ADS_RAG:
|
|
555 |
)
|
556 |
|
557 |
def getRagGradioResponse(
|
558 |
-
self,
|
559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
answer, documents_clusters = self.getRagResponse(
|
561 |
page_information, threshold, RelationPrompt, QuestionPrompt
|
562 |
)
|
|
|
563 |
self.logResult(RelationPrompt, QuestionPrompt, page_information, answer)
|
564 |
|
|
|
565 |
docs_info = self.changeDocumentsToPrintableString(documents_clusters)
|
566 |
relation_answer_string = self.changeResponseToPrintableString(
|
567 |
answer["relation_answer"], "relation"
|
@@ -569,79 +678,70 @@ class ADS_RAG:
|
|
569 |
question_answer_string = self.changeResponseToPrintableString(
|
570 |
answer["question_answer"], "question"
|
571 |
)
|
|
|
572 |
question_tokens = answer["tokens_used_question"]
|
573 |
relation_tokens = answer["tokens_used_relation"]
|
574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
return full_response
|
576 |
|
577 |
|
578 |
-
class
|
579 |
-
def __init__(self):
|
|
|
580 |
load_dotenv(override=True)
|
581 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
582 |
|
583 |
-
self.
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
|
589 |
-
|
590 |
-
# embeddings_oa = OpenAIEmbeddings(model=embedding_model_oa)
|
591 |
-
# embeddings_hf = HuggingFaceEmbeddings(model_name = embedding_model_hf, show_progress = True)
|
592 |
-
embeddings_hf = HuggingFaceEmbeddings(model_name=self.embedding_model_hf)
|
593 |
vector_db = VECTOR_DB(
|
594 |
-
0.75,
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
)
|
|
|
|
|
596 |
rag = ADS_RAG(
|
597 |
-
vector_db,
|
598 |
-
"gpt-3.5-turbo",
|
599 |
-
0.6,
|
600 |
-
self.getRelationSystemPrompt(),
|
601 |
-
self.getQuestionSystemPrompt(),
|
602 |
)
|
603 |
return rag
|
604 |
|
605 |
-
def
|
606 |
-
|
607 |
-
|
608 |
-
)
|
609 |
-
data = pd.read_csv(data_file_path, sep="\t")
|
610 |
-
data.dropna(axis=0, how="any", inplace=True)
|
611 |
-
|
612 |
-
# data.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
|
613 |
-
# ad_title_content = list(data["ad_title"].values)
|
614 |
-
def get_core_content(row):
|
615 |
-
url_content = row["url_content"]
|
616 |
-
url_title = row["url_title"]
|
617 |
-
return (
|
618 |
-
"Page Title -: "
|
619 |
-
+ url_title
|
620 |
-
+ "\nPage Content -: "
|
621 |
-
+ ". ".join(url_content.split(". ")[:7])
|
622 |
-
)
|
623 |
-
|
624 |
-
data["core_content"] = data.apply(get_core_content, axis=1)
|
625 |
-
# for i in range(len(data)):
|
626 |
-
# print(data.loc[i, 'url'])
|
627 |
-
# print(data.loc[i, 'url_content'])
|
628 |
-
# print(data.loc[i, 'core_content'])
|
629 |
-
# print()
|
630 |
-
# if(i > 10):
|
631 |
-
# break
|
632 |
-
return data
|
633 |
-
|
634 |
-
def GradioRagPreProcessing(self):
|
635 |
-
data_file_path = (
|
636 |
-
"./data/149_adclick_Jun_facty_activeBeat_Health_dupRemoved0.85_campaign.tsv"
|
637 |
-
)
|
638 |
-
data = pd.read_csv(data_file_path, sep="\t")
|
639 |
-
# data.dropna(axis=0, how="any", inplace=True)
|
640 |
-
data.drop_duplicates(subset=["ad_title", "ad_desc"], inplace=True)
|
641 |
-
ad_title_content = list(data["ad_title"].values)
|
642 |
-
return ad_title_content
|
643 |
|
644 |
-
|
|
|
|
|
645 |
bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
|
646 |
2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
|
647 |
3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
|
@@ -699,7 +799,13 @@ The ADS_DATA provided to you is as follows:
|
|
699 |
# """
|
700 |
return bestQuestionSystemPrompt
|
701 |
|
702 |
-
def getRelationSystemPrompt(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
703 |
bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
|
704 |
|
705 |
---------------------------------------
|
@@ -727,129 +833,82 @@ The ADS_DATA provided to you is as follows:
|
|
727 |
return bestRelationSystemPrompt
|
728 |
|
729 |
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
# WEB DATA PROCESSING
|
735 |
-
# from urllib.parse import urlparse
|
736 |
-
# import re
|
737 |
-
# def get_cleaned_url(url):
|
738 |
-
# path = urlparse(url).path.strip()
|
739 |
-
# cleaned_path = re.sub(r'[^a-zA-Z0-9\s-]', ' ', path).replace('/', '')
|
740 |
-
# cleaned_path = re.sub(r'[^a-zA-Z0-9\s]', ' ', path).replace('-', '')
|
741 |
-
# return cleaned_path.strip()
|
742 |
-
|
743 |
-
# df['cleaned_url'] = df['url'].map(get_cleaned_url)
|
744 |
-
# df.dropna(subset=['cleaned_url', 'url_content', 'url_title'], inplace=True)
|
745 |
-
# df['combined'] = df['cleaned_url'] + ". " + df['url_title'] + ". " + df['url_content']
|
746 |
-
# content = df["combined"].tolist()
|
747 |
-
# metadata = [
|
748 |
-
# {"title": row["url_title"], "url": row["url"]}
|
749 |
-
# for _, row in df.iterrows()
|
750 |
-
# ]
|
751 |
-
# ------------------------------
|
752 |
-
# ADS DATA PROCESSING
|
753 |
-
# # df.dropna(axis=0, how='any', inplace=True)
|
754 |
-
# df.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
|
755 |
-
# dfRPC = df[df['RPC'] > 0]
|
756 |
-
# dfRPC.dropna(how = 'any', inplace=True)
|
757 |
-
# dfCampaign = df[df['type'] == 'campaign']
|
758 |
-
# dfCampaign.fillna('', inplace=True)
|
759 |
-
# df = pd.concat([dfRPC, dfCampaign])
|
760 |
-
# df
|
761 |
-
|
762 |
-
# content = (df["ad_title"] + ". " + df["ad_desc"]).tolist()
|
763 |
-
# metadata = [
|
764 |
-
# {"publisher_url": row["publisher_url"], "keyword_term": row["keyword_term"], "ad_display_url": row["ad_display_url"], "revenue": row["revenue"], "ad_click_count": row["ad_click_count"], "RPC": row["RPC"], "Type": row["type"]}
|
765 |
-
# # {"revenue": row["revenue"], "ad_click_count": row["ad_click_count"]}
|
766 |
-
# for _, row in df.iterrows()
|
767 |
-
# ]
|
768 |
-
# --------------------------------
|
769 |
-
|
770 |
-
# faiss_db = FAISS_DB()
|
771 |
-
# db = faiss_db.createDBFromDocs(content, metadata)
|
772 |
-
# faiss_db.saveDB(db, '.')
|
773 |
-
|
774 |
-
# ************************************************************************
|
775 |
-
# PARALLELY CREATING DB - BACKUP FOR FUTURE USE
|
776 |
-
# import time
|
777 |
-
# import threading
|
778 |
-
# import os
|
779 |
-
# one_db_docs_size = 1000
|
780 |
-
# starting_i = 0
|
781 |
-
# parallel_processes = 3
|
782 |
-
# def split_list(lst, n):
|
783 |
-
# k, m = divmod(len(lst), n)
|
784 |
-
# return (lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
|
785 |
-
# def createDBForIndexes(inds):
|
786 |
-
# for i in inds:
|
787 |
-
# ctime = time.time()
|
788 |
-
# print(f"Processing {i}")
|
789 |
-
# if not os.path.exists(DB_FAISS_PATH + "/index_{int(i/one_db_docs_size)}.faiss"):
|
790 |
-
# db = FAISS.from_documents(split_docs[i:i+one_db_docs_size], embeddings_hf)
|
791 |
-
# db.save_local(DB_FAISS_PATH, index_name = f"index_{int(i/one_db_docs_size)}")
|
792 |
-
# ctime = time.time() - ctime
|
793 |
-
# print(f"{i})Time taken", ctime)
|
794 |
-
# indexes = split_list(range(starting_i, len(split_docs), one_db_docs_size), parallel_processes)
|
795 |
-
# threads = []
|
796 |
-
# for i, one_process_indexes in enumerate(indexes):
|
797 |
-
# thread = threading.Thread(target=createDBForIndexes, args=(one_process_indexes,))
|
798 |
-
# thread.start()
|
799 |
-
# threads.append(thread)
|
800 |
-
# for thread in threads:
|
801 |
-
# thread.join()
|
802 |
-
# print("All threads completed.")
|
803 |
-
# ************************************************************************
|
804 |
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
)
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
output = gr.Textbox(label="Output")
|
|
|
|
|
838 |
submit_btn = gr.Button("Submit")
|
839 |
|
|
|
840 |
submit_btn.click(
|
841 |
-
rag.getRagGradioResponse,
|
842 |
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
|
843 |
outputs=[output],
|
844 |
)
|
|
|
|
|
845 |
page_information.submit(
|
846 |
-
rag.getRagGradioResponse,
|
847 |
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
|
848 |
outputs=[output],
|
849 |
)
|
|
|
|
|
850 |
with gr.Accordion("Ad Titles", open=False):
|
851 |
ad_titles = gr.Markdown()
|
852 |
|
|
|
|
|
|
|
|
|
853 |
demo.load(
|
854 |
lambda: "<br>".join(
|
855 |
random.sample(
|
@@ -861,5 +920,30 @@ if __name__ == "__main__":
|
|
861 |
ad_titles,
|
862 |
)
|
863 |
|
864 |
-
|
865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
import os
|
12 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
13 |
from dotenv import load_dotenv
|
|
|
14 |
from typing import List, Tuple, Dict, Any
|
15 |
+
from numpy import ndarray
|
16 |
+
from langchain_core.documents import Document
|
17 |
+
import gradio as gr
|
18 |
+
import random
|
19 |
+
import pandas as pd
|
20 |
|
21 |
|
22 |
class CLUSTERING:
|
23 |
def cluster_embeddings(
|
24 |
self,
|
25 |
+
embeddings: ndarray,
|
26 |
clustering_algo: str,
|
27 |
no_of_clusters: int,
|
28 |
no_of_points: int,
|
|
|
31 |
Clusters embeddings using the specified clustering algorithm and returns the indices of points in each cluster.
|
32 |
|
33 |
Parameters:
|
34 |
+
embeddings (ndarray): The input embeddings to cluster.
|
35 |
clustering_algo (str): The clustering algorithm to use ("kmeans-cc", "kmeans-sp", or "spectral").
|
36 |
no_of_clusters (int): The number of clusters to form.
|
37 |
no_of_points (int): The maximum number of points to include in each cluster.
|
|
|
127 |
|
128 |
def queryVectorDB(
|
129 |
self, page_information: str, threshold: float = None
|
130 |
+
) -> Tuple[List[List[Tuple[Document, float]]], float]:
|
131 |
"""
|
132 |
Query the vector database and cluster the retrieved documents.
|
133 |
|
|
|
200 |
metadata: List[Dict[str, Any]],
|
201 |
CHUNK_SIZE: int = 2048,
|
202 |
CHUNK_OVERLAP: int = 512,
|
203 |
+
) -> List[Document]:
|
204 |
"""
|
205 |
Split the provided content into chunks with metadata.
|
206 |
|
|
|
211 |
CHUNK_OVERLAP (int): The overlap between chunks. Default is 512.
|
212 |
|
213 |
Returns:
|
214 |
+
List[Document]: The split documents with metadata.
|
215 |
"""
|
216 |
text_splitter = RecursiveCharacterTextSplitter(
|
217 |
chunk_size=CHUNK_SIZE,
|
|
|
222 |
return split_docs
|
223 |
|
224 |
def createDBFromDocs(
|
225 |
+
self, split_docs: List[Document], embeddings_model: HuggingFaceEmbeddings
|
226 |
) -> FAISS:
|
227 |
"""
|
228 |
Create a FAISS database from the provided documents and embeddings model.
|
229 |
|
230 |
Parameters:
|
231 |
+
split_docs (List[Document]): The split documents.
|
232 |
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
|
233 |
|
234 |
Returns:
|
|
|
239 |
|
240 |
def createAndSaveDBInChunks(
|
241 |
self,
|
242 |
+
split_docs: List[Document],
|
243 |
embeddings_model: HuggingFaceEmbeddings,
|
244 |
DB_FAISS_PATH: str,
|
245 |
chunk_size: int = 1000,
|
|
|
248 |
Create and save the FAISS database in chunks.
|
249 |
|
250 |
Parameters:
|
251 |
+
split_docs (List[Document]): The split documents.
|
252 |
embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
|
253 |
DB_FAISS_PATH (str): The path to save the FAISS database.
|
254 |
chunk_size (int): The size of each chunk. Default is 1000.
|
|
|
351 |
cv = fl[6:-6]
|
352 |
ind = max(ind, int(cv))
|
353 |
|
354 |
+
all_dbs: List[FAISS] = []
|
355 |
for i in range(0, ind + 1, 2):
|
356 |
print(i)
|
357 |
db1 = FAISS.load_local(
|
|
|
393 |
class ADS_RAG:
|
394 |
def __init__(
|
395 |
self,
|
396 |
+
db: VECTOR_DB,
|
397 |
+
qa_model_name: str,
|
398 |
+
relation_check_best_value_thresh: float,
|
399 |
+
bestRelationSystemPrompt: str,
|
400 |
+
bestQuestionSystemPrompt: str,
|
401 |
+
) -> None:
|
402 |
+
"""Initialize the ADS_RAG class with the given parameters."""
|
403 |
self.client = OpenAI()
|
404 |
self.db = db
|
405 |
self.qa_model_name = qa_model_name
|
|
|
407 |
self.bestRelationSystemPrompt = bestRelationSystemPrompt
|
408 |
self.bestQuestionSystemPrompt = bestQuestionSystemPrompt
|
409 |
|
410 |
+
def callOpenAiApi(
|
411 |
+
self, messages: List[Dict[str, str]]
|
412 |
+
) -> Tuple[Dict[str, Any], int]:
|
413 |
+
"""
|
414 |
+
Call the OpenAI API with the given messages and return the response.
|
415 |
+
|
416 |
+
Parameters:
|
417 |
+
messages (List[Dict[str, str]]): The messages to send to the OpenAI API.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
Tuple[Dict[str, Any], int]: The response from the OpenAI API and the number of tokens used.
|
421 |
+
"""
|
422 |
while True:
|
423 |
try:
|
424 |
response = self.client.chat.completions.create(
|
|
|
439 |
|
440 |
def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
|
441 |
self,
|
442 |
+
page_information: str,
|
443 |
+
adsData: str,
|
444 |
+
relationSystemPrompt: str,
|
445 |
+
questionSystemPrompt: str,
|
446 |
+
bestRetreivedAdValue: float,
|
447 |
+
) -> Dict[str, Any]:
|
448 |
+
"""
|
449 |
+
Get the best question based on page information and ads data.
|
450 |
+
|
451 |
+
Parameters:
|
452 |
+
page_information (str): The information about the page.
|
453 |
+
adsData (str): The data about the ads.
|
454 |
+
relationSystemPrompt (str): The system prompt for relation checking.
|
455 |
+
questionSystemPrompt (str): The system prompt for question generation.
|
456 |
+
bestRetreivedAdValue (float): The best retrieved ad value.
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
Dict[str, Any]: The relation and question answers along with token usage information.
|
460 |
+
"""
|
461 |
if adsData == "":
|
462 |
return ({"reasoning": "No ads data present", "classification": 0}, 0), (
|
463 |
{"reasoning": "", "question": "", "options": []},
|
|
|
483 |
}
|
484 |
]
|
485 |
)
|
|
|
|
|
486 |
else:
|
487 |
relation_answer["reasoning"] = (
|
488 |
+
"First retrieved document value less than threshold so no need to check relation"
|
489 |
)
|
490 |
|
491 |
if relation_answer["classification"] != 0:
|
|
|
510 |
"tokens_used_question": tokens_used_question,
|
511 |
}
|
512 |
|
513 |
+
def convertDocumentsClustersToStringForApiCall(
|
514 |
+
self, documents_clusters: List[List[Tuple[Document, float]]]
|
515 |
+
) -> str:
|
516 |
+
"""
|
517 |
+
Convert document clusters to a string format suitable for API calls.
|
518 |
+
|
519 |
+
Parameters:
|
520 |
+
documents_clusters (List[List[Tuple[Document, float]]]): The document clusters.
|
521 |
+
|
522 |
+
Returns:
|
523 |
+
str: The document clusters converted to a string.
|
524 |
+
"""
|
525 |
key_counter = count(1)
|
526 |
res = json.dumps(
|
527 |
{
|
|
|
535 |
return res
|
536 |
|
537 |
def getRagResponse(
|
538 |
+
self,
|
539 |
+
page_information: str,
|
540 |
+
threshold: float = None,
|
541 |
+
RelationPrompt: str = None,
|
542 |
+
QuestionPrompt: str = None,
|
543 |
+
) -> Tuple[Dict[str, Any], List[List[Tuple[Document, float]]]]:
|
544 |
+
"""
|
545 |
+
Get the RAG response based on the page information and optional prompts.
|
546 |
+
|
547 |
+
Parameters:
|
548 |
+
page_information (str): The information about the page.
|
549 |
+
threshold (float): The threshold for querying the database. Default is None.
|
550 |
+
RelationPrompt (str): The prompt for relation checking. Default is None.
|
551 |
+
QuestionPrompt (str): The prompt for question generation. Default is None.
|
552 |
+
|
553 |
+
Returns:
|
554 |
+
Tuple[Dict[str, Any], List[List[Tuple[Document, float]]]]: The RAG response and the document clusters.
|
555 |
+
"""
|
556 |
curr_relation_prompt = self.bestRelationSystemPrompt
|
557 |
+
if RelationPrompt is not None and len(RelationPrompt):
|
558 |
curr_relation_prompt = RelationPrompt
|
559 |
|
560 |
curr_question_prompt = self.bestQuestionSystemPrompt
|
561 |
+
if QuestionPrompt is not None and len(QuestionPrompt):
|
562 |
curr_question_prompt = QuestionPrompt
|
563 |
|
564 |
documents_clusters, best_value = self.db.queryVectorDB(
|
|
|
574 |
|
575 |
return answer, documents_clusters
|
576 |
|
577 |
+
def changeDocumentsToPrintableString(
|
578 |
+
self, documents_clusters: List[List[Tuple[Document, float]]]
|
579 |
+
) -> str:
|
580 |
+
"""
|
581 |
+
Convert document clusters to a printable string format.
|
582 |
+
|
583 |
+
Parameters:
|
584 |
+
documents_clusters (List[List[Tuple[Document, float]]]): The document clusters.
|
585 |
+
|
586 |
+
Returns:
|
587 |
+
str: The document clusters converted to a printable string.
|
588 |
+
"""
|
589 |
res = ""
|
590 |
i = 0
|
591 |
for ind, documents_cluster in enumerate(documents_clusters):
|
|
|
596 |
res += "\n"
|
597 |
return res
|
598 |
|
599 |
+
def changeResponseToPrintableString(
|
600 |
+
self, response: Dict[str, Any], task: str
|
601 |
+
) -> str:
|
602 |
+
"""
|
603 |
+
Convert the response to a printable string format.
|
604 |
+
|
605 |
+
Parameters:
|
606 |
+
response (Dict[str, Any]): The response to convert.
|
607 |
+
task (str): The task type ('relation' or 'question').
|
608 |
+
|
609 |
+
Returns:
|
610 |
+
str: The response converted to a printable string.
|
611 |
+
"""
|
612 |
if task == "relation":
|
613 |
return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
|
614 |
res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
|
|
|
620 |
return res
|
621 |
|
622 |
def logResult(
|
623 |
+
self,
|
624 |
+
curr_relation_prompt: str,
|
625 |
+
curr_question_prompt: str,
|
626 |
+
page_information: str,
|
627 |
+
answer: Dict[str, Any],
|
628 |
+
) -> None:
|
629 |
+
"""
|
630 |
+
Log the result of the RAG response.
|
631 |
+
|
632 |
+
Parameters:
|
633 |
+
curr_relation_prompt (str): The current relation prompt.
|
634 |
+
curr_question_prompt (str): The current question prompt.
|
635 |
+
page_information (str): The information about the page.
|
636 |
+
answer (Dict[str, Any]): The RAG response.
|
637 |
+
"""
|
638 |
print(
|
639 |
"**************************************************************************************************\n",
|
640 |
# curr_relation_prompt,
|
|
|
645 |
)
|
646 |
|
647 |
def getRagGradioResponse(
|
648 |
+
self,
|
649 |
+
page_information: str,
|
650 |
+
RelationPrompt: str,
|
651 |
+
QuestionPrompt: str,
|
652 |
+
threshold: float,
|
653 |
+
) -> str:
|
654 |
+
"""
|
655 |
+
Get the RAG response in a format suitable for Gradio.
|
656 |
+
|
657 |
+
Parameters:
|
658 |
+
page_information (str): The information about the page.
|
659 |
+
RelationPrompt (str): The prompt for relation checking.
|
660 |
+
QuestionPrompt (str): The prompt for question generation.
|
661 |
+
threshold (float): The threshold for querying the database.
|
662 |
+
|
663 |
+
Returns:
|
664 |
+
str: The full response formatted for Gradio.
|
665 |
+
"""
|
666 |
+
# Get the RAG response and document clusters
|
667 |
answer, documents_clusters = self.getRagResponse(
|
668 |
page_information, threshold, RelationPrompt, QuestionPrompt
|
669 |
)
|
670 |
+
# Log the result
|
671 |
self.logResult(RelationPrompt, QuestionPrompt, page_information, answer)
|
672 |
|
673 |
+
# Convert documents and responses to printable strings
|
674 |
docs_info = self.changeDocumentsToPrintableString(documents_clusters)
|
675 |
relation_answer_string = self.changeResponseToPrintableString(
|
676 |
answer["relation_answer"], "relation"
|
|
|
678 |
question_answer_string = self.changeResponseToPrintableString(
|
679 |
answer["question_answer"], "question"
|
680 |
)
|
681 |
+
# Get token usage information
|
682 |
question_tokens = answer["tokens_used_question"]
|
683 |
relation_tokens = answer["tokens_used_relation"]
|
684 |
+
|
685 |
+
# Format the full response
|
686 |
+
full_response = (
|
687 |
+
f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n "
|
688 |
+
f"Question answer:\n {question_answer_string}\n\n"
|
689 |
+
f"**RETRIEVED DOCUMENTS CLUSTERS**:\n{docs_info}\n\n"
|
690 |
+
f"**TOKENS USED**:\nQuestion api call: {question_tokens}\n"
|
691 |
+
f"Relation api call: {relation_tokens}"
|
692 |
+
)
|
693 |
+
|
694 |
return full_response
|
695 |
|
696 |
|
697 |
+
class Helper:
|
698 |
+
def __init__(self, DB_FAISS_PATH: str) -> None:
|
699 |
+
"""Initialize the Helper class and set environment variables."""
|
700 |
load_dotenv(override=True)
|
701 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
702 |
|
703 |
+
self.DB_FAISS_PATH = DB_FAISS_PATH
|
704 |
+
|
705 |
+
def getRag(self) -> ADS_RAG:
|
706 |
+
"""
|
707 |
+
Create and return an instance of the ADS_RAG class.
|
708 |
+
|
709 |
+
Returns:
|
710 |
+
ADS_RAG: An instance of the ADS_RAG class.
|
711 |
+
"""
|
712 |
+
# Initialize embeddings using HuggingFace
|
713 |
+
embeddings_hf = HuggingFaceEmbeddings(
|
714 |
+
model_name="BAAI/bge-m3"
|
715 |
+
) # "sentence-transformers/all-mpnet-base-v2"
|
716 |
|
717 |
+
# Create a VECTOR_DB instance
|
|
|
|
|
|
|
718 |
vector_db = VECTOR_DB(
|
719 |
+
default_threshold=0.75,
|
720 |
+
number_of_ads_to_fetch_from_db=50,
|
721 |
+
clustering_algo="kmeans-cc",
|
722 |
+
no_of_clusters=3,
|
723 |
+
no_of_ads_in_each_cluster=6,
|
724 |
+
DB_FAISS_PATH=self.DB_FAISS_PATH,
|
725 |
+
embeddings_hf=embeddings_hf,
|
726 |
)
|
727 |
+
|
728 |
+
# Create and return an ADS_RAG instance
|
729 |
rag = ADS_RAG(
|
730 |
+
db=vector_db,
|
731 |
+
qa_model_name="gpt-3.5-turbo",
|
732 |
+
relation_check_best_value_thresh=0.6,
|
733 |
+
bestRelationSystemPrompt=self.getRelationSystemPrompt(),
|
734 |
+
bestQuestionSystemPrompt=self.getQuestionSystemPrompt(),
|
735 |
)
|
736 |
return rag
|
737 |
|
738 |
+
def getQuestionSystemPrompt(self) -> str:
|
739 |
+
"""
|
740 |
+
Return the system prompt for question generation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
741 |
|
742 |
+
Returns:
|
743 |
+
str: The question system prompt.
|
744 |
+
"""
|
745 |
bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
|
746 |
2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
|
747 |
3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
|
|
|
799 |
# """
|
800 |
return bestQuestionSystemPrompt
|
801 |
|
802 |
+
def getRelationSystemPrompt(self) -> str:
|
803 |
+
"""
|
804 |
+
Return the system prompt for relation checking.
|
805 |
+
|
806 |
+
Returns:
|
807 |
+
str: The relation system prompt.
|
808 |
+
"""
|
809 |
bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
|
810 |
|
811 |
---------------------------------------
|
|
|
833 |
return bestRelationSystemPrompt
|
834 |
|
835 |
|
836 |
+
class RAGGradioApp:
|
837 |
+
def __init__(self, helper: Helper) -> None:
|
838 |
+
"""
|
839 |
+
Initialize the RAGGradioApp with an instance of ADS_RAG and Helper.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
|
841 |
+
Args:
|
842 |
+
rag (ADS_RAG): An instance of ADS_RAG for handling RAG functionality.
|
843 |
+
helper (Helper): An instance of Helper for configuration and prompts.
|
844 |
+
"""
|
845 |
+
self.rag = helper.getRag()
|
846 |
+
self.relationSystemPrompt = helper.getRelationSystemPrompt()
|
847 |
+
self.questionSystempPrompt = helper.getQuestionSystemPrompt()
|
848 |
+
|
849 |
+
def get_interface(self, ad_title_content: List[str]) -> gr.Blocks:
|
850 |
+
"""
|
851 |
+
Construct the Gradio interface for RAG functionality.
|
852 |
+
|
853 |
+
Returns:
|
854 |
+
gr.Blocks: Gradio Blocks object containing the constructed interface.
|
855 |
+
"""
|
856 |
+
# Textbox for Relation System prompt
|
857 |
+
RelationPrompt = gr.Textbox(
|
858 |
+
self.getRelationSystemPrompt(),
|
859 |
+
lines=1,
|
860 |
+
placeholder="Enter the relation system prompt for relation check",
|
861 |
+
label="Relation System prompt",
|
862 |
+
)
|
863 |
+
|
864 |
+
# Textbox for Question System prompt
|
865 |
+
QuestionPrompt = gr.Textbox(
|
866 |
+
self.getQuestionSystemPrompt(),
|
867 |
+
lines=1,
|
868 |
+
placeholder="Enter the question system prompt for question formulation",
|
869 |
+
label="Question System prompt",
|
870 |
+
)
|
871 |
+
|
872 |
+
# Textbox for Page Information input
|
873 |
+
page_information = gr.Textbox(
|
874 |
+
lines=1,
|
875 |
+
placeholder="Enter the page information",
|
876 |
+
label="Page Information",
|
877 |
+
)
|
878 |
+
|
879 |
+
# Number input for Threshold
|
880 |
+
threshold = gr.Number(
|
881 |
+
value=self.rag.db.default_threshold, label="Threshold", interactive=True
|
882 |
+
)
|
883 |
+
|
884 |
+
# Textbox for displaying output
|
885 |
output = gr.Textbox(label="Output")
|
886 |
+
|
887 |
+
# Button for submitting the form
|
888 |
submit_btn = gr.Button("Submit")
|
889 |
|
890 |
+
# Define behavior on button click
|
891 |
submit_btn.click(
|
892 |
+
self.rag.getRagGradioResponse,
|
893 |
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
|
894 |
outputs=[output],
|
895 |
)
|
896 |
+
|
897 |
+
# Define behavior on form submission by pressing enter
|
898 |
page_information.submit(
|
899 |
+
self.rag.getRagGradioResponse,
|
900 |
inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
|
901 |
outputs=[output],
|
902 |
)
|
903 |
+
|
904 |
+
# Accordion to display Ad Titles
|
905 |
with gr.Accordion("Ad Titles", open=False):
|
906 |
ad_titles = gr.Markdown()
|
907 |
|
908 |
+
# Create a Gradio Blocks object for structured layout
|
909 |
+
demo = gr.Blocks()
|
910 |
+
|
911 |
+
# Load ad titles into the accordion
|
912 |
demo.load(
|
913 |
lambda: "<br>".join(
|
914 |
random.sample(
|
|
|
920 |
ad_titles,
|
921 |
)
|
922 |
|
923 |
+
return demo
|
924 |
+
|
925 |
+
def launch(self, example_content: List) -> None:
|
926 |
+
"""
|
927 |
+
Launch the Gradio interface for RAG functionality.
|
928 |
+
"""
|
929 |
+
gr.close_all() # Close any existing Gradio instances
|
930 |
+
interface = self.get_interface(example_content) # Get the constructed interface
|
931 |
+
interface.launch() # Launch the Gradio interface
|
932 |
+
|
933 |
+
|
934 |
+
if __name__ == "__main__":
|
935 |
+
helper = Helper(
|
936 |
+
"./vectorstore/db_faiss_ads_Jun_facty_activebeat_Health_dupRemoved0.85"
|
937 |
+
)
|
938 |
+
rag_gradio_app = RAGGradioApp(helper)
|
939 |
+
|
940 |
+
data = pd.read_csv(
|
941 |
+
"./data/149_adclick_Jun_facty_activeBeat_Health_dupRemoved0.85_campaign.tsv",
|
942 |
+
sep="\t",
|
943 |
+
)
|
944 |
+
# data.dropna(axis=0, how="any", inplace=True)
|
945 |
+
ad_title_content = list(
|
946 |
+
data.drop_duplicates(subset=["ad_title", "ad_desc"])["ad_title"].values
|
947 |
+
)
|
948 |
+
|
949 |
+
rag_gradio_app.launch(ad_title_content)
|