Rajat.bans commited on
Commit
0b952f2
1 Parent(s): 3da4cba

Updated the code with comments and type definitions

Browse files
Files changed (2) hide show
  1. rag.ipynb +0 -0
  2. rag.py +290 -206
rag.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
rag.py CHANGED
@@ -11,14 +11,18 @@ import time
11
  import os
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from dotenv import load_dotenv
14
- import pandas as pd
15
  from typing import List, Tuple, Dict, Any
 
 
 
 
 
16
 
17
 
18
  class CLUSTERING:
19
  def cluster_embeddings(
20
  self,
21
- embeddings: List[List[float]],
22
  clustering_algo: str,
23
  no_of_clusters: int,
24
  no_of_points: int,
@@ -27,7 +31,7 @@ class CLUSTERING:
27
  Clusters embeddings using the specified clustering algorithm and returns the indices of points in each cluster.
28
 
29
  Parameters:
30
- embeddings (List[List[float]]): The input embeddings to cluster.
31
  clustering_algo (str): The clustering algorithm to use ("kmeans-cc", "kmeans-sp", or "spectral").
32
  no_of_clusters (int): The number of clusters to form.
33
  no_of_points (int): The maximum number of points to include in each cluster.
@@ -123,7 +127,7 @@ class VECTOR_DB:
123
 
124
  def queryVectorDB(
125
  self, page_information: str, threshold: float = None
126
- ) -> Tuple[List[List[Tuple]], float]:
127
  """
128
  Query the vector database and cluster the retrieved documents.
129
 
@@ -196,7 +200,7 @@ class FAISS_DB:
196
  metadata: List[Dict[str, Any]],
197
  CHUNK_SIZE: int = 2048,
198
  CHUNK_OVERLAP: int = 512,
199
- ) -> List[Dict[str, Any]]:
200
  """
201
  Split the provided content into chunks with metadata.
202
 
@@ -207,7 +211,7 @@ class FAISS_DB:
207
  CHUNK_OVERLAP (int): The overlap between chunks. Default is 512.
208
 
209
  Returns:
210
- List[Dict[str, Any]]: The split documents with metadata.
211
  """
212
  text_splitter = RecursiveCharacterTextSplitter(
213
  chunk_size=CHUNK_SIZE,
@@ -218,13 +222,13 @@ class FAISS_DB:
218
  return split_docs
219
 
220
  def createDBFromDocs(
221
- self, split_docs: List[Dict[str, Any]], embeddings_model: HuggingFaceEmbeddings
222
  ) -> FAISS:
223
  """
224
  Create a FAISS database from the provided documents and embeddings model.
225
 
226
  Parameters:
227
- split_docs (List[Dict[str, Any]]): The split documents.
228
  embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
229
 
230
  Returns:
@@ -235,7 +239,7 @@ class FAISS_DB:
235
 
236
  def createAndSaveDBInChunks(
237
  self,
238
- split_docs: List[Dict[str, Any]],
239
  embeddings_model: HuggingFaceEmbeddings,
240
  DB_FAISS_PATH: str,
241
  chunk_size: int = 1000,
@@ -244,7 +248,7 @@ class FAISS_DB:
244
  Create and save the FAISS database in chunks.
245
 
246
  Parameters:
247
- split_docs (List[Dict[str, Any]]): The split documents.
248
  embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
249
  DB_FAISS_PATH (str): The path to save the FAISS database.
250
  chunk_size (int): The size of each chunk. Default is 1000.
@@ -347,7 +351,7 @@ class FAISS_DB:
347
  cv = fl[6:-6]
348
  ind = max(ind, int(cv))
349
 
350
- all_dbs = []
351
  for i in range(0, ind + 1, 2):
352
  print(i)
353
  db1 = FAISS.load_local(
@@ -389,12 +393,13 @@ class FAISS_DB:
389
  class ADS_RAG:
390
  def __init__(
391
  self,
392
- db,
393
- qa_model_name,
394
- relation_check_best_value_thresh,
395
- bestRelationSystemPrompt,
396
- bestQuestionSystemPrompt,
397
- ):
 
398
  self.client = OpenAI()
399
  self.db = db
400
  self.qa_model_name = qa_model_name
@@ -402,7 +407,18 @@ class ADS_RAG:
402
  self.bestRelationSystemPrompt = bestRelationSystemPrompt
403
  self.bestQuestionSystemPrompt = bestQuestionSystemPrompt
404
 
405
- def callOpenAiApi(self, messages):
 
 
 
 
 
 
 
 
 
 
 
406
  while True:
407
  try:
408
  response = self.client.chat.completions.create(
@@ -423,12 +439,25 @@ class ADS_RAG:
423
 
424
  def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
425
  self,
426
- page_information,
427
- adsData,
428
- relationSystemPrompt,
429
- questionSystemPrompt,
430
- bestRetreivedAdValue,
431
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  if adsData == "":
433
  return ({"reasoning": "No ads data present", "classification": 0}, 0), (
434
  {"reasoning": "", "question": "", "options": []},
@@ -454,11 +483,9 @@ class ADS_RAG:
454
  }
455
  ]
456
  )
457
-
458
- tokens_used_question = 0
459
  else:
460
  relation_answer["reasoning"] = (
461
- "First retreived document value less than threshold so no need to check relation"
462
  )
463
 
464
  if relation_answer["classification"] != 0:
@@ -483,7 +510,18 @@ class ADS_RAG:
483
  "tokens_used_question": tokens_used_question,
484
  }
485
 
486
- def convertDocumentsClustersToStringForApiCall(self, documents_clusters):
 
 
 
 
 
 
 
 
 
 
 
487
  key_counter = count(1)
488
  res = json.dumps(
489
  {
@@ -497,14 +535,30 @@ class ADS_RAG:
497
  return res
498
 
499
  def getRagResponse(
500
- self, page_information, threshold=None, RelationPrompt=None, QuestionPrompt=None
501
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  curr_relation_prompt = self.bestRelationSystemPrompt
503
- if RelationPrompt != None and len(RelationPrompt):
504
  curr_relation_prompt = RelationPrompt
505
 
506
  curr_question_prompt = self.bestQuestionSystemPrompt
507
- if QuestionPrompt != None and len(QuestionPrompt):
508
  curr_question_prompt = QuestionPrompt
509
 
510
  documents_clusters, best_value = self.db.queryVectorDB(
@@ -520,7 +574,18 @@ class ADS_RAG:
520
 
521
  return answer, documents_clusters
522
 
523
- def changeDocumentsToPrintableString(self, documents_clusters):
 
 
 
 
 
 
 
 
 
 
 
524
  res = ""
525
  i = 0
526
  for ind, documents_cluster in enumerate(documents_clusters):
@@ -531,7 +596,19 @@ class ADS_RAG:
531
  res += "\n"
532
  return res
533
 
534
- def changeResponseToPrintableString(self, response, task):
 
 
 
 
 
 
 
 
 
 
 
 
535
  if task == "relation":
536
  return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
537
  res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
@@ -543,8 +620,21 @@ class ADS_RAG:
543
  return res
544
 
545
  def logResult(
546
- self, curr_relation_prompt, curr_question_prompt, page_information, answer
547
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  print(
549
  "**************************************************************************************************\n",
550
  # curr_relation_prompt,
@@ -555,13 +645,32 @@ class ADS_RAG:
555
  )
556
 
557
  def getRagGradioResponse(
558
- self, page_information, RelationPrompt, QuestionPrompt, threshold
559
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  answer, documents_clusters = self.getRagResponse(
561
  page_information, threshold, RelationPrompt, QuestionPrompt
562
  )
 
563
  self.logResult(RelationPrompt, QuestionPrompt, page_information, answer)
564
 
 
565
  docs_info = self.changeDocumentsToPrintableString(documents_clusters)
566
  relation_answer_string = self.changeResponseToPrintableString(
567
  answer["relation_answer"], "relation"
@@ -569,79 +678,70 @@ class ADS_RAG:
569
  question_answer_string = self.changeResponseToPrintableString(
570
  answer["question_answer"], "question"
571
  )
 
572
  question_tokens = answer["tokens_used_question"]
573
  relation_tokens = answer["tokens_used_relation"]
574
- full_response = f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n Question answer:\n {question_answer_string}\n\n**RETREIVED DOCUMENTS CLUSTERS**:\n{docs_info}\n\n**TOKENS USED**:\nQuestion api call: {question_tokens}\nRelation api call: {relation_tokens}"
 
 
 
 
 
 
 
 
 
575
  return full_response
576
 
577
 
578
- class VARIABLE_MANAGER:
579
- def __init__(self):
 
580
  load_dotenv(override=True)
581
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
582
 
583
- self.embedding_model_hf = "BAAI/bge-m3"
584
- # embedding_model_hf = "sentence-transformers/all-mpnet-base-v2"
585
- self.DB_FAISS_PATH = (
586
- "./vectorstore/db_faiss_ads_Jun_facty_activebeat_Health_dupRemoved0.85"
587
- )
 
 
 
 
 
 
 
 
588
 
589
- def getRag(self):
590
- # embeddings_oa = OpenAIEmbeddings(model=embedding_model_oa)
591
- # embeddings_hf = HuggingFaceEmbeddings(model_name = embedding_model_hf, show_progress = True)
592
- embeddings_hf = HuggingFaceEmbeddings(model_name=self.embedding_model_hf)
593
  vector_db = VECTOR_DB(
594
- 0.75, 50, "kmeans-cc", 3, 6, self.DB_FAISS_PATH, embeddings_hf
 
 
 
 
 
 
595
  )
 
 
596
  rag = ADS_RAG(
597
- vector_db,
598
- "gpt-3.5-turbo",
599
- 0.6,
600
- self.getRelationSystemPrompt(),
601
- self.getQuestionSystemPrompt(),
602
  )
603
  return rag
604
 
605
- def QnAAdsSampleGenerationPreProcessing(self):
606
- data_file_path = (
607
- "./data/148_facty_activebeat_24Jun-30Jun_top1000each_urlsContent.tsv"
608
- )
609
- data = pd.read_csv(data_file_path, sep="\t")
610
- data.dropna(axis=0, how="any", inplace=True)
611
-
612
- # data.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
613
- # ad_title_content = list(data["ad_title"].values)
614
- def get_core_content(row):
615
- url_content = row["url_content"]
616
- url_title = row["url_title"]
617
- return (
618
- "Page Title -: "
619
- + url_title
620
- + "\nPage Content -: "
621
- + ". ".join(url_content.split(". ")[:7])
622
- )
623
-
624
- data["core_content"] = data.apply(get_core_content, axis=1)
625
- # for i in range(len(data)):
626
- # print(data.loc[i, 'url'])
627
- # print(data.loc[i, 'url_content'])
628
- # print(data.loc[i, 'core_content'])
629
- # print()
630
- # if(i > 10):
631
- # break
632
- return data
633
-
634
- def GradioRagPreProcessing(self):
635
- data_file_path = (
636
- "./data/149_adclick_Jun_facty_activeBeat_Health_dupRemoved0.85_campaign.tsv"
637
- )
638
- data = pd.read_csv(data_file_path, sep="\t")
639
- # data.dropna(axis=0, how="any", inplace=True)
640
- data.drop_duplicates(subset=["ad_title", "ad_desc"], inplace=True)
641
- ad_title_content = list(data["ad_title"].values)
642
- return ad_title_content
643
 
644
- def getQuestionSystemPrompt(self):
 
 
645
  bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
646
  2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
647
  3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
@@ -699,7 +799,13 @@ The ADS_DATA provided to you is as follows:
699
  # """
700
  return bestQuestionSystemPrompt
701
 
702
- def getRelationSystemPrompt(self):
 
 
 
 
 
 
703
  bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
704
 
705
  ---------------------------------------
@@ -727,129 +833,82 @@ The ADS_DATA provided to you is as follows:
727
  return bestRelationSystemPrompt
728
 
729
 
730
- # *********************** DB GENERATION ******************************
731
- # df = pd.read_csv(data_file_path, sep="\t")
732
-
733
- # --------------------------------
734
- # WEB DATA PROCESSING
735
- # from urllib.parse import urlparse
736
- # import re
737
- # def get_cleaned_url(url):
738
- # path = urlparse(url).path.strip()
739
- # cleaned_path = re.sub(r'[^a-zA-Z0-9\s-]', ' ', path).replace('/', '')
740
- # cleaned_path = re.sub(r'[^a-zA-Z0-9\s]', ' ', path).replace('-', '')
741
- # return cleaned_path.strip()
742
-
743
- # df['cleaned_url'] = df['url'].map(get_cleaned_url)
744
- # df.dropna(subset=['cleaned_url', 'url_content', 'url_title'], inplace=True)
745
- # df['combined'] = df['cleaned_url'] + ". " + df['url_title'] + ". " + df['url_content']
746
- # content = df["combined"].tolist()
747
- # metadata = [
748
- # {"title": row["url_title"], "url": row["url"]}
749
- # for _, row in df.iterrows()
750
- # ]
751
- # ------------------------------
752
- # ADS DATA PROCESSING
753
- # # df.dropna(axis=0, how='any', inplace=True)
754
- # df.drop_duplicates(subset = ['ad_title', 'ad_desc'], inplace=True)
755
- # dfRPC = df[df['RPC'] > 0]
756
- # dfRPC.dropna(how = 'any', inplace=True)
757
- # dfCampaign = df[df['type'] == 'campaign']
758
- # dfCampaign.fillna('', inplace=True)
759
- # df = pd.concat([dfRPC, dfCampaign])
760
- # df
761
-
762
- # content = (df["ad_title"] + ". " + df["ad_desc"]).tolist()
763
- # metadata = [
764
- # {"publisher_url": row["publisher_url"], "keyword_term": row["keyword_term"], "ad_display_url": row["ad_display_url"], "revenue": row["revenue"], "ad_click_count": row["ad_click_count"], "RPC": row["RPC"], "Type": row["type"]}
765
- # # {"revenue": row["revenue"], "ad_click_count": row["ad_click_count"]}
766
- # for _, row in df.iterrows()
767
- # ]
768
- # --------------------------------
769
-
770
- # faiss_db = FAISS_DB()
771
- # db = faiss_db.createDBFromDocs(content, metadata)
772
- # faiss_db.saveDB(db, '.')
773
-
774
- # ************************************************************************
775
- # PARALLELY CREATING DB - BACKUP FOR FUTURE USE
776
- # import time
777
- # import threading
778
- # import os
779
- # one_db_docs_size = 1000
780
- # starting_i = 0
781
- # parallel_processes = 3
782
- # def split_list(lst, n):
783
- # k, m = divmod(len(lst), n)
784
- # return (lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
785
- # def createDBForIndexes(inds):
786
- # for i in inds:
787
- # ctime = time.time()
788
- # print(f"Processing {i}")
789
- # if not os.path.exists(DB_FAISS_PATH + "/index_{int(i/one_db_docs_size)}.faiss"):
790
- # db = FAISS.from_documents(split_docs[i:i+one_db_docs_size], embeddings_hf)
791
- # db.save_local(DB_FAISS_PATH, index_name = f"index_{int(i/one_db_docs_size)}")
792
- # ctime = time.time() - ctime
793
- # print(f"{i})Time taken", ctime)
794
- # indexes = split_list(range(starting_i, len(split_docs), one_db_docs_size), parallel_processes)
795
- # threads = []
796
- # for i, one_process_indexes in enumerate(indexes):
797
- # thread = threading.Thread(target=createDBForIndexes, args=(one_process_indexes,))
798
- # thread.start()
799
- # threads.append(thread)
800
- # for thread in threads:
801
- # thread.join()
802
- # print("All threads completed.")
803
- # ************************************************************************
804
 
805
- if __name__ == "__main__":
806
- import pandas as pd
807
- import gradio as gr
808
- import random
809
-
810
- vm = VARIABLE_MANAGER()
811
- rag = vm.getRag()
812
- ad_title_content = vm.GradioRagPreProcessing()
813
-
814
- with gr.Blocks() as demo:
815
- gr.Markdown("# RAG on ads data")
816
- with gr.Row():
817
- RelationPrompt = gr.Textbox(
818
- vm.getRelationSystemPrompt(),
819
- lines=1,
820
- placeholder="Enter the relation system prompt for relation check",
821
- label="Relation System prompt",
822
- )
823
- QuestionPrompt = gr.Textbox(
824
- vm.getQuestionSystemPrompt(),
825
- lines=1,
826
- placeholder="Enter the question system prompt for question formulation",
827
- label="Question System prompt",
828
- )
829
- page_information = gr.Textbox(
830
- lines=1,
831
- placeholder="Enter the page information",
832
- label="Page Information",
833
- )
834
- threshold = gr.Number(
835
- value=rag.db.default_threshold, label="Threshold", interactive=True
836
- )
 
 
 
 
 
 
 
 
 
 
 
 
837
  output = gr.Textbox(label="Output")
 
 
838
  submit_btn = gr.Button("Submit")
839
 
 
840
  submit_btn.click(
841
- rag.getRagGradioResponse,
842
  inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
843
  outputs=[output],
844
  )
 
 
845
  page_information.submit(
846
- rag.getRagGradioResponse,
847
  inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
848
  outputs=[output],
849
  )
 
 
850
  with gr.Accordion("Ad Titles", open=False):
851
  ad_titles = gr.Markdown()
852
 
 
 
 
 
853
  demo.load(
854
  lambda: "<br>".join(
855
  random.sample(
@@ -861,5 +920,30 @@ if __name__ == "__main__":
861
  ad_titles,
862
  )
863
 
864
- gr.close_all()
865
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import os
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from dotenv import load_dotenv
 
14
  from typing import List, Tuple, Dict, Any
15
+ from numpy import ndarray
16
+ from langchain_core.documents import Document
17
+ import gradio as gr
18
+ import random
19
+ import pandas as pd
20
 
21
 
22
  class CLUSTERING:
23
  def cluster_embeddings(
24
  self,
25
+ embeddings: ndarray,
26
  clustering_algo: str,
27
  no_of_clusters: int,
28
  no_of_points: int,
 
31
  Clusters embeddings using the specified clustering algorithm and returns the indices of points in each cluster.
32
 
33
  Parameters:
34
+ embeddings (ndarray): The input embeddings to cluster.
35
  clustering_algo (str): The clustering algorithm to use ("kmeans-cc", "kmeans-sp", or "spectral").
36
  no_of_clusters (int): The number of clusters to form.
37
  no_of_points (int): The maximum number of points to include in each cluster.
 
127
 
128
  def queryVectorDB(
129
  self, page_information: str, threshold: float = None
130
+ ) -> Tuple[List[List[Tuple[Document, float]]], float]:
131
  """
132
  Query the vector database and cluster the retrieved documents.
133
 
 
200
  metadata: List[Dict[str, Any]],
201
  CHUNK_SIZE: int = 2048,
202
  CHUNK_OVERLAP: int = 512,
203
+ ) -> List[Document]:
204
  """
205
  Split the provided content into chunks with metadata.
206
 
 
211
  CHUNK_OVERLAP (int): The overlap between chunks. Default is 512.
212
 
213
  Returns:
214
+ List[Document]: The split documents with metadata.
215
  """
216
  text_splitter = RecursiveCharacterTextSplitter(
217
  chunk_size=CHUNK_SIZE,
 
222
  return split_docs
223
 
224
  def createDBFromDocs(
225
+ self, split_docs: List[Document], embeddings_model: HuggingFaceEmbeddings
226
  ) -> FAISS:
227
  """
228
  Create a FAISS database from the provided documents and embeddings model.
229
 
230
  Parameters:
231
+ split_docs (List[Document]): The split documents.
232
  embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
233
 
234
  Returns:
 
239
 
240
  def createAndSaveDBInChunks(
241
  self,
242
+ split_docs: List[Document],
243
  embeddings_model: HuggingFaceEmbeddings,
244
  DB_FAISS_PATH: str,
245
  chunk_size: int = 1000,
 
248
  Create and save the FAISS database in chunks.
249
 
250
  Parameters:
251
+ split_docs (List[Document]): The split documents.
252
  embeddings_model (HuggingFaceEmbeddings): The embeddings model to use.
253
  DB_FAISS_PATH (str): The path to save the FAISS database.
254
  chunk_size (int): The size of each chunk. Default is 1000.
 
351
  cv = fl[6:-6]
352
  ind = max(ind, int(cv))
353
 
354
+ all_dbs: List[FAISS] = []
355
  for i in range(0, ind + 1, 2):
356
  print(i)
357
  db1 = FAISS.load_local(
 
393
  class ADS_RAG:
394
  def __init__(
395
  self,
396
+ db: VECTOR_DB,
397
+ qa_model_name: str,
398
+ relation_check_best_value_thresh: float,
399
+ bestRelationSystemPrompt: str,
400
+ bestQuestionSystemPrompt: str,
401
+ ) -> None:
402
+ """Initialize the ADS_RAG class with the given parameters."""
403
  self.client = OpenAI()
404
  self.db = db
405
  self.qa_model_name = qa_model_name
 
407
  self.bestRelationSystemPrompt = bestRelationSystemPrompt
408
  self.bestQuestionSystemPrompt = bestQuestionSystemPrompt
409
 
410
+ def callOpenAiApi(
411
+ self, messages: List[Dict[str, str]]
412
+ ) -> Tuple[Dict[str, Any], int]:
413
+ """
414
+ Call the OpenAI API with the given messages and return the response.
415
+
416
+ Parameters:
417
+ messages (List[Dict[str, str]]): The messages to send to the OpenAI API.
418
+
419
+ Returns:
420
+ Tuple[Dict[str, Any], int]: The response from the OpenAI API and the number of tokens used.
421
+ """
422
  while True:
423
  try:
424
  response = self.client.chat.completions.create(
 
439
 
440
  def getBestQuestionOnTheBasisOfPageInformationAndAdsData(
441
  self,
442
+ page_information: str,
443
+ adsData: str,
444
+ relationSystemPrompt: str,
445
+ questionSystemPrompt: str,
446
+ bestRetreivedAdValue: float,
447
+ ) -> Dict[str, Any]:
448
+ """
449
+ Get the best question based on page information and ads data.
450
+
451
+ Parameters:
452
+ page_information (str): The information about the page.
453
+ adsData (str): The data about the ads.
454
+ relationSystemPrompt (str): The system prompt for relation checking.
455
+ questionSystemPrompt (str): The system prompt for question generation.
456
+ bestRetreivedAdValue (float): The best retrieved ad value.
457
+
458
+ Returns:
459
+ Dict[str, Any]: The relation and question answers along with token usage information.
460
+ """
461
  if adsData == "":
462
  return ({"reasoning": "No ads data present", "classification": 0}, 0), (
463
  {"reasoning": "", "question": "", "options": []},
 
483
  }
484
  ]
485
  )
 
 
486
  else:
487
  relation_answer["reasoning"] = (
488
+ "First retrieved document value less than threshold so no need to check relation"
489
  )
490
 
491
  if relation_answer["classification"] != 0:
 
510
  "tokens_used_question": tokens_used_question,
511
  }
512
 
513
+ def convertDocumentsClustersToStringForApiCall(
514
+ self, documents_clusters: List[List[Tuple[Document, float]]]
515
+ ) -> str:
516
+ """
517
+ Convert document clusters to a string format suitable for API calls.
518
+
519
+ Parameters:
520
+ documents_clusters (List[List[Tuple[Document, float]]]): The document clusters.
521
+
522
+ Returns:
523
+ str: The document clusters converted to a string.
524
+ """
525
  key_counter = count(1)
526
  res = json.dumps(
527
  {
 
535
  return res
536
 
537
  def getRagResponse(
538
+ self,
539
+ page_information: str,
540
+ threshold: float = None,
541
+ RelationPrompt: str = None,
542
+ QuestionPrompt: str = None,
543
+ ) -> Tuple[Dict[str, Any], List[List[Tuple[Document, float]]]]:
544
+ """
545
+ Get the RAG response based on the page information and optional prompts.
546
+
547
+ Parameters:
548
+ page_information (str): The information about the page.
549
+ threshold (float): The threshold for querying the database. Default is None.
550
+ RelationPrompt (str): The prompt for relation checking. Default is None.
551
+ QuestionPrompt (str): The prompt for question generation. Default is None.
552
+
553
+ Returns:
554
+ Tuple[Dict[str, Any], List[List[Tuple[Document, float]]]]: The RAG response and the document clusters.
555
+ """
556
  curr_relation_prompt = self.bestRelationSystemPrompt
557
+ if RelationPrompt is not None and len(RelationPrompt):
558
  curr_relation_prompt = RelationPrompt
559
 
560
  curr_question_prompt = self.bestQuestionSystemPrompt
561
+ if QuestionPrompt is not None and len(QuestionPrompt):
562
  curr_question_prompt = QuestionPrompt
563
 
564
  documents_clusters, best_value = self.db.queryVectorDB(
 
574
 
575
  return answer, documents_clusters
576
 
577
+ def changeDocumentsToPrintableString(
578
+ self, documents_clusters: List[List[Tuple[Document, float]]]
579
+ ) -> str:
580
+ """
581
+ Convert document clusters to a printable string format.
582
+
583
+ Parameters:
584
+ documents_clusters (List[List[Tuple[Document, float]]]): The document clusters.
585
+
586
+ Returns:
587
+ str: The document clusters converted to a printable string.
588
+ """
589
  res = ""
590
  i = 0
591
  for ind, documents_cluster in enumerate(documents_clusters):
 
596
  res += "\n"
597
  return res
598
 
599
+ def changeResponseToPrintableString(
600
+ self, response: Dict[str, Any], task: str
601
+ ) -> str:
602
+ """
603
+ Convert the response to a printable string format.
604
+
605
+ Parameters:
606
+ response (Dict[str, Any]): The response to convert.
607
+ task (str): The task type ('relation' or 'question').
608
+
609
+ Returns:
610
+ str: The response converted to a printable string.
611
+ """
612
  if task == "relation":
613
  return f"Reasoning: {response['reasoning']}\n\nClassification: {response['classification']}\n"
614
  res = f"Reasoning: {response['reasoning']}\n\nQuestion: {response['question']}\n\nOptions: \n"
 
620
  return res
621
 
622
  def logResult(
623
+ self,
624
+ curr_relation_prompt: str,
625
+ curr_question_prompt: str,
626
+ page_information: str,
627
+ answer: Dict[str, Any],
628
+ ) -> None:
629
+ """
630
+ Log the result of the RAG response.
631
+
632
+ Parameters:
633
+ curr_relation_prompt (str): The current relation prompt.
634
+ curr_question_prompt (str): The current question prompt.
635
+ page_information (str): The information about the page.
636
+ answer (Dict[str, Any]): The RAG response.
637
+ """
638
  print(
639
  "**************************************************************************************************\n",
640
  # curr_relation_prompt,
 
645
  )
646
 
647
  def getRagGradioResponse(
648
+ self,
649
+ page_information: str,
650
+ RelationPrompt: str,
651
+ QuestionPrompt: str,
652
+ threshold: float,
653
+ ) -> str:
654
+ """
655
+ Get the RAG response in a format suitable for Gradio.
656
+
657
+ Parameters:
658
+ page_information (str): The information about the page.
659
+ RelationPrompt (str): The prompt for relation checking.
660
+ QuestionPrompt (str): The prompt for question generation.
661
+ threshold (float): The threshold for querying the database.
662
+
663
+ Returns:
664
+ str: The full response formatted for Gradio.
665
+ """
666
+ # Get the RAG response and document clusters
667
  answer, documents_clusters = self.getRagResponse(
668
  page_information, threshold, RelationPrompt, QuestionPrompt
669
  )
670
+ # Log the result
671
  self.logResult(RelationPrompt, QuestionPrompt, page_information, answer)
672
 
673
+ # Convert documents and responses to printable strings
674
  docs_info = self.changeDocumentsToPrintableString(documents_clusters)
675
  relation_answer_string = self.changeResponseToPrintableString(
676
  answer["relation_answer"], "relation"
 
678
  question_answer_string = self.changeResponseToPrintableString(
679
  answer["question_answer"], "question"
680
  )
681
+ # Get token usage information
682
  question_tokens = answer["tokens_used_question"]
683
  relation_tokens = answer["tokens_used_relation"]
684
+
685
+ # Format the full response
686
+ full_response = (
687
+ f"**ANSWER**: \n Relation answer:\n {relation_answer_string}\n "
688
+ f"Question answer:\n {question_answer_string}\n\n"
689
+ f"**RETRIEVED DOCUMENTS CLUSTERS**:\n{docs_info}\n\n"
690
+ f"**TOKENS USED**:\nQuestion api call: {question_tokens}\n"
691
+ f"Relation api call: {relation_tokens}"
692
+ )
693
+
694
  return full_response
695
 
696
 
697
+ class Helper:
698
+ def __init__(self, DB_FAISS_PATH: str) -> None:
699
+ """Initialize the Helper class and set environment variables."""
700
  load_dotenv(override=True)
701
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
702
 
703
+ self.DB_FAISS_PATH = DB_FAISS_PATH
704
+
705
+ def getRag(self) -> ADS_RAG:
706
+ """
707
+ Create and return an instance of the ADS_RAG class.
708
+
709
+ Returns:
710
+ ADS_RAG: An instance of the ADS_RAG class.
711
+ """
712
+ # Initialize embeddings using HuggingFace
713
+ embeddings_hf = HuggingFaceEmbeddings(
714
+ model_name="BAAI/bge-m3"
715
+ ) # "sentence-transformers/all-mpnet-base-v2"
716
 
717
+ # Create a VECTOR_DB instance
 
 
 
718
  vector_db = VECTOR_DB(
719
+ default_threshold=0.75,
720
+ number_of_ads_to_fetch_from_db=50,
721
+ clustering_algo="kmeans-cc",
722
+ no_of_clusters=3,
723
+ no_of_ads_in_each_cluster=6,
724
+ DB_FAISS_PATH=self.DB_FAISS_PATH,
725
+ embeddings_hf=embeddings_hf,
726
  )
727
+
728
+ # Create and return an ADS_RAG instance
729
  rag = ADS_RAG(
730
+ db=vector_db,
731
+ qa_model_name="gpt-3.5-turbo",
732
+ relation_check_best_value_thresh=0.6,
733
+ bestRelationSystemPrompt=self.getRelationSystemPrompt(),
734
+ bestQuestionSystemPrompt=self.getQuestionSystemPrompt(),
735
  )
736
  return rag
737
 
738
+ def getQuestionSystemPrompt(self) -> str:
739
+ """
740
+ Return the system prompt for question generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
 
742
+ Returns:
743
+ str: The question system prompt.
744
+ """
745
  bestQuestionSystemPrompt = """1. You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to form a relevant QUESTION to ask the user visiting the webpage. This question should help identify the user's intent behind visiting the webpage and should be highly attractive.
746
  2. Now form a highly attractive/lucrative and diverse/mutually exclusive OPTION which should be both the answer for the QUESTION and related to ads in this cluster.
747
  3. Try to generate intelligent creatives for advertising and keep QUESTION within 70 characters and either 2, 3 or 4 options with each OPTION within 4 to 6 words.
 
799
  # """
800
  return bestQuestionSystemPrompt
801
 
802
+ def getRelationSystemPrompt(self) -> str:
803
+ """
804
+ Return the system prompt for relation checking.
805
+
806
+ Returns:
807
+ str: The relation system prompt.
808
+ """
809
  bestRelationSystemPrompt = """You are an advertising concierge for text ads on websites. Given an INPUT and the available ad inventory (ADS_DATA), your task is to determine whether there are some relevant ADS to INPUT are present in ADS_DATA. ADS WHICH DON'T MATCH USER'S INTENT SHOULD BE CONSIDERED IRRELEVANT
810
 
811
  ---------------------------------------
 
833
  return bestRelationSystemPrompt
834
 
835
 
836
+ class RAGGradioApp:
837
+ def __init__(self, helper: Helper) -> None:
838
+ """
839
+ Initialize the RAGGradioApp with an instance of ADS_RAG and Helper.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
 
841
+ Args:
842
+ rag (ADS_RAG): An instance of ADS_RAG for handling RAG functionality.
843
+ helper (Helper): An instance of Helper for configuration and prompts.
844
+ """
845
+ self.rag = helper.getRag()
846
+ self.relationSystemPrompt = helper.getRelationSystemPrompt()
847
+ self.questionSystempPrompt = helper.getQuestionSystemPrompt()
848
+
849
+ def get_interface(self, ad_title_content: List[str]) -> gr.Blocks:
850
+ """
851
+ Construct the Gradio interface for RAG functionality.
852
+
853
+ Returns:
854
+ gr.Blocks: Gradio Blocks object containing the constructed interface.
855
+ """
856
+ # Textbox for Relation System prompt
857
+ RelationPrompt = gr.Textbox(
858
+ self.getRelationSystemPrompt(),
859
+ lines=1,
860
+ placeholder="Enter the relation system prompt for relation check",
861
+ label="Relation System prompt",
862
+ )
863
+
864
+ # Textbox for Question System prompt
865
+ QuestionPrompt = gr.Textbox(
866
+ self.getQuestionSystemPrompt(),
867
+ lines=1,
868
+ placeholder="Enter the question system prompt for question formulation",
869
+ label="Question System prompt",
870
+ )
871
+
872
+ # Textbox for Page Information input
873
+ page_information = gr.Textbox(
874
+ lines=1,
875
+ placeholder="Enter the page information",
876
+ label="Page Information",
877
+ )
878
+
879
+ # Number input for Threshold
880
+ threshold = gr.Number(
881
+ value=self.rag.db.default_threshold, label="Threshold", interactive=True
882
+ )
883
+
884
+ # Textbox for displaying output
885
  output = gr.Textbox(label="Output")
886
+
887
+ # Button for submitting the form
888
  submit_btn = gr.Button("Submit")
889
 
890
+ # Define behavior on button click
891
  submit_btn.click(
892
+ self.rag.getRagGradioResponse,
893
  inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
894
  outputs=[output],
895
  )
896
+
897
+ # Define behavior on form submission by pressing enter
898
  page_information.submit(
899
+ self.rag.getRagGradioResponse,
900
  inputs=[page_information, RelationPrompt, QuestionPrompt, threshold],
901
  outputs=[output],
902
  )
903
+
904
+ # Accordion to display Ad Titles
905
  with gr.Accordion("Ad Titles", open=False):
906
  ad_titles = gr.Markdown()
907
 
908
+ # Create a Gradio Blocks object for structured layout
909
+ demo = gr.Blocks()
910
+
911
+ # Load ad titles into the accordion
912
  demo.load(
913
  lambda: "<br>".join(
914
  random.sample(
 
920
  ad_titles,
921
  )
922
 
923
+ return demo
924
+
925
+ def launch(self, example_content: List) -> None:
926
+ """
927
+ Launch the Gradio interface for RAG functionality.
928
+ """
929
+ gr.close_all() # Close any existing Gradio instances
930
+ interface = self.get_interface(example_content) # Get the constructed interface
931
+ interface.launch() # Launch the Gradio interface
932
+
933
+
934
+ if __name__ == "__main__":
935
+ helper = Helper(
936
+ "./vectorstore/db_faiss_ads_Jun_facty_activebeat_Health_dupRemoved0.85"
937
+ )
938
+ rag_gradio_app = RAGGradioApp(helper)
939
+
940
+ data = pd.read_csv(
941
+ "./data/149_adclick_Jun_facty_activeBeat_Health_dupRemoved0.85_campaign.tsv",
942
+ sep="\t",
943
+ )
944
+ # data.dropna(axis=0, how="any", inplace=True)
945
+ ad_title_content = list(
946
+ data.drop_duplicates(subset=["ad_title", "ad_desc"])["ad_title"].values
947
+ )
948
+
949
+ rag_gradio_app.launch(ad_title_content)