Chris4K commited on
Commit
12c661c
1 Parent(s): 66a2c7e

Update vector_store_retriever.py

Browse files
Files changed (1) hide show
  1. vector_store_retriever.py +168 -23
vector_store_retriever.py CHANGED
@@ -1,37 +1,182 @@
 
 
1
  import gradio as gr
2
- from langchain.vectorstores import Chroma
3
- from langchain.document_loaders import PyPDFLoader
 
 
 
4
  from langchain.embeddings import HuggingFaceInstructEmbeddings
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
 
 
6
 
7
- # Initialize the HuggingFaceInstructEmbeddings
8
- hf = HuggingFaceInstructEmbeddings(
9
- model_name="sentence-transformers/all-MiniLM-L6-v2",
 
 
10
  model_kwargs={"device": "cpu"}
11
  )
12
 
13
- # Load and process the PDF files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from langchain.document_loaders import PyPDFDirectoryLoader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- loader = PyPDFDirectoryLoader("new_papers/")
 
 
17
 
18
- documents = loader.load()
 
 
 
 
19
 
20
- #loader = PyPDFLoader('./new_papers/', glob="./*.pdf")
21
- #documents = loader.load()
 
 
 
 
22
 
23
- #splitting the text into
24
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
25
- texts = text_splitter.split_documents(documents)
26
 
27
- # Create a Chroma vector store from the PDF documents
28
- db = Chroma.from_documents(texts, hf, collection_name="my-collection")
29
 
30
- class VectoreStoreRetrievalTool:
31
- def __init__(self):
32
- self.retriever = db.as_retriever(search_kwargs={"k": 1})
33
 
34
- def __call__(self, query):
35
- # Run the query through the retriever
36
- response = self.retriever.run(query)
37
- return response['result']
 
1
+ import json
2
+ import os
3
  import gradio as gr
4
+ import time
5
+ from pydantic import BaseModel, Field
6
+ from typing import Any, Optional, Dict, List
7
+ from huggingface_hub import InferenceClient
8
+ from langchain.llms.base import LLM
9
  from langchain.embeddings import HuggingFaceInstructEmbeddings
10
+ from langchain.vectorstores import Chroma
11
+ from transformers import AutoTokenizer
12
+ from transformers import Tool
13
+
14
+ load_dotenv()
15
 
16
+ path_work = "."
17
+ hf_token = os.getenv("HF")
18
+
19
+ embeddings = HuggingFaceInstructEmbeddings(
20
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
21
  model_kwargs={"device": "cpu"}
22
  )
23
 
24
+ vectordb = Chroma(
25
+ persist_directory=path_work + '/new_papers',
26
+ embedding_function=embeddings
27
+ )
28
+
29
+ retriever = vectordb.as_retriever(search_kwargs={"k": 2})#5
30
+
31
+ class KwArgsModel(BaseModel):
32
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
33
+
34
+ class CustomInferenceClient(LLM, KwArgsModel):
35
+ model_name: str
36
+ inference_client: InferenceClient
37
+
38
+ def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
39
+ inference_client = InferenceClient(model=model_name, token=hf_token)
40
+ super().__init__(
41
+ model_name=model_name,
42
+ hf_token=hf_token,
43
+ kwargs=kwargs,
44
+ inference_client=inference_client
45
+ )
46
+
47
+ def _call(
48
+ self,
49
+ prompt: str,
50
+ stop: Optional[List[str]] = None
51
+ ) -> str:
52
+ if stop is not None:
53
+ raise ValueError("stop kwargs are not permitted.")
54
+ response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
55
+ response = ''.join(response_gen)
56
+ return response
57
+
58
+ @property
59
+ def _llm_type(self) -> str:
60
+ return "custom"
61
+
62
+ @property
63
+ def _identifying_params(self) -> dict:
64
+ return {"model_name": self.model_name}
65
+
66
+ kwargs = {"max_new_tokens": 256, "temperature": 0.9, "top_p": 0.6, "repetition_penalty": 1.3, "do_sample": True}
67
+
68
+ model_list = [
69
+ "meta-llama/Llama-2-13b-chat-hf",
70
+ "HuggingFaceH4/zephyr-7b-alpha",
71
+ "meta-llama/Llama-2-70b-chat-hf",
72
+ "tiiuae/falcon-180B-chat"
73
+ ]
74
+
75
+ qa_chain = None
76
+
77
+ def load_model(model_selected):
78
+ global qa_chain
79
+ model_name = model_selected
80
+ llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
81
+
82
+ from langchain.chains import RetrievalQA
83
+ qa_chain = RetrievalQA.from_chain_type(
84
+ llm=llm,
85
+ chain_type="stuff",
86
+ retriever=retriever,
87
+ return_source_documents=True,
88
+ verbose=True,
89
+ )
90
+ return qa_chain
91
+
92
+ load_model("meta-llama/Llama-2-70b-chat-hf")
93
+
94
+ ##########
95
+ #####
96
+ #########
97
+
98
  from langchain.document_loaders import PyPDFDirectoryLoader
