Open-Source AI Cookbook documentation
使用向量嵌入和 Qdrant 进行代码搜索
使用向量嵌入和 Qdrant 进行代码搜索
作者:Qdrant 团队
在这个 Notebook 中,我们演示了如何使用向量嵌入来导航代码库并找到相关的代码片段。我们将使用自然语义查询来搜索代码库,并根据相似的逻辑查找代码。
你可以查看这个方法的实时部署,该部署通过一个网页界面提供 Qdrant 代码库的搜索功能。
方法
为了实现我们的目标,我们需要两个模型:
用于自然语言处理(NLP)的通用神经编码器, 在我们的案例中使用的是 sentence-transformers/all-MiniLM-L6-v2 模型。我们将称之为 NLP 模型。
用于代码间相似度搜索的专门嵌入模型。我们将使用 jinaai/jina-embeddings-v2-base-code 模型来完成此任务。它支持英语和 30 种广泛使用的编程语言,且具有 8192 的序列长度。我们将称之为 代码模型。
为了准备 NLP 模型所需的代码,我们需要将代码预处理成一种更接近自然语言的格式。由于代码模型已经支持多种标准编程语言,因此无需对代码片段进行预处理。我们可以直接使用原始代码。
安装依赖项
让我们安装将要使用的包。
- inflection - 一个字符串转换库。它可以将英文单词的复数转为单数,单数转为复数,并将驼峰命名法(CamelCase)转换为下划线分隔的字符串。
- fastembed - 一个轻量级的库,用于生成向量嵌入,优先支持 CPU。支持 GPU。
- qdrant-client - 官方的 Python 库,用于与 Qdrant 服务器进行交互。
%pip install inflection qdrant-client fastembed
数据准备
将应用程序源代码分割成更小的部分是一个复杂的任务。通常,函数、类方法、结构体、枚举以及所有其他语言特定的构造体都是很好的分块候选。它们足够大,能包含一些有意义的信息,但又足够小,适合被嵌入模型处理,因为这些模型有一个有限的上下文窗口。你还可以使用文档字符串(docstrings)、注释和其他元数据来丰富这些块,增加额外的信息。

