Pranit commited on
Commit
1459715
1 Parent(s): 0aef4ff

adds populate db command

Browse files
Files changed (2) hide show
  1. app.py +6 -3
  2. populate_database.py +1 -5
app.py CHANGED
@@ -1,10 +1,10 @@
1
  from fastapi import FastAPI
2
- import argparse
3
  from langchain.vectorstores.chroma import Chroma
4
  from langchain.prompts import ChatPromptTemplate
5
  from langchain_community.llms import LlamaCpp
6
  from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
7
  from get_embedding_function import get_embedding_function
 
8
 
9
  CHROMA_PATH = "chroma"
10
 
@@ -17,13 +17,16 @@ Answer the question based only on the following context:
17
 
18
  Answer the question based on the above context: {question}
19
  """
 
 
 
20
  embedding_function = get_embedding_function()
21
  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
22
 
23
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
24
  model = LlamaCpp(
25
  model_path="mistral-7b-instruct-v0.2.Q4_K_M.gguf",
26
- temperature=0.75,
27
  max_tokens=2000,
28
  top_p=1,
29
  callback_manager=callback_manager,
@@ -44,4 +47,4 @@ async def getAnswer():
44
  response_text = model.invoke(prompt)
45
  sources = [doc.metadata.get("id", None) for doc, _score in results]
46
  formatted_response = f"Response: {response_text}\nSources: {sources}"
47
- return response_text
 
1
  from fastapi import FastAPI
 
2
  from langchain.vectorstores.chroma import Chroma
3
  from langchain.prompts import ChatPromptTemplate
4
  from langchain_community.llms import LlamaCpp
5
  from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
6
  from get_embedding_function import get_embedding_function
7
+ from populate_database import populate_database
8
 
9
  CHROMA_PATH = "chroma"
10
 
 
17
 
18
  Answer the question based on the above context: {question}
19
  """
20
+
21
+ populate_database()
22
+
23
  embedding_function = get_embedding_function()
24
  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
25
 
26
  callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
27
  model = LlamaCpp(
28
  model_path="mistral-7b-instruct-v0.2.Q4_K_M.gguf",
29
+ temperature=0.1,
30
  max_tokens=2000,
31
  top_p=1,
32
  callback_manager=callback_manager,
 
47
  response_text = model.invoke(prompt)
48
  sources = [doc.metadata.get("id", None) for doc, _score in results]
49
  formatted_response = f"Response: {response_text}\nSources: {sources}"
50
+ return formatted_response
populate_database.py CHANGED
@@ -12,7 +12,7 @@ CHROMA_PATH = "chroma"
12
  DATA_PATH = "data"
13
 
14
 
15
- def main():
16
 
17
  # Check if the database should be cleared (using the --clear flag).
18
  parser = argparse.ArgumentParser()
@@ -104,7 +104,3 @@ def calculate_chunk_ids(chunks):
104
  def clear_database():
105
  if os.path.exists(CHROMA_PATH):
106
  shutil.rmtree(CHROMA_PATH)
107
-
108
-
109
- if __name__ == "__main__":
110
- main()
 
12
  DATA_PATH = "data"
13
 
14
 
15
+ def populate_database():
16
 
17
  # Check if the database should be cleared (using the --clear flag).
18
  parser = argparse.ArgumentParser()
 
104
  def clear_database():
105
  if os.path.exists(CHROMA_PATH):
106
  shutil.rmtree(CHROMA_PATH)