mlara commited on
Commit
8950a9f
·
1 Parent(s): 750e9c4
Files changed (3) hide show
  1. aimakerspace/openai_utils/chatmodel.py +2 -2
  2. app.py +3 -2
  3. rag.py +2 -9
aimakerspace/openai_utils/chatmodel.py CHANGED
@@ -9,11 +9,11 @@ class ChatOpenAI:
9
  if self.openai_api_key is None:
10
  raise ValueError("OPENAI_API_KEY is not set")
11
 
12
- def run(self, client, messages, text_only: bool = True):
13
  if not isinstance(messages, list):
14
  raise ValueError("messages must be a list")
15
 
16
- # client = OpenAI()
17
  response = client.chat.completions.create(
18
  model=self.model_name, messages=messages
19
  )
 
9
  if self.openai_api_key is None:
10
  raise ValueError("OPENAI_API_KEY is not set")
11
 
12
+ def run(self, messages, text_only: bool = True):
13
  if not isinstance(messages, list):
14
  raise ValueError("messages must be a list")
15
 
16
+ client = OpenAI()
17
  response = client.chat.completions.create(
18
  model=self.model_name, messages=messages
19
  )
app.py CHANGED
@@ -8,6 +8,7 @@ import chainlit as cl # importing chainlit for our app
8
  from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
9
  # from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
10
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
 
11
  from dotenv import load_dotenv
12
 
13
  load_dotenv()
@@ -45,13 +46,13 @@ async def main(message: cl.Message):
45
  settings = cl.user_session.get("settings")
46
 
47
  # client = AsyncOpenAI()
48
- client = ChatOpenAI()
49
 
50
  print(message.content)
51
 
52
  vector_db = _build_vector_db()
53
  pipeline = RetrievalAugmentedQAPipeline(
54
- llm=client,
55
  vector_db_retriever=vector_db
56
  )
57
  response = pipeline.run_pipeline(message.content)
 
8
  from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
9
  # from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
10
  from aimakerspace.openai_utils.chatmodel import ChatOpenAI
11
+
12
  from dotenv import load_dotenv
13
 
14
  load_dotenv()
 
46
  settings = cl.user_session.get("settings")
47
 
48
  # client = AsyncOpenAI()
49
+ chat_openai = ChatOpenAI()
50
 
51
  print(message.content)
52
 
53
  vector_db = _build_vector_db()
54
  pipeline = RetrievalAugmentedQAPipeline(
55
+ llm=chat_openai,
56
  vector_db_retriever=vector_db
57
  )
58
  response = pipeline.run_pipeline(message.content)
rag.py CHANGED
@@ -41,7 +41,7 @@ class RetrievalAugmentedQAPipeline:
41
  self.llm = llm
42
  self.vector_db_retriever = vector_db_retriever
43
 
44
- def run_pipeline(self, client, user_query: str) -> str:
45
  context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
46
 
47
  context_prompt = ""
@@ -52,7 +52,7 @@ class RetrievalAugmentedQAPipeline:
52
 
53
  formatted_user_prompt = user_prompt.create_message(user_query=user_query)
54
 
55
- return self.llm.run(client, [formatted_system_prompt, formatted_user_prompt])
56
 
57
  def _split_documents():
58
  split_documents = []
@@ -71,10 +71,3 @@ def _build_vector_db():
71
  split_documents = _split_documents()
72
  vector_db = asyncio.run(vector_db.abuild_from_list(split_documents))
73
  return vector_db
74
-
75
- # def retrieval_augmented_qa_pipeline(client):
76
- # vector_db = _build_vector_db()
77
- # pipeline = RetrievalAugmentedQAPipeline(
78
- # llm=client,
79
- # vector_db_retriever=vector_db)
80
- # return pipeline
 
41
  self.llm = llm
42
  self.vector_db_retriever = vector_db_retriever
43
 
44
+ def run_pipeline(self, user_query: str) -> str:
45
  context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
46
 
47
  context_prompt = ""
 
52
 
53
  formatted_user_prompt = user_prompt.create_message(user_query=user_query)
54
 
55
+ return self.llm.run([formatted_system_prompt, formatted_user_prompt])
56
 
57
  def _split_documents():
58
  split_documents = []
 
71
  split_documents = _split_documents()
72
  vector_db = asyncio.run(vector_db.abuild_from_list(split_documents))
73
  return vector_db