demiroz commited on
Commit
f257e56
1 Parent(s): 77c8a38

Upload 2 files

Browse files
Files changed (2) hide show
  1. chain.py +85 -0
  2. config.py +19 -0
chain.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains functions for loading a ConversationalRetrievalChain"""
2
+
3
+ import logging
4
+
5
+ import wandb
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.chat_models import ChatOpenAI
8
+ from langchain.embeddings import OpenAIEmbeddings
9
+ from langchain.vectorstores import Chroma
10
+ from prompts import load_chat_prompt
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma:
17
+ """Load a vector store from a Weights & Biases artifact
18
+ Args:
19
+ run (wandb.run): An active Weights & Biases run
20
+ openai_api_key (str): The OpenAI API key to use for embedding
21
+ Returns:
22
+ Chroma: A chroma vector store object
23
+ """
24
+ # load vector store artifact
25
+ vector_store_artifact_dir = wandb_run.use_artifact(
26
+ wandb_run.config.vector_store_artifact, type="search_index"
27
+ ).download()
28
+ embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key)
29
+ # load vector store
30
+ vector_store = Chroma(
31
+ embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir
32
+ )
33
+
34
+ return vector_store
35
+
36
+
37
+ def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str):
38
+ """Load a ConversationalQA chain from a config and a vector store
39
+ Args:
40
+ wandb_run (wandb.run): An active Weights & Biases run
41
+ vector_store (Chroma): A Chroma vector store object
42
+ openai_api_key (str): The OpenAI API key to use for embedding
43
+ Returns:
44
+ ConversationalRetrievalChain: A ConversationalRetrievalChain object
45
+ """
46
+ retriever = vector_store.as_retriever()
47
+ llm = ChatOpenAI(
48
+ openai_api_key=openai_api_key,
49
+ model_name=wandb_run.config.model_name,
50
+ temperature=wandb_run.config.chat_temperature,
51
+ max_retries=wandb_run.config.max_fallback_retries,
52
+ )
53
+ chat_prompt_dir = wandb_run.use_artifact(
54
+ wandb_run.config.chat_prompt_artifact, type="prompt"
55
+ ).download()
56
+ qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/prompt.json")
57
+ qa_chain = ConversationalRetrievalChain.from_llm(
58
+ llm=llm,
59
+ chain_type="stuff",
60
+ retriever=retriever,
61
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
62
+ return_source_documents=True,
63
+ )
64
+ return qa_chain
65
+
66
+
67
+ def get_answer(
68
+ chain: ConversationalRetrievalChain,
69
+ question: str,
70
+ chat_history: list[tuple[str, str]],
71
+ ):
72
+ """Get an answer from a ConversationalRetrievalChain
73
+ Args:
74
+ chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object
75
+ question (str): The question to ask
76
+ chat_history (list[tuple[str, str]]): A list of tuples of (question, answer)
77
+ Returns:
78
+ str: The answer to the question
79
+ """
80
+ result = chain(
81
+ inputs={"question": question, "chat_history": chat_history},
82
+ return_only_outputs=True,
83
+ )
84
+ response = f"Answer:\t{result['answer']}"
85
+ return response
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for the LLM Apps Course"""
2
+ from types import SimpleNamespace
3
+
4
+ TEAM = None
5
+ PROJECT = "llmapps"
6
+ JOB_TYPE = "production"
7
+
8
+ default_config = SimpleNamespace(
9
+ project=PROJECT,
10
+ entity=TEAM,
11
+ job_type=JOB_TYPE,
12
+ vector_store_artifact="demiroz/llmapps/vector_store:v1",
13
+ chat_prompt_artifact="demiroz/llmapps/chat_prompt:v0",
14
+ chat_temperature=0.3,
15
+ max_fallback_retries=1,
16
+ model_name="gpt-3.5-turbo",
17
+ eval_model="gpt-3.5-turbo",
18
+ eval_artifact="demiroz/llmapps/vector_store:v1",
19
+ )