Spaces:
Sleeping
Sleeping
import spaces | |
import time | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
import pandas as pd | |
import faiss | |
import transformers | |
import spaces | |
import torch | |
import os | |
class Langchain_RAG: | |
def __init__(self): | |
# Initialize an empty FAISS index | |
self.embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5") | |
print("Loading PDF file, this may take time to process...") | |
self.data = pd.read_csv("conversation_RAG.csv", delimiter=",") | |
self.texts=self.data["Output"] | |
self.texts=self.texts.tolist() | |
self.get_vec_value = FAISS.from_texts(self.texts, self.embeddings) | |
print("Vector values saved.") | |
self.retriever = self.get_vec_value.as_retriever(search_kwargs={"k": 4}) | |
def __call__(self, query): | |
relevant_docs = self.retriever.invoke(query) | |
return "".join([doc.page_content for doc in relevant_docs]) | |
class Llama3_8B_gen: | |
def __init__(self, threshold): | |
self.pipeline = transformers.pipeline( | |
"text-generation", | |
model="dixiyao/toxic-professor-bot-dpo_3epochs", | |
token=os.getenv("HF_TOKEN"), | |
model_kwargs={ | |
"torch_dtype": torch.float16, | |
"quantization_config": {"load_in_4bit": True}, | |
"low_cpu_mem_usage": True, | |
}, | |
) | |
self.terminators = [ | |
self.pipeline.tokenizer.eos_token_id, | |
self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
self.embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5") | |
dimension = self.embeddings.client.get_sentence_embedding_dimension() | |
index = faiss.IndexFlatL2(dimension) | |
docstore = InMemoryDocstore() | |
self.vector_store = FAISS( | |
embedding_function=self.embeddings, | |
index=index, | |
docstore=docstore, | |
index_to_docstore_id={} | |
) | |
self.threshold = threshold | |
self.llama3=transformers.pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct") | |
def generate_prompt(query,retrieved_text): | |
messages = """Below is a conservation between a PhD student and Professor Li. Reply as you are Professor Baochun Li and respond to the student | |
### Student Question: {} | |
### Professor Li Response: {}""".format(query,retrieved_text) | |
return messages | |
def semantic_cache(self, query, prompt): | |
query_embedding = self.embeddings.embed_documents([query]) | |
similar_docs = self.vector_store.similarity_search_with_score_by_vector(query_embedding[0], k=1) | |
if similar_docs and similar_docs[0][1] <self.threshold: | |
self.print_bold_underline("---->> From Cache") | |
return similar_docs[0][0].metadata['response'] | |
else: | |
self.print_bold_underline("---->> From LLM") | |
output = self.pipeline(prompt, max_new_tokens=512, eos_token_id=self.terminators, do_sample=True, temperature=0.7, top_p=0.9) | |
response = output[0]["generated_text"][len(prompt):] | |
self.vector_store.add_texts(texts = [query], | |
metadatas = [{'response': response},]) | |
# prompt="Revise the response to remove replicated sentences. The last sentence should be ended. Remove unended sentences. Do not change the rest sentences ro meanings. Texts: {}".format(response) | |
# output = self.llama3(prompt, max_new_tokens=512, eos_token_id=self.terminators, do_sample=True, temperature=0.7, top_p=0.9) | |
# response = output[0]["generated_text"][len(prompt):] | |
return response | |
def generate(self, query, retrieved_context): | |
start_time = time.time() | |
prompt = self.generate_prompt(query, retrieved_context) | |
res = self.semantic_cache(query, prompt) | |
end_time = time.time() | |
execution_time = end_time - start_time | |
self.print_bold_underline(f"LLM generated in {execution_time:.6f} seconds") | |
return res | |
def print_bold_underline(text): | |
print(f"\033[1m\033[4m{text}\033[0m") | |
def Rag_qa(query): | |
text_gen=Llama3_8B_gen(0.1) | |
retriever=Langchain_RAG() | |
retriever_context = retriever(query) | |
result = text_gen.generate(query,retriever_context) | |
return result | |
#print(Rag_qa("Can you write a reference letter for me?")) | |