momegas commited on
Commit
f4e889a
β€’
1 Parent(s): 0eed37b

πŸ˜‡ Test your code people

Browse files
.gitignore CHANGED
@@ -4,4 +4,5 @@ __pycache__
4
  qnabot.egg-info
5
  dist
6
  build
7
- **.pickle
 
 
4
  qnabot.egg-info
5
  dist
6
  build
7
+ **.pickle
8
+ .env
README.md CHANGED
@@ -31,6 +31,7 @@ bot = QnABot(directory="./mydata", index="index.pickle")
31
  - [x] Save / load index to reduce costs (Open AI embedings are used to create the index)
32
  - [x] Local data source (directory of documents) or S3 data source
33
  - [x] FAISS for storing vectors / index
 
34
  - [ ] Customise prompt
35
  - [ ] Expose API
36
  - [ ] Support for LLaMA model
 
31
  - [x] Save / load index to reduce costs (Open AI embedings are used to create the index)
32
  - [x] Local data source (directory of documents) or S3 data source
33
  - [x] FAISS for storing vectors / index
34
+ - [ ] Support for other vector databases (e.g. Weaviate, Pinecone)
35
  - [ ] Customise prompt
36
  - [ ] Expose API
37
  - [ ] Support for LLaMA model
examples/example.ipynb CHANGED
@@ -18,7 +18,7 @@
18
  "from qnabot.QnABot import QnABot\n",
19
  "import os, sys\n",
20
  "\n",
21
- "os.environ[\"OPENAI_API_KEY\"] = \"you api key\"\n",
22
  "\n",
23
  "bot = QnABot(directory=\"./files\")"
24
  ]
@@ -40,7 +40,7 @@
40
  }
41
  ],
42
  "source": [
43
- "bot.print_answer(\"what was the first roster of angers in comics?\")\n",
44
  "bot.print_answer(\"Who is Vision?\")"
45
  ]
46
  }
 
18
  "from qnabot.QnABot import QnABot\n",
19
  "import os, sys\n",
20
  "\n",
21
+ "os.environ[\"OPENAI_API_KEY\"] = \"your api key\"\n",
22
  "\n",
23
  "bot = QnABot(directory=\"./files\")"
24
  ]
 
40
  }
41
  ],
42
  "source": [
43
+ "bot.print_answer(\"what was the first roster of avengers in comics?\")\n",
44
  "bot.print_answer(\"Who is Vision?\")"
45
  ]
46
  }
qnabot/QnABot.py CHANGED
@@ -73,7 +73,7 @@ class QnABot:
73
  )["output_text"]
74
  )
75
 
76
- def get_answer(self, question, k=1):
77
  # Retrieve the answer to the given question and return it
78
  input_documents = self.search_index.similarity_search(question, k=k)
79
  return self.chain(
 
73
  )["output_text"]
74
  )
75
 
76
+ def get_answer(self, question, k=1) -> str:
77
  # Retrieve the answer to the given question and return it
78
  input_documents = self.search_index.similarity_search(question, k=k)
79
  return self.chain(
qnabot/tests/__init__.py ADDED
File without changes
qnabot/tests/test_QnABot.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from qnabot import QnABot
4
+ import pickle
5
+ from langchain.vectorstores.faiss import FAISS
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ # Define test data
11
+ test_directory = "./examples/files"
12
+ test_question = "what was the first roster of avengers in comics?"
13
+ correct_answer = "Iron Man, Thor, Hulk, Ant-Man"
14
+ sources = "SOURCES:"
15
+
16
+
17
+ def test_get_answer():
18
+ bot = QnABot(directory=test_directory)
19
+ answer = bot.get_answer(test_question)
20
+
21
+ # Assert that the answer contains the correct answer
22
+ assert correct_answer in answer
23
+ # Assert that the answer contains the sources
24
+ assert sources in answer
25
+
26
+
27
+ def test_save_load_index():
28
+ # Create a temporary directory and file path for the test index
29
+ with tempfile.TemporaryDirectory() as temp_dir:
30
+ index_path = os.path.join(temp_dir, "test_index.pkl")
31
+
32
+ # Create a bot and save the index to the temporary file path
33
+ bot = QnABot(directory=test_directory, index=index_path)
34
+ bot.save_index(index_path)
35
+
36
+ # Load the saved index and assert that it is the same as the original index
37
+ with open(index_path, "rb") as f:
38
+ saved_index = pickle.load(f)
39
+ assert isinstance(saved_index, FAISS)
40
+
41
+ bot_with_predefined_index = QnABot(directory=test_directory, index=index_path)
42
+
43
+ # Assert that the bot returns the correct answer to the test question
44
+ assert correct_answer in bot_with_predefined_index.get_answer(test_question)