99
+ from langchain.document_loaders.utils import RecursiveCharacterTextSplitter
100
+ from langchain.vectorstores import Chroma
101
+
102
+ def load_and_process_pdfs(directory_path: str, chunk_size: int = 500, chunk_overlap: int = 200, collection_name: str = "my-collection"):
103
+ # Load PDF files from the specified directory
104
+ loader = PyPDFDirectoryLoader(directory_path)
105
+ documents = loader.load()
106
+
107
+ # Split the text into chunks
108
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
109
+ texts = text_splitter.split_documents(documents)
110
+
111
+ # Create a Chroma vector store from the processed texts
112
+ db = Chroma.from_documents(texts, hf, collection_name=collection_name)
113
+
114
+ return db # You can return the Chroma vector store if needed
115
+
116
+ # Call the function with the desired directory path and parameters
117
+ load_and_process_pdfs("new_papers/")
118
+
119
+ ###
120
+ ###
121
+ ###
122
+
123
+ def predict(message, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3):
124
+ temperature = float(temperature)
125
+ if temperature < 1e-2: temperature = 1e-2
126
+ top_p = float(top_p)
127
+
128
+ llm_response = qa_chain(message)
129
+ res_result = llm_response['result']
130
+
131
+ res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
132
+ response = f"{res_result}" + "\n\n" + "[Answer Source Documents (Ctrl + Click!)] :" + "\n" + f" \n {res_relevant_doc}"
133
+ print("response: =====> \n", response, "\n\n")
134
+
135
+ tokens = response.split('\n')
136
+ token_list = []
137
+ for idx, token in enumerate(tokens):
138
+ token_dict = {"id": idx + 1, "text": token}
139
+ token_list.append(token_dict)
140
+ response = {"data": {"token": token_list}}
141
+ response = json.dumps(response, indent=4)
142
+
143
+ response = json.loads(response)
144
+ data_dict = response.get('data', {})
145
+ token_list = data_dict.get('token', [])
146
+
147
+ partial_message = ""
148
+ for token_entry in token_list:
149
+ if token_entry:
150
+ try:
151
+ token_id = token_entry.get('id', None)
152
+ token_text = token_entry.get('text', None)
153
+
154
+ if token_text:
155
+ for char in token_text:
156
+ partial_message += char
157
+ yield partial_message
158
+ time.sleep(0.01)
159
+ else:
160
+ print(f"[[워닝]] ==> The key 'text' does not exist or is None in this token entry: {token_entry}")
161
+ pass
162
 
163
+ except KeyError as e:
164
+ gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}")
165
+ continue
166
 
167
+ class TextGeneratorTool(Tool):
168
+ name = "vector_retriever"
169
+ description = "This tool searches in a vector store based on a given prompt."
170
+ inputs = ["prompt"]
171
+ outputs = ["generated_text"]
172
 
173
+ def __init__(self):
174
+ #self.retriever = db.as_retriever(search_kwargs={"k": 1})
175
+
176
+ def __call__(self, prompt: str):
177
+ result = predict(prompt, 0.9, 512, 0.6, 1.4)
178
+ return result
179
 
 
 
 
180
 
 
 
181
 
 
 
 
182