GeoLLM / KNN_token.py
Ciallo0d00's picture
Upload folder using huggingface_hub
badcf3c verified
import json
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
from collections import defaultdict
class EntityLevelRetriever:
def __init__(self, model_name='bert-base-chinese'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入
self.entity_db = []
self.metadata = []
def _get_entity_span(self, text, entity):
"""通过精确匹配获取实体在文本中的位置"""
start = text.find(entity)
if start == -1:
return None
return (start, start + len(entity))
def _generate_entity_embedding(self, text, entity):
"""生成实体级上下文嵌入"""
# 通过BERT模型获取实体在文本中的上下文表示
# 核心实现:提取实体对应token的嵌入并平均
span = self._get_entity_span(text, entity)
if not span:
return None
inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
# 将字符位置转换为token位置
char_to_token = lambda x: inputs.char_to_token(x)
start_token = char_to_token(span[0])
end_token = char_to_token(span[1]-1)
if not start_token or not end_token:
return None
# 提取实体对应的token嵌入并平均
entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
return entity_embedding.astype('float32')
def build_index(self, train_path):
"""构建实体索引"""
# 关键设计:为每个三元组中的头实体和尾实体分别建立索引
# 存储实体嵌入时同时保存关系类型和上下文信息
with open(train_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
# 仅处理500-1000索引的数据(演示用切片操作)
dataset = dataset[500:1000]
for item in dataset:
text = item['text']
for triple in item['triple_list']:
# 处理头实体和尾实体
for entity in [triple[0], triple[2]]:
embedding = self._generate_entity_embedding(text, entity)
if embedding is not None:
self.entity_db.append(embedding)
self.metadata.append({
'entity': entity,
'type': triple[1], # 保存关系类型
'context': text
})
print(f"实体数量检查 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
self.index.add(np.array(self.entity_db))
print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")
def search_texts(self, test_path, top_k=3, score_mode='weighted'):
"""
基于实体聚合的文本级检索
:param score_mode: 评分模式,可选'simple'(简单累加)/'weighted'(带距离权重)
"""
# 通过以下方式实现实体级到文本级的检索转换
# 1. 对查询文本中的每个实体进行相似搜索
# 2. 聚合多个实体的匹配结果到上下文层面
# 3. 通过加权评分机制综合判断文本相似度
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
results = []
for item in test_data:
text = item['text']
context_scores = defaultdict(float)
context_hits = defaultdict(int)
# 第一阶段:收集所有实体的匹配上下文
for triple in item['triple_list']:
for entity in [triple[0], triple[2]]:
embedding = self._generate_entity_embedding(text, entity)
if embedding is None:
continue
distances, indices = self.index.search(np.array([embedding]), top_k)
for j in range(top_k):
idx = indices[0][j]
if 0 <= idx < len(self.metadata):
ctx_info = self.metadata[idx]
distance = distances[0][j]
# 两种评分模式
if score_mode == 'simple':
context_scores[ctx_info['context']] += 1
elif score_mode == 'weighted':
context_scores[ctx_info['context']] += 1 / (1 + distance)
context_hits[ctx_info['context']] += 1
# 第二阶段:结果归一化处理
scored_contexts = []
for ctx, score in context_scores.items():
# 根据命中次数进行归一化
normalized_score = score / context_hits[ctx] if context_hits[ctx] > 0 else 0
scored_contexts.append((ctx, normalized_score))
# 按分数排序取前top_k
scored_contexts.sort(key=lambda x: x[1], reverse=True)
final_results = [{'context': ctx, 'score': float(score)}
for ctx, score in scored_contexts[:top_k]]
results.append({
'query_text': text,
'matched_texts': final_results,
'total_hits': sum(context_hits.values())
})
return results
# 使用示例
if __name__ == "__main__":
# 初始化检索系统
retriever = EntityLevelRetriever()
# 构建训练索引(约需2-5分钟,取决于数据量)
print("Building training index...")
retriever.build_index('./data/train_triples.json')
# 执行测试检索
print("\nSearching similar entities...")
# 执行改进后的检索
text_results = retriever.search_texts('./data/GT_500.json', top_k=3)
# 保存结果
with open('./data/text_retrieval_results.json', 'w', encoding='utf-8') as f:
json.dump(text_results, f, ensure_ascii=False, indent=2)
print("text_retrieval_results.json")