基于文本的搜索通常是基于函数签名的,但代码搜索可能会返回更小的部分,比如循环。因此,如果我们从 NLP 模型收到一个特定的函数签名,并且从代码模型收到该函数部分实现的代码,我们将合并这些结果。
解析代码库
我们将使用 Qdrant 代码库 来进行演示。虽然该代码库使用的是 Rust,但你可以使用此方法处理任何其他编程语言。你可以使用 语言服务器协议(LSP) 工具来构建代码库的图谱,然后提取代码块。我们使用了 rust-analyzer 完成这项工作。我们将解析后的代码库导出为 LSIF 格式,这是一个用于代码智能的数据标准。接下来,我们利用 LSIF 数据来导航代码库并提取代码块。
你可以对其他编程语言使用相同的方法。市面上有大量的实现可供选择。
接下来,我们将把代码块导出为 JSON 文档,文档中不仅包含代码本身,还会包括代码在项目中的位置信息(上下文)。
你可以在我们的 Google Cloud Storage 存储桶中查看解析为 JSON 格式的 Qdrant 结构,文件名为 structures.jsonl 文件。下载该文件,并将其用作我们的代码搜索数据源。
!wget https://storage.googleapis.com/tutorial-attachments/code-search/structures.jsonl
接下来,加载文件并将每一行解析为字典列表:
import json
structures = []
with open("structures.jsonl", "r") as fp:
for i, row in enumerate(fp):
entry = json.loads(row)
structures.append(entry)
我们来看一下一个条目的结构是怎样的
structures[0]
{'name': 'InvertedIndexRam',
'signature': '# [doc = " Inverted flatten index from dimension id to posting list"] # [derive (Debug , Clone , PartialEq)] pub struct InvertedIndexRam { # [doc = " Posting lists for each dimension flattened (dimension id -> posting list)"] # [doc = " Gaps are filled with empty posting lists"] pub postings : Vec < PostingList > , # [doc = " Number of unique indexed vectors"] # [doc = " pre-computed on build and upsert to avoid having to traverse the posting lists."] pub vector_count : usize , }',
'code_type': 'Struct',
'docstring': '= " Inverted flatten index from dimension id to posting list"',
'line': 15,
'line_from': 13,
'line_to': 22,
'context': {'module': 'inverted_index',
'file_path': 'lib/sparse/src/index/inverted_index/inverted_index_ram.rs',
'file_name': 'inverted_index_ram.rs',
'struct_name': None,
'snippet': '/// Inverted flatten index from dimension id to posting list\n#[derive(Debug, Clone, PartialEq)]\npub struct InvertedIndexRam {\n /// Posting lists for each dimension flattened (dimension id -> posting list)\n /// Gaps are filled with empty posting lists\n pub postings: Vec<PostingList>,\n /// Number of unique indexed vectors\n /// pre-computed on build and upsert to avoid having to traverse the posting lists.\n pub vector_count: usize,\n}\n'}}
代码到自然语言的转换
每种编程语言都有其特定的语法,而这些语法并不是自然语言的一部分。因此,一个通用模型可能无法直接理解代码。然而,我们可以通过去除代码特有的部分并加入额外的上下文(如模块、类、函数和文件名)来规范化数据。我们采取以下步骤:
- 提取函数、方法或其他代码结构的签名。
- 将驼峰命名法(CamelCase)和下划线命名法(snake_case)中的名称拆分为单独的单词。
- 获取文档字符串(docstring)、注释和其他重要的元数据。
- 使用预定义的模板根据提取的数据构建句子。
- 移除特殊字符并用空格替代。
我们现在可以定义 textify
函数,利用 inflection
库来执行我们的转换操作:
import inflection
import re
from typing import Dict, Any
def textify(chunk: Dict[str, Any]) -> str:
# Get rid of all the camel case / snake case
# - inflection.underscore changes the camel case to snake case
# - inflection.humanize converts the snake case to human readable form
name = inflection.humanize(inflection.underscore(chunk["name"]))
signature = inflection.humanize(inflection.underscore(chunk["signature"]))
# Check if docstring is provided
docstring = ""
if chunk["docstring"]:
docstring = f"that does {chunk['docstring']} "
# Extract the location of that snippet of code
context = f"module {chunk['context']['module']} " f"file {chunk['context']['file_name']}"
if chunk["context"]["struct_name"]:
struct_name = inflection.humanize(inflection.underscore(chunk["context"]["struct_name"]))
context = f"defined in struct {struct_name} {context}"
# Combine all the bits and pieces together
text_representation = f"{chunk['code_type']} {name} " f"{docstring}" f"defined as {signature} " f"{context}"
# Remove any special characters and concatenate the tokens
tokens = re.split(r"\W", text_representation)
tokens = filter(lambda x: x, tokens)
return " ".join(tokens)
现在我们可以使用 textify
函数将所有的代码块转换为文本表示
text_representations = list(map(textify, structures))
让我们看看其中一个转换后的自然语言表示是什么样子的:
text_representations[1000]
'Function Hnsw discover precision that does Checks discovery search precision when using hnsw index this is different from the tests in defined as Fn hnsw discover precision module integration file hnsw_discover_test rs'
自然语言嵌入
from fastembed import TextEmbedding
batch_size = 5
nlp_model = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", threads=0)
nlp_embeddings = nlp_model.embed(text_representations, batch_size=batch_size)
代码嵌入
code_snippets = [structure["context"]["snippet"] for structure in structures]
code_model = TextEmbedding("jinaai/jina-embeddings-v2-base-code")
code_embeddings = code_model.embed(code_snippets, batch_size=batch_size)
构建 Qdrant 集合
Qdrant 支持多种部署模式,包括内存模式(用于原型开发)、Docker 和 Qdrant Cloud。你可以参考 安装指南 了解更多信息。
我们将在此教程中使用内存实例继续操作。
提示
内存模式只能用于快速原型开发和测试。它是 Qdrant 服务器方法的 Python 实现。
现在,让我们创建一个集合来存储我们的向量。
from qdrant_client import QdrantClient, models
COLLECTION_NAME = "qdrant-sources"
client = QdrantClient(":memory:") # Use in-memory storage
# client = QdrantClient("http://locahost:6333") # For Qdrant server
client.create_collection(
COLLECTION_NAME,
vectors_config={
"text": models.VectorParams(
size=384,
distance=models.Distance.COSINE,
),
"code": models.VectorParams(
size=768,
distance=models.Distance.COSINE,
),
},
)
我们新创建的集合已经准备好接受数据。接下来,让我们上传嵌入向量:
from tqdm import tqdm
points = []
total = len(structures)
print("Number of points to upload: ", total)
for id, (text_embedding, code_embedding, structure) in tqdm(
enumerate(zip(nlp_embeddings, code_embeddings, structures)), total=total
):
# FastEmbed returns generators. Embeddings are computed as consumed.
points.append(
models.PointStruct(
id=id,
vector={
"text": text_embedding,
"code": code_embedding,
},
payload=structure,
)
)
# Upload points in batches
if len(points) >= batch_size:
client.upload_points(COLLECTION_NAME, points=points, wait=True)
points = []
# Ensure any remaining points are uploaded
if points:
client.upload_points(COLLECTION_NAME, points=points)
print(f"Total points in collection: {client.count(COLLECTION_NAME).count}")
上传的点立即可以用于搜索。接下来,查询集合以找到相关的代码片段。
查询代码库
我们使用其中一个模型通过 Qdrant 的新 查询 API 搜索集合。首先使用文本嵌入。运行以下查询:“如何计算集合中的点数?”。然后查看查询结果。
query = "How do I count points in a collection?"
hits = client.query_points(
COLLECTION_NAME,
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=3,
).points
现在,查看查询结果。以下表格列出了模块、文件名和得分。每一行都包含一个指向签名的链接。
模块 | 文件名 | 得分 | 签名 |
---|---|---|---|
operations | types.rs | 0.5493385 | pub struct CountRequestInternal |
map_index | types.rs | 0.49973965 | fn get_points_with_value_count |
map_index | mutable_map_index.rs | 0.49941066 | pub fn get_points_with_value_count |
看起来我们已经能够找到一些相关的代码结构。接下来,让我们使用代码嵌入(code embeddings)再试一次。
hits = client.query_points(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=3,
).points
输出结果:
模块 | 文件名 | 得分 | 签名 |
---|---|---|---|
field_index | geo_index.rs | 0.7217579 | fn count_indexed_points |
numeric_index | mod.rs | 0.7113214 | fn count_indexed_points |
full_text_index | text_index.rs | 0.6993165 | fn count_indexed_points |
虽然不同模型得到的得分不可直接比较,但我们可以看到,结果是不同的。代码嵌入和文本嵌入能够捕捉到代码库的不同方面。我们可以使用这两种模型查询集合,然后将结果结合起来,以获取最相关的代码片段。
from qdrant_client import models
hits = client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=5,
),
models.Prefetch(
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=5,
),
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
).points
>>> for hit in hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_path"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| operations | lib/collection/src/operations/types.rs | 0.5 | ` # [doc = " Count Request"] # [doc = " Counts the number of points which satisfy the given filter."] # [doc = " If filter is not provided, the count of all points in the collection will be returned."] # [derive (Debug , Deserialize , Serialize , JsonSchema , Validate)] # [serde (rename_all = "snake_case")] pub struct CountRequestInternal { # [doc = " Look only for points which satisfies this conditions"] # [validate] pub filter : Option < Filter > , # [doc = " If true, count exact number of points. If false, count approximate number of points faster."] # [doc = " Approximate count might be unreliable during the indexing process. Default: true"] # [serde (default = "default_exact_count")] pub exact : bool , } ` | | field_index | lib/segment/src/index/field_index/geo_index.rs | 0.5 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.33333334 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | numeric_index | lib/segment/src/index/field_index/numeric_index/mod.rs | 0.33333334 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | lib/segment/src/fixtures/payload_context_fixture.rs | 0.25 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mutable_map_index.rs | 0.25 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | id_tracker | lib/segment/src/id_tracker/simple_id_tracker.rs | 0.2 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.2 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.16666667 | ` fn count_indexed_points (& self) -> usize ` | | field_index | lib/segment/src/index/field_index/stat_tools.rs | 0.16666667 | ` fn number_of_selected_points (points : usize , values : usize) -> usize ` |
这是一个如何融合不同模型结果的示例。在实际场景中,你可能会进行一些重新排序(reranking)和去重(deduplication),以及对结果进行额外的处理。
对结果进行分组
你可以通过按负载属性对搜索结果进行分组,从而改进搜索结果。在我们的例子中,我们可以按模块对结果进行分组。如果我们使用代码嵌入,可能会看到来自 map_index
模块的多个结果。让我们对结果进行分组,并假设每个模块只显示一个结果:
results = client.query_points_groups(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
group_by="context.module",
limit=5,
group_size=1,
)
>>> for group in results.groups:
... for hit in group.hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_name"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| field_index | geo_index.rs | 0.7217579 | ` fn count_indexed_points (& self) -> usize ` | | numeric_index | mod.rs | 0.7113214 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | payload_context_fixture.rs | 0.6993165 | ` fn total_point_count (& self) -> usize ` | | map_index | mod.rs | 0.68385994 | ` fn count_indexed_points (& self) -> usize ` | | full_text_index | text_index.rs | 0.6660142 | ` fn count_indexed_points (& self) -> usize ` |
这就结束了我们的教程。感谢你花时间跟随我们完成这个过程。我们刚刚开始探索使用向量嵌入的可能性以及如何改进它。请随意进行实验,或许你能构建出非常酷的东西!如果你有任何创意,欢迎与我们分享 🙏 我们的联系方式在 这里。
< > Update on GitHub