0xdant commited on
Commit
e8445ba
β€’
1 Parent(s): 7613242

Prompt improvements, replace deprecated methods and chat history support

Browse files
Files changed (2) hide show
  1. requirements.txt +5 -4
  2. src/worker_huggingface.py +92 -58
requirements.txt CHANGED
@@ -3,12 +3,13 @@ transformers
3
  torch
4
  Pillow
5
  accelerate
6
- langchain==0.1.12
7
  pypdf
8
- chromadb
9
- sentence-transformers==2.2.2
10
  InstructorEmbedding
11
  flask
12
  flask_cors
13
  huggingface-hub
14
- langchain-community
 
 
 
3
  torch
4
  Pillow
5
  accelerate
6
+ langchain
7
  pypdf
8
+ sentence-transformers
 
9
  InstructorEmbedding
10
  flask
11
  flask_cors
12
  huggingface-hub
13
+ langchain-community
14
+ langchain-huggingface
15
+ faiss-cpu
src/worker_huggingface.py CHANGED
@@ -1,12 +1,24 @@
1
  import os
2
  import torch
3
  from langchain_core.prompts import PromptTemplate
4
- from langchain.chains import RetrievalQA
5
  from langchain_community.embeddings import HuggingFaceInstructEmbeddings
6
  from langchain_community.document_loaders import PyPDFLoader
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain_community.vectorstores import Chroma
9
- from langchain_community.llms import HuggingFaceHub
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  # Check for GPU availability and set the appropriate device for computation.
@@ -17,6 +29,7 @@ conversation_retrieval_chain = None
17
  chat_history = []
18
  llm_hub = None
19
  embeddings = None
 
20
 
21
  # Function to initialize the language model and its embeddings
22
  def init_llm():
@@ -25,32 +38,27 @@ def init_llm():
25
  # Hugging Face API token
26
  # Setup environment variable HUGGINGFACEHUB_API_TOKEN
27
 
28
- # repo name for the model
29
- # model_id = "facebook/blenderbot-400M-distill"
30
- # model_id = "tiiuae/falcon-7b-instruct"
31
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
32
 
33
- # Define the model parameters
34
- model_kwargs = {
35
- "temperature": 0.1, # Lower temperature for more focused outputs
36
- "top_k": 10, # Use top-k sampling
37
- "top_p": 0.9, # Use top-p (nucleus) sampling
38
- "max_length": 512, # Limit the length of the response
39
- "repetition_penalty": 1.2 # Penalize repeated phrases
40
- }
41
- # load the model into the HuggingFaceHub
42
- llm_hub = HuggingFaceHub(repo_id=model_id, model_kwargs=model_kwargs)
43
-
44
- # #Initialize embeddings using a pre-trained model to represent the text data.
45
- model_name = "sentence-transformers/all-MiniLM-L6-v2"
46
- model_kwargs = {'device': DEVICE}
47
- encode_kwargs = {'normalize_embeddings': False}
48
- embeddings = HuggingFaceInstructEmbeddings(
49
- model_name=model_name,
50
- model_kwargs=model_kwargs,
51
- encode_kwargs=encode_kwargs
52
  )
53
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Function to process a PDF document
56
  def process_document(document_path):
@@ -64,48 +72,74 @@ def process_document(document_path):
64
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
65
  texts = text_splitter.split_documents(documents)
66
 
67
- # Create an embeddings database using Chroma from the split text chunks.
68
- # text_embeddings = [embeddings.encode(text.content) for text in texts]
69
- db = Chroma.from_documents(documents=texts, embedding=embeddings)
70
-
71
- # --> Build the QA chain, which utilizes the LLM and retriever for answering questions.
72
- # By default, the vectorstore retriever uses similarity search.
73
- # If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type (search_type="mmr").
74
- # You can also specify search kwargs like k to use when doing retrieval. k represent how many search results send to llm
75
- conversation_retrieval_chain = RetrievalQA.from_chain_type(
76
- llm=llm_hub,
77
- chain_type="stuff",
78
- retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 5, 'lambda_mult': 0.25}),
79
- return_source_documents=False,
80
- input_key = "question"
81
- # chain_type_kwargs={"prompt": prompt} # if you are using prompt template, you need to uncomment this part
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )
83
 
84
 
85
  # Function to process a user prompt
86
  def process_prompt(prompt):
87
- global conversation_retrieval_chain
88
- global chat_history
89
 
90
- improved_prompt = f"Use the given prompt to answer the question. If you don't know the answer, say you don't know. Use three sentence maximum and keep the answer concise. Prompt: {prompt}"
91
-
92
- # Query the model
93
- output = conversation_retrieval_chain.invoke({"question": improved_prompt, "chat_history": chat_history})
94
- answer = output["result"]
 
 
 
95
  print(output)
