rag-tool / vector_store_retriever.py
Chris4K's picture
Update vector_store_retriever.py
67a4436
raw
history blame
No virus
6.14 kB
import json
import os
import gradio as gr
import time
import langchain
from pydantic import BaseModel, Field
from typing import Any, Optional, Dict, List, Union
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
#from langchain.Images import Images
from langchain.llms.base import LLM
#from langchain_core.embeddings import EmbeddingFunction, Embeddings
from langchain.embeddings import HuggingFaceInstructEmbeddings
#from langchain import [all]
#from langchain.Documents import Documents
from langchain.vectorstores import Chroma
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModel, Tool
load_dotenv()
path_work = "."
hf_token = os.getenv("HF")
class HuggingFaceInstructEmbeddings(HuggingFaceInstructEmbeddings):
def __init__(self, model_name: str, model_kwargs: Optional[Dict[str, Any]] = None):
self.model = AutoModel.from_pretrained(model_name, **(model_kwargs or {}))
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def __call__(self, input: Union[Documents]) -> HuggingFaceInstructEmbeddings:
if isinstance(input, Documents):
texts = [doc.text for doc in input]
embeddings = self._embed_text(texts)
else:
# Handle image embeddings if needed
pass
return embeddings
def _embed_text(self, texts: List[str]) -> Embeddings:
# Your existing logic for text embeddings using Hugging Face models...
inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1) # Adjust this based on your specific model
return embeddings
vectordb = Chroma(
persist_directory=path_work + '/new_papers',
embedding_function=HuggingFaceInstructEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
)
retriever = vectordb.as_retriever(search_kwargs={"k": 2})#5
class KwArgsModel(BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
kwargs = {"max_new_tokens": 256, "temperature": 0.9, "top_p": 0.6, "repetition_penalty": 1.3, "do_sample": True}
model_list = [
"meta-llama/Llama-2-13b-chat-hf",
"HuggingFaceH4/zephyr-7b-alpha",
"meta-llama/Llama-2-70b-chat-hf",
"tiiuae/falcon-180B-chat"
]
qa_chain = None
def load_model(model_selected):
global qa_chain
model_name = model_selected
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
from langchain.chains import RetrievalQA
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
verbose=True,
)
return qa_chain
load_model("meta-llama/Llama-2-70b-chat-hf")
##########
#####
#########
###
###
###
def predict(message, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3):
temperature = float(temperature)
if temperature < 1e-2: temperature = 1e-2
top_p = float(top_p)
llm_response = qa_chain(message)
res_result = llm_response['result']
res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
response = f"{res_result}" + "\n\n" + "[Answer Source Documents (Ctrl + Click!)] :" + "\n" + f" \n {res_relevant_doc}"
print("response: =====> \n", response, "\n\n")
tokens = response.split('\n')
token_list = []
for idx, token in enumerate(tokens):
token_dict = {"id": idx + 1, "text": token}
token_list.append(token_dict)
response = {"data": {"token": token_list}}
response = json.dumps(response, indent=4)
response = json.loads(response)
data_dict = response.get('data', {})
token_list = data_dict.get('token', [])
partial_message = ""
for token_entry in token_list:
if token_entry:
try:
# Handle missing 'id' key gracefully
token_id = token_entry.get('id', None)
token_text = token_entry.get('text', None)
if token_text:
for char in token_text:
partial_message += char
yield partial_message
time.sleep(0.01)
else:
print(f"Warning ==> The key 'text' does not exist or is None in this token entry: {token_entry}")
pass
except KeyError as e:
print(f"KeyError: {e} occurred for token entry: {token_entry}")
continue
class TextGeneratorTool(Tool):
name = "vector_retriever"
description = "This tool searches in a vector store based on a given prompt."
inputs = ["prompt"]
outputs = ["text"]
def __init__(self):
#self.retriever = db.as_retriever(search_kwargs={"k": 1})
pass # You might want to add some initialization logic here
def __call__(self, prompt: str):
result = predict(prompt, 0.9, 512, 0.6, 1.4)
return result