Chris4K commited on
Commit
a6c5429
1 Parent(s): 8e6f6b5

Update vector_store_retriever.py

Browse files
Files changed (1) hide show
  1. vector_store_retriever.py +31 -22
vector_store_retriever.py CHANGED
@@ -1,31 +1,40 @@
1
  import gradio as gr
2
- from langchain.document_loaders import TextLoader
3
  from langchain.vectorstores import Chroma
4
  from langchain.chains import RetrievalQA
5
- from langchain.embeddings import HuggingFaceInstructEmbeddings
6
  from langchain.agents import Tool
 
 
 
7
 
8
- # Initialize the HuggingFaceInstructEmbeddings
9
- hf = HuggingFaceInstructEmbeddings(
10
- model_name="hkunlp/instructor-large",
11
- embed_instruction="Represent the document for retrieval: ",
12
- query_instruction="Represent the query for retrieval: "
13
- )
14
 
15
- # Example texts for the vector store
16
- texts=["The meaning of life is to love","The meaning of vacation is to relax","Roses are red.","Hack the planet!"]
 
17
 
18
- # Create a Chroma vector store from the example texts
19
- db = Chroma.from_texts(texts, hf, collection_name="my-collection")
20
 
21
- # Create a RetrievalQA chain
22
- llm = LLM.from_model("lgaalves/gpt2-dolly") # Replace with the appropriate LLM model
23
- docsearcher = RetrievalQA.from_chain_type(
24
- llm=llm,
25
- chain_type="stuff", # Replace with the appropriate chain type
26
- return_source_documents=False,
27
- retriever=db.as_retriever(search_type="similarity", search_kwargs={"k": 1})
28
- )
 
 
 
 
 
 
 
 
29
 
30
  class VectorStoreRetrieverTool(Tool):
31
  name = "vectorstore_retriever"
@@ -36,8 +45,8 @@ class VectorStoreRetrieverTool(Tool):
36
 
37
  def __call__(self, query: str):
38
  # Run the query through the RetrievalQA chain
39
- response = docsearcher.run(query)
40
- return response
41
 
42
  # Create the Gradio interface using the HuggingFaceTool
43
  tool = gr.Interface(
 
1
  import gradio as gr
2
+ from langchain.document_loaders import DirectoryLoader, PyPDFLoader
3
  from langchain.vectorstores import Chroma
4
  from langchain.chains import RetrievalQA
5
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
6
  from langchain.agents import Tool
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.llms import HuggingFacePipeline
9
+ from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
10
 
11
+ # Load and process the text files
12
+ loader = DirectoryLoader('./new_papers/new_papers/', glob="./*.pdf", loader_cls=PyPDFLoader)
13
+ documents = loader.load()
 
 
 
14
 
15
+ # Splitting the text into chunks
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
17
+ texts = text_splitter.split_documents(documents)
18
 
19
+ # HF Instructor Embeddings
20
+ instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl", model_kwargs={"device": "cuda"})
21
 
22
+ # Embed and store the texts
23
+ persist_directory = 'db'
24
+ embedding = instructor_embeddings
25
+ vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
26
+
27
+ # Make a retriever
28
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
29
+
30
+ # Setup LLM for text generation
31
+ tokenizer = LlamaTokenizer.from_pretrained("TheBloke/wizardLM-7B-HF")
32
+ model = LlamaForCausalLM.from_pretrained("TheBloke/wizardLM-7B-HF", load_in_8bit=True, device_map='auto', torch_dtype=torch.float16, low_cpu_mem_usage=True)
33
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=1024, temperature=0, top_p=0.95, repetition_penalty=1.15)
34
+ local_llm = HuggingFacePipeline(pipeline=pipe)
35
+
36
+ # Make a chain
37
+ qa_chain = RetrievalQA.from_chain_type(llm=local_llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
38
 
39
  class VectorStoreRetrieverTool(Tool):
40
  name = "vectorstore_retriever"
 
45
 
46
  def __call__(self, query: str):
47
  # Run the query through the RetrievalQA chain
48
+ llm_response = qa_chain(query)
49
+ return llm_response['result']
50
 
51
  # Create the Gradio interface using the HuggingFaceTool
52
  tool = gr.Interface(