96
-
97
- # Extract the 'Helpful Answer:' part
98
- helpful_answer_index = answer.find("Helpful Answer:")
99
- if helpful_answer_index != -1:
100
- helpful_answer = answer[helpful_answer_index + len("Helpful Answer:"):].strip()
101
- else:
102
- helpful_answer = answer
103
-
104
- # Update the chat history
105
- chat_history.append((prompt, helpful_answer))
106
 
107
  # Return the model's response
108
- return helpful_answer
109
 
110
  # Initialize the language model
111
  init_llm()
 
1
  import os
2
  import torch
3
  from langchain_core.prompts import PromptTemplate
 
4
  from langchain_community.embeddings import HuggingFaceInstructEmbeddings
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_community.vectorstores import FAISS
8
+
9
+ from langchain.chains import create_retrieval_chain
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain.chains.combine_documents import create_stuff_documents_chain
12
+
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
+ from langchain_core.prompts import MessagesPlaceholder
15
+
16
+ from langchain.chains import create_history_aware_retriever
17
+
18
+ from langchain_core.chat_history import BaseChatMessageHistory
19
+ from langchain_core.runnables.history import RunnableWithMessageHistory
20
+ from langchain_community.chat_message_histories import ChatMessageHistory
21
+
22
 
23
 
24
  # Check for GPU availability and set the appropriate device for computation.
 
29
  chat_history = []
30
  llm_hub = None
31
  embeddings = None
32
+ tokenizer = None
33
 
34
  # Function to initialize the language model and its embeddings
35
  def init_llm():
 
38
  # Hugging Face API token
39
  # Setup environment variable HUGGINGFACEHUB_API_TOKEN
40
 
 
 
 
41
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
42
 
43
+ llm_hub = HuggingFaceEndpoint(
44
+ repo_id=model_id,
45
+ task="text-generation",
46
+ max_new_tokens=200,
47
+ do_sample=False,
48
+ repetition_penalty=1.03,
49
+ return_full_text=False,
50
+ temperature=0.1,
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
+ from langchain.embeddings import HuggingFaceEmbeddings
54
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
55
+
56
+ store = {}
57
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
58
+ if session_id not in store:
59
+ store[session_id] = ChatMessageHistory()
60
+ return store[session_id]
61
+
62
 
63
  # Function to process a PDF document
64
  def process_document(document_path):
 
72
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
73
  texts = text_splitter.split_documents(documents)
74
 
75
+ # Create an embeddings database using FAISS from the split text chunks.
76
+ db = FAISS.from_documents(documents=texts, embedding=embeddings)
77
+
78
+ system_prompt = """
79
+ <|start_header_id|>user<|end_header_id|>
80
+ You are an assistant for answering questions using provided context.
81
+ You are given the extracted parts of a long document, previous chat_history and a question. Provide a conversational answer.
82
+ If you don't know the answer, just say "I do not know." Don't make up an answer.
83
+ Question: {input}
84
+ Context: {context}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
85
+ """
86
+ prompt = ChatPromptTemplate.from_messages(
87
+ [
88
+ ("system", system_prompt),
89
+ ("human", "{input}"),
90
+ ]
91
+ )
92
+
93
+ retriever=db.as_retriever(search_type="similarity", search_kwargs={'k': 3, 'lambda_mult': 0.25})
94
+ question_answer_chain = create_stuff_documents_chain(llm_hub, prompt)
95
+ # conversation_retrieval_chain = create_retrieval_chain(retriever, question_answer_chain)
96
+
97
+ contextualize_q_system_prompt = (
98
+ "Given a chat history and the latest user question "
99
+ "which might reference context in the chat history, "
100
+ "formulate a standalone question which can be understood "
101
+ "without the chat history. Do NOT answer the question, "
102
+ "just reformulate it if needed and otherwise return it as is."
103
+ )
104
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
105
+ [
106
+ ("system", contextualize_q_system_prompt),
107
+ MessagesPlaceholder("chat_history"),
108
+ ("human", "{input}"),
109
+ ]
110
+ )
111
+ history_aware_retriever = create_history_aware_retriever(
112
+ llm_hub, retriever, contextualize_q_prompt
113
+ )
114
+
115
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
116
+
117
+ conversation_retrieval_chain = RunnableWithMessageHistory(
118
+ rag_chain,
119
+ get_session_history,
120
+ input_messages_key="input",
121
+ history_messages_key="chat_history",
122
+ output_messages_key="answer",
123
  )
124
 
125
 
126
  # Function to process a user prompt
127
  def process_prompt(prompt):
128
+ # global conversation_retrieval_chain
129
+ global chat_history
130
 
131
+ # Query the model with history
132
+ output = conversation_retrieval_chain.invoke(
133
+ {"input": prompt},
134
+ config={
135
+ "configurable": {"session_id": "abc123"}
136
+ }, # constructs a key "abc123" in `store`.
137
+ )
138
+ answer = output["answer"]
139
  print(output)
 
 
 
 
 
 
 
 
 
 
140
 
141
  # Return the model's response
142
+ return answer
143
 
144
  # Initialize the language model
145
  init_llm()