ak3ra commited on
Commit
669d93a
1 Parent(s): 1e49809

added a cache

Browse files
Files changed (2) hide show
  1. app.py +12 -9
  2. rag/rag_pipeline.py +37 -34
app.py CHANGED
@@ -1,22 +1,25 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import json
5
  from rag.rag_pipeline import RAGPipeline
6
  from utils.prompts import highlight_prompt, evidence_based_prompt
7
  from config import STUDY_FILES
8
 
 
 
9
 
10
- def load_rag_pipeline(study_name):
11
- study_file = STUDY_FILES.get(study_name)
12
- if study_file:
13
- return RAGPipeline(study_file)
14
- else:
15
- raise ValueError(f"Invalid study name: {study_name}")
 
 
 
16
 
17
 
18
  def query_rag(study_name, question, prompt_type):
19
- rag = load_rag_pipeline(study_name)
20
 
21
  if prompt_type == "Highlight":
22
  prompt = highlight_prompt
 
 
 
1
  import gradio as gr
2
  import json
3
  from rag.rag_pipeline import RAGPipeline
4
  from utils.prompts import highlight_prompt, evidence_based_prompt
5
  from config import STUDY_FILES
6
 
7
+ # Cache for RAG pipelines
8
+ rag_cache = {}
9
 
10
+
11
+ def get_rag_pipeline(study_name):
12
+ if study_name not in rag_cache:
13
+ study_file = STUDY_FILES.get(study_name)
14
+ if study_file:
15
+ rag_cache[study_name] = RAGPipeline(study_file)
16
+ else:
17
+ raise ValueError(f"Invalid study name: {study_name}")
18
+ return rag_cache[study_name]
19
 
20
 
21
  def query_rag(study_name, question, prompt_type):
22
+ rag = get_rag_pipeline(study_name)
23
 
24
  if prompt_type == "Highlight":
25
  prompt = highlight_prompt
rag/rag_pipeline.py CHANGED
@@ -4,59 +4,62 @@ import json
4
  from llama_index.core import Document, VectorStoreIndex
5
  from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
6
  from llama_index.core import PromptTemplate
7
- from typing import List
8
 
9
 
10
  class RAGPipeline:
11
  def __init__(self, study_json, use_semantic_splitter=False):
12
  self.study_json = study_json
13
- self.index = None
14
  self.use_semantic_splitter = use_semantic_splitter
15
- self.load_documents()
16
- self.build_index()
17
 
18
  def load_documents(self):
19
- with open(self.study_json, "r") as f:
20
- self.data = json.load(f)
 
21
 
22
- self.documents = []
23
 
24
- for index, doc_data in enumerate(self.data):
25
- doc_content = (
26
- f"Title: {doc_data['title']}\n"
27
- f"Authors: {', '.join(doc_data['authors'])}\n"
28
- f"Full Text: {doc_data['full_text']}"
29
- )
30
 
31
- metadata = {
32
- "title": doc_data.get("title"),
33
- "abstract": doc_data.get("abstract"),
34
- "authors": doc_data.get("authors", []),
35
- "year": doc_data.get("year"),
36
- "doi": doc_data.get("doi"),
37
- }
38
 
39
- self.documents.append(
40
- Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
41
- )
42
 
43
  def build_index(self):
44
- sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
 
 
45
 
46
- def _split(text: str) -> List[str]:
47
- return sentence_splitter.split_text(text)
48
 
49
- node_parser = SentenceWindowNodeParser.from_defaults(
50
- sentence_splitter=_split,
51
- window_size=3,
52
- window_metadata_key="window",
53
- original_text_metadata_key="original_text",
54
- )
55
 
56
- nodes = node_parser.get_nodes_from_documents(self.documents)
57
- self.index = VectorStoreIndex(nodes)
58
 
59
  def query(self, question, prompt_template=None):
 
 
60
  if prompt_template is None:
61
  prompt_template = PromptTemplate(
62
  "Context information is below.\n"
 
4
  from llama_index.core import Document, VectorStoreIndex
5
  from llama_index.core.node_parser import SentenceWindowNodeParser, SentenceSplitter
6
  from llama_index.core import PromptTemplate
 
7
 
8
 
9
  class RAGPipeline:
10
  def __init__(self, study_json, use_semantic_splitter=False):
11
  self.study_json = study_json
 
12
  self.use_semantic_splitter = use_semantic_splitter
13
+ self.documents = None
14
+ self.index = None
15
 
16
  def load_documents(self):
17
+ if self.documents is None:
18
+ with open(self.study_json, "r") as f:
19
+ self.data = json.load(f)
20
 
21
+ self.documents = []
22
 
23
+ for index, doc_data in enumerate(self.data):
24
+ doc_content = (
25
+ f"Title: {doc_data['title']}\n"
26
+ f"Authors: {', '.join(doc_data['authors'])}\n"
27
+ f"Full Text: {doc_data['full_text']}"
28
+ )
29
 
30
+ metadata = {
31
+ "title": doc_data.get("title"),
32
+ "abstract": doc_data.get("abstract"),
33
+ "authors": doc_data.get("authors", []),
34
+ "year": doc_data.get("year"),
35
+ "doi": doc_data.get("doi"),
36
+ }
37
 
38
+ self.documents.append(
39
+ Document(text=doc_content, id_=f"doc_{index}", metadata=metadata)
40
+ )
41
 
42
  def build_index(self):
43
+ if self.index is None:
44
+ self.load_documents()
45
+ sentence_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=13)
46
 
47
+ def _split(text: str) -> List[str]:
48
+ return sentence_splitter.split_text(text)
49
 
50
+ node_parser = SentenceWindowNodeParser.from_defaults(
51
+ sentence_splitter=_split,
52
+ window_size=3,
53
+ window_metadata_key="window",
54
+ original_text_metadata_key="original_text",
55
+ )
56
 
57
+ nodes = node_parser.get_nodes_from_documents(self.documents)
58
+ self.index = VectorStoreIndex(nodes)
59
 
60
  def query(self, question, prompt_template=None):
61
+ self.build_index() # This will only build the index if it hasn't been built yet
62
+
63
  if prompt_template is None:
64
  prompt_template = PromptTemplate(
65
  "Context information is below.\n"