dixiyao's picture
fix
14da8a1
raw
history blame
4.7 kB
import spaces
import time
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
import pandas as pd
import faiss
import transformers
import spaces
import torch
import os
@spaces.GPU
class Langchain_RAG:
def __init__(self):
# Initialize an empty FAISS index
self.embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
print("Loading PDF file, this may take time to process...")
self.data = pd.read_csv("conversation_RAG.csv", delimiter=",")
self.texts=self.data["Output"]
self.texts=self.texts.tolist()
self.get_vec_value = FAISS.from_texts(self.texts, self.embeddings)
print("Vector values saved.")
self.retriever = self.get_vec_value.as_retriever(search_kwargs={"k": 4})
def __call__(self, query):
relevant_docs = self.retriever.invoke(query)
return "".join([doc.page_content for doc in relevant_docs])
@spaces.GPU
class Llama3_8B_gen:
def __init__(self, threshold):
self.pipeline = transformers.pipeline(
"text-generation",
model="dixiyao/toxic-professor-bot-dpo_3epochs",
token=os.getenv("HF_TOKEN"),
model_kwargs={
"torch_dtype": torch.float16,
"quantization_config": {"load_in_4bit": True},
"low_cpu_mem_usage": True,
},
)
self.terminators = [
self.pipeline.tokenizer.eos_token_id,
self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
self.embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
dimension = self.embeddings.client.get_sentence_embedding_dimension()
index = faiss.IndexFlatL2(dimension)
docstore = InMemoryDocstore()
self.vector_store = FAISS(
embedding_function=self.embeddings,
index=index,
docstore=docstore,
index_to_docstore_id={}
)
self.threshold = threshold
self.llama3=transformers.pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct")
@staticmethod
def generate_prompt(query,retrieved_text):
messages = """Below is a conservation between a PhD student and Professor Li. Reply as you are Professor Baochun Li and respond to the student
### Student Question: {}
### Professor Li Response: {}""".format(query,retrieved_text)
return messages
@spaces.GPU
def semantic_cache(self, query, prompt):
query_embedding = self.embeddings.embed_documents([query])
similar_docs = self.vector_store.similarity_search_with_score_by_vector(query_embedding[0], k=1)
if similar_docs and similar_docs[0][1] <self.threshold:
self.print_bold_underline("---->> From Cache")
return similar_docs[0][0].metadata['response']
else:
self.print_bold_underline("---->> From LLM")
output = self.pipeline(prompt, max_new_tokens=512, eos_token_id=self.terminators, do_sample=True, temperature=0.7, top_p=0.9)
response = output[0]["generated_text"][len(prompt):]
self.vector_store.add_texts(texts = [query],
metadatas = [{'response': response},])
# prompt="Revise the response to remove replicated sentences. The last sentence should be ended. Remove unended sentences. Do not change the rest sentences ro meanings. Texts: {}".format(response)
# output = self.llama3(prompt, max_new_tokens=512, eos_token_id=self.terminators, do_sample=True, temperature=0.7, top_p=0.9)
# response = output[0]["generated_text"][len(prompt):]
return response
@spaces.GPU
def generate(self, query, retrieved_context):
start_time = time.time()
prompt = self.generate_prompt(query, retrieved_context)
res = self.semantic_cache(query, prompt)
end_time = time.time()
execution_time = end_time - start_time
self.print_bold_underline(f"LLM generated in {execution_time:.6f} seconds")
return res
@staticmethod
def print_bold_underline(text):
print(f"\033[1m\033[4m{text}\033[0m")
@spaces.GPU
def Rag_qa(query):
text_gen=Llama3_8B_gen(0.1)
retriever=Langchain_RAG()
retriever_context = retriever(query)
result = text_gen.generate(query,retriever_context)
return result
#print(Rag_qa("Can you write a reference letter for me?"))