通过引入语义缓存到 FAISS 中以增强 RAG 系统的性能
作者:Pere Martra
在这个 notebook 中,我们将使用一个现成的模型和 Chroma 数据库来搭建一个常见的 RAG 系统。但我们会加入一个新功能,就是一个语义缓存系统,它会保存用户的各种问题,并决定是直接用数据库的信息来回答问题,还是用之前保存的问题答案。
这个语义缓存系统的目的是找出用户提出的问题中哪些是相似的或者是一样的。如果找到了一个之前问过的问题,系统就会直接用缓存里的答案来回答,这样就不用再去数据库里找了。
因为这个系统会考虑问题的实际意思,所以即使问题表达的方式不同,或者有些小错误,比如拼写或句子结构不对,系统也能识别出用户其实是在问同一个问题。
比如,像 法国的首都是什么?、告诉我法国的首都叫什么? 和 法国的首都是什么? 这样的问题,虽然问法不一样,但都是在问同一个事情。
虽然根据问题的不同,模型的回答可能会有点不一样,但基本上从数据库里拿到的信息应该是相同的。这就是为什么我们把缓存系统放在用户和数据库之间,而不是用户和语言模型之间。
大多数教程指导你创建一个 RAG 系统,这些教程都是为单个用户设计的,用于在测试环境中运行。换句话说,就是在笔记本中与本地向量数据库交互,以及进行 API 调用或使用本地存储的模型。
当尝试将其中一种模型过渡到生产环境时,这种架构很快就显得不够用了,在生产环境中,它们可能会遇到从几十到成千上万次的重复请求。
提高性能的一种方法是通过一个或多个语义缓存。这个缓存保留了以前请求的结果,并且在解决新请求之前,它会检查是否之前收到过类似的请求。如果是这样,它就不会重新执行过程,而是从缓存中检索信息。
在 RAG 系统中,有两个耗时的点:
- 检索用于构建丰富提示的信息:
- 调用大型语言模型以获得响应。
在这两点上,都可以实现语义缓存系统,我们甚至可以有两个缓存,每个点一个。
将缓存系统放在模型的响应点可能会导致对获得响应的影响减少。我们的缓存系统可能会将”用 10 个词解释法国大革命”和”用 100 个词解释法国大革命”视为相同的查询。如果我们的缓存系统存储模型响应,用户可能会认为他们的指令没有被准确地遵循。
但是,两个请求都需要相同的信息来丰富提示。这就是我选择将语义缓存系统放置在用户请求和从向量数据库检索信息之间的主要原因。
然而,这是一个设计决策。根据响应类型和系统请求的不同,它可以被放置在一个点或另一个点。很明显,缓存模型响应会节省最多的时间,但正如我已经解释过的,这样做会牺牲用户对响应的影响。
导入并加载库。
首先,我们需要安装必要的 Python 包。
- sentence transformers。这个库用于将句子转换为固定长度的向量,也称为嵌入。
- xformers。这是一个提供库和工具的包,以便与 transformers 模型一起使用。我们需要安装它,以避免在处理模型和嵌入时出现错误。
- chromadb。这是我们的向量数据库。ChromaDB 易于使用且开源,可能是用于存储嵌入的最常用的向量数据库。
- accelerate。在 GPU 上运行模型的必要条件。
!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1
import numpy as np
import pandas as pd
加载数据集
由于我们在一个免费且有限的空间中工作,并且只能使用几 GB 的内存,我通过变量 MAX_ROWS
限制了从数据集中使用的行数。
#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
from datasets import load_dataset
data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split="train")
ChromaDB 要求数据具有唯一的标识符。我们可以使用这个语句来创建一个名为Id的新列。
data = data.to_pandas()
data["id"] = data.index
data.head(10)
MAX_ROWS = 15000
DOCUMENT = "Answer"
TOPIC = "qtype"
# Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)
导入并配置向量数据库
为了存储信息,我选择使用 ChromaDB,这是最知名且广泛使用的开源向量数据库之一。
首先我们需要导入 ChromaDB。
import chromadb
现在我们只需要指定存储向量数据库的路径。
chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")
填充和查询 ChromaDB 数据库
ChromaDB 中的数据存储在集合中。如果集合已存在,我们需要删除它。
在接下来的行中,我们通过调用上面创建的 chroma_client
中的 create_collection
函数来创建集合。
collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
chroma_client.delete_collection(name=collection_name)
collection = chroma_client.create_collection(name=collection_name)
现在我们准备好使用 add
函数将数据添加到集合中。这个函数需要三个关键信息:
- 在 文档 中,我们存储数据集中
Answer
列的内容。 - 在 元数据 中,我们可以提供一个主题列表。我使用了
qtype
列中的值。 - 在 id 中,我们需要为每一行提供一个唯一的标识符。我使用
MAX_ROWS
的范围来创建ID。
collection.add(
documents=subset_data[DOCUMENT].tolist(),
metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
ids=[f"id{x}" for x in range(MAX_ROWS)],
)
一旦我们在数据库中有了信息,我们就可以查询它,并请求符合我们需求的数据。搜索是在文档内容内部进行的,它不会查找确切的单词或短语。结果将基于搜索词与文档内容之间的相似性。
元数据在初始搜索过程中并不直接参与,它可以在检索后用于过滤或细化结果,从而实现进一步的定制和精确性。
让我们定义一个函数来查询 ChromaDB 数据库。
def query_database(query_text, n_results=10):
results = collection.query(query_texts=query_text, n_results=n_results)
return results
创建语义缓存系统
为了实现缓存系统,我们将使用 Faiss 库,该库允许在内存中存储嵌入。这和 Chroma 做的事情很相似,但没有其持久性。
为此,我们将创建一个名为 semantic_cache
的类,它将使用自己的编码器,并为用户提供执行查询所需的函数。
在这个类中,我们首先查询使用 Faiss 实现的缓存,其中包含以前的请求,如果返回的结果超过了一个指定的阈值,它将返回缓存的内容。否则,它将从 Chroma 数据库获取结果。 缓存存储在一个 .json 文件中。
!pip install -q faiss-cpu==1.8.0
import faiss
from sentence_transformers import SentenceTransformer
import time
import json
下面的 init_cache()
函数初始化了语义缓存。
它使用了 FlatLS 索引,这可能不是最快的,但对于小数据集来说是理想的。如果我们需要根据数据的具体内容和大小来选择缓存(临时存储)数据的方式,我们还可以考虑使用其他的索引方法,比如 HNSW 或 IVF。
我选择这个索引是因为它与示例非常契合。它可以用于高维向量,消耗的内存最少,并且在小数据集上表现良好。
下面概述了 Faiss 可用的各种索引的关键特性。
- FlatL2 或 FlatIP。非常适合小数据集,可能不是最快的,但其内存消耗并不过分。
- LSH。它在小数据集上工作效果很好,并且推荐用于最多 128 维的向量。
- HNSW。非常快,但需要大量的 RAM。
- IVF。在大数据集上工作良好,而且不会消耗太多内存或影响性能。
关于 Faiss 可用的不同索引的更多信息可以在以下链接中找到:https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
def init_cache():
index = faiss.IndexFlatL2(768)
if index.is_trained:
print("Index trained")
# Initialize Sentence Transformer model
encoder = SentenceTransformer("all-mpnet-base-v2")
return index, encoder
在 retrieve_cache
函数中,.json 文件从磁盘中被检索出来,以便在需要跨会话重用缓存时使用。
def retrieve_cache(json_file):
try:
with open(json_file, "r") as file:
cache = json.load(file)
except FileNotFoundError:
cache = {"questions": [], "embeddings": [], "answers": [], "response_text": []}
return cache
store_cache
函数将包含缓存数据的文件保存到磁盘上。
def store_cache(json_file, cache):
with open(json_file, "w") as file:
json.dump(cache, file)
这些函数将在 SemanticCache
类中使用,该类包括搜索函数及其初始化函数。
尽管 ask
函数的代码量相当大,但它的目的非常直接。它在缓存中查找与用户刚刚提出的问题最接近的问题。
然后,检查它是否在指定的阈值内。如果是肯定的,它直接从缓存中返回响应;否则,它调用 query_database
函数从 ChromaDB 检索数据。
我使用了欧几里得距离而不是广泛应用于向量比较的余弦距离。这个选择是基于欧几里得距离是 Faiss 默认使用的度量标准。尽管也可以计算余弦距离,但这样做会增加复杂性,可能不会显著有助于最终结果。
class semantic_cache:
def __init__(self, json_file="cache_file.json", thresold=0.35):
# Initialize Faiss index with Euclidean distance
self.index, self.encoder = init_cache()
# Set Euclidean distance threshold
# a distance of 0 means identicals sentences
# We only return from cache sentences under this thresold
self.euclidean_threshold = thresold
self.json_file = json_file
self.cache = retrieve_cache(self.json_file)
def ask(self, question: str) -> str:
# Method to retrieve an answer from the cache or generate a new one
start_time = time.time()
try:
# First we obtain the embeddings corresponding to the user question
embedding = self.encoder.encode([question])
# Search for the nearest neighbor in the index
self.index.nprobe = 8
D, I = self.index.search(embedding, 1)
if D[0] >= 0:
if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
row_id = int(I[0][0])
print("Answer recovered from Cache. ")
print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
print(f"response_text: " + self.cache["response_text"][row_id])
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.3f} seconds")
return self.cache["response_text"][row_id]
# Handle the case when there are not enough results
# or Euclidean distance is not met, asking to chromaDB.
answer = query_database([question], 1)
response_text = answer["documents"][0][0]
self.cache["questions"].append(question)
self.cache["embeddings"].append(embedding[0].tolist())
self.cache["answers"].append(answer)
self.cache["response_text"].append(response_text)
print("Answer recovered from ChromaDB. ")
print(f"response_text: {response_text}")
self.index.add(embedding)
store_cache(self.json_file, self.cache)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.3f} seconds")
return response_text
except Exception as e:
raise RuntimeError(f"Error during 'ask' method: {e}")
测试 semantic_cache 类。
>>> # Initialize the cache.
>>> cache = semantic_cache("4cache.json")
Index trained
>>> results = cache.ask("How do vaccines work?")
Answer recovered from ChromaDB. response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children. Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity. Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune. NIH: National Institute of Allergy and Infectious Diseases Time taken: 0.057 seconds
正如预期的那样,这个响应是从 ChromaDB 获取的。然后,该类将其存储在缓存中。
现在,如果我们发送一个完全不同的问题,响应也应该从 ChromaDB 中检索。这是因为先前存储的问题与当前问题如此不同,以至于它在欧几里得距离方面会超过指定的阈值。
>>> results = cache.ask("Explain briefly what is a Sydenham chorea")
Answer recovered from ChromaDB. response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations. Time taken: 0.082 seconds
完美,语义缓存系统正如预期那样运行。
让我们继续用一个非常类似于我们刚才问的问题来测试它。
在这种情况下,响应应该直接来自缓存,而不需要访问 ChromaDB 数据库。
>>> results = cache.ask("Briefly explain me what is a Sydenham chorea.")
Answer recovered from Cache. 0.028 smaller than 0.35 Found cache in row: 1 with score 0.028 response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations. Time taken: 0.019 seconds
这两个问题非常相似,它们的欧几里得距离非常小,几乎就像它们是相同的。
现在,让我们尝试另一个问题,这次稍微有些不同,观察系统的表现。
>>> question_def = "Write in 20 words what is a Sydenham chorea."
>>> results = cache.ask(question_def)
Answer recovered from Cache. 0.228 smaller than 0.35 Found cache in row: 1 with score 0.228 response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations. Time taken: 0.016 seconds
我们观察到欧几里得距离已经增加,但它仍然在指定的阈值范围内。因此,它继续直接从缓存中返回响应。
加载模型并创建提示
是时候使用 transformers 库了,这是hugging face 最著名的库,用于处理语言模型。
我们将导入:
- Autotokenizer:这是一个实用程序类,用于标记化与各种预训练语言模型兼容的文本输入。
- AutoModelForCausalLM:它提供了一个接口,用于预训练的语言模型,特别适用于使用因果语言建模(例如,GPT 模型)的语言生成任务,或者是这个 Notebook 中使用的模型 Gemma-2b-it。 请随意测试 不同的模型,你需要搜索训练用于文本生成的 NLP 模型。
!pip install torch
from torch import cuda, torch
# In a MAC Silicon the device must be 'mps'
# device = torch.device('mps') #to use with MAC Silicon
device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu"
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16)
创建扩展提示
为了创建提示,我们使用从查询 ‘semantic_cache’ 类得到的结果以及用户提出的问题。
提示有两部分,相关上下文是从数据库中恢复的信息,以及用户的问题。
我们只需要将这两部分放在一起来创建提示,然后将其发送给模型。
prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")
现在剩下的就是将提示发送给模型,等待它的响应!
>>> outputs = model.generate(**input_ids, max_new_tokens=256)
>>> print(tokenizer.decode(outputs[0]))
Relevant context: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations. The user's question: Write in 20 words what is a Sydenham chorea. Sure, here is a 20-word answer: Sydenham chorea is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS).
结论
在访问 ChromaDB 和直接访问缓存之间,数据检索时间减少了 50%。然而,在更大的项目中,这种差异会增加,导致性能提升达到 90-95%。
我们在 Chroma 中的数据非常少,只有一个缓存类的实例。通常,缓存系统背后的数据要大得多,可能不仅仅是对向量数据库的查询,而是来自各种来源。
通常会有多个缓存类的实例,通常基于用户类型,因为共享共同特征的用户之间的问题往往更容易重复。
总之,我们创建了一个非常简单的 RAG 系统,并通过在用户的问题和获取创建丰富提示所需信息之间增加一个语义缓存层来增强它。
< > Update on GitHub