ImranzamanML commited on
Commit
be1b078
·
verified ·
1 Parent(s): e40592d

Update ai_assistant.py

Browse files
Files changed (1) hide show
  1. ai_assistant.py +27 -24
ai_assistant.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from langchain.document_loaders.csv_loader import CSVLoader
2
  from langchain.embeddings.openai import OpenAIEmbeddings
3
  from langchain.embeddings import CacheBackedEmbeddings
@@ -5,49 +6,51 @@ from langchain_community.vectorstores import FAISS
5
  from langchain.storage import LocalFileStore
6
  from langchain.chains import RetrievalQA
7
  from langchain_openai import ChatOpenAI
8
- import os
9
 
10
  def create_index():
11
- # load the data
12
- dir = os.path.dirname(__file__)
13
- df_path = dir + '/data/train.csv'
14
- loader = CSVLoader(file_path = df_path)
15
- data = loader.load()
16
 
17
- # create the embeddings model
18
  embeddings_model = OpenAIEmbeddings()
19
 
20
- # create the cache backed embeddings in vector store
21
  store = LocalFileStore("./cache")
22
- cached_embeder = CacheBackedEmbeddings.from_bytes_store(
23
  embeddings_model, store, namespace=embeddings_model.model
24
  )
 
 
25
  vector_store = FAISS.from_documents(data, embeddings_model)
26
 
27
  return vector_store.as_retriever()
28
 
29
- def setup(openai_key):
30
- # Set the API key for OpenAI
31
  os.environ["OPENAI_API_KEY"] = openai_key
32
- retriver = create_index()
33
- llm = ChatOpenAI(model="gpt-4")
34
- return retriver, llm
35
-
36
- def ai_doctor(openai_key,query):
 
 
 
37
 
38
- # Setup
39
- retriever,llm = setup(openai_key)
 
 
40
  # Create the QA chain
41
  handler = StdOutCallbackHandler()
42
-
43
  qa_with_sources_chain = RetrievalQA.from_chain_type(
44
- llm=llm,
45
  retriever=retriever,
46
  callbacks=[handler],
47
  return_source_documents=True
48
  )
49
 
50
- # Ask a question
51
- res = qa_with_sources_chain({"query":query})
52
- return (res['result'])
53
-
 
1
+ import os
2
  from langchain.document_loaders.csv_loader import CSVLoader
3
  from langchain.embeddings.openai import OpenAIEmbeddings
4
  from langchain.embeddings import CacheBackedEmbeddings
 
6
  from langchain.storage import LocalFileStore
7
  from langchain.chains import RetrievalQA
8
  from langchain_openai import ChatOpenAI
 
9
 
10
  def create_index():
11
+ # Load the data from CSV file
12
+ data_loader = CSVLoader(file_path="train.csv")
13
+ data = data_loader.load()
 
 
14
 
15
+ # Create the embeddings model
16
  embeddings_model = OpenAIEmbeddings()
17
 
18
+ # Create the cache backed embeddings in vector store
19
  store = LocalFileStore("./cache")
20
+ cached_embedder = CacheBackedEmbeddings.from_bytes_store(
21
  embeddings_model, store, namespace=embeddings_model.model
22
  )
23
+
24
+ # Create FAISS vector store from documents
25
  vector_store = FAISS.from_documents(data, embeddings_model)
26
 
27
  return vector_store.as_retriever()
28
 
29
+ def setup_openai(openai_key):
30
+ # Set the API key for OpenAI
31
  os.environ["OPENAI_API_KEY"] = openai_key
32
+
33
+ # Create index retriever
34
+ retriever = create_index()
35
+
36
+ # Initialize ChatOpenAI model
37
+ chat_openai_model = ChatOpenAI(model="gpt-4")
38
+
39
+ return retriever, chat_openai_model
40
 
41
+ def ai_doctor_chat(openai_key, query):
42
+ # Setup OpenAI environment
43
+ retriever, chat_model = setup_openai(openai_key)
44
+
45
  # Create the QA chain
46
  handler = StdOutCallbackHandler()
 
47
  qa_with_sources_chain = RetrievalQA.from_chain_type(
48
+ llm=chat_model,
49
  retriever=retriever,
50
  callbacks=[handler],
51
  return_source_documents=True
52
  )
53
 
54
+ # Ask a question/query
55
+ res = qa_with_sources_chain({"query": query})
56
+ return res['result']