GeoLLM / KNN_token_old.py
Ciallo0d00's picture
Upload folder using huggingface_hub
badcf3c verified
import json
import numpy as np
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
class EntityLevelRetriever:
def __init__(self, model_name='bert-base-chinese'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入
self.entity_db = []
self.metadata = []
def _get_entity_span(self, text, entity):
"""通过精确匹配获取实体在文本中的位置"""
start = text.find(entity)
if start == -1:
return None
return (start, start + len(entity))
def _generate_entity_embedding(self, text, entity):
"""生成实体级上下文嵌入"""
span = self._get_entity_span(text, entity)
if not span:
return None
inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
outputs = self.model(**inputs)
# 将字符位置转换为token位置
char_to_token = lambda x: inputs.char_to_token(x)
start_token = char_to_token(span[0])
end_token = char_to_token(span[1]-1)
if not start_token or not end_token:
return None
# 提取实体对应的token嵌入并平均
entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy()
return entity_embedding.astype('float32')
def build_index(self, train_path):
"""构建实体索引"""
with open(train_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
dataset = dataset[500:1000]
embeddings = []
meta_info = []
for idx, item in enumerate(dataset):
if idx % 100 == 0:
print(f"处理进度: {idx}/{len(dataset)}")
text = item['text']
for triple in item['triple_list']:
for entity in [triple[0], triple[2]]:
try:
embedding = self._generate_entity_embedding(text, entity)
if embedding is not None:
embeddings.append(embedding)
meta_info.append({
'entity': entity,
'type': triple[1],
'context': text
})
except Exception as e:
print(f"处理实体 {entity} 时出错: {str(e)}")
if embeddings:
self.entity_db = embeddings
self.metadata = meta_info
self.index.add(np.array(embeddings))
print(f"索引构建完成 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}")
print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}")
else:
print("警告:没有有效的实体嵌入被添加到索引中")
def search_entities(self, test_path, top_k=3, batch_size=32):
"""优化的实体检索"""
with open(test_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
results = []
for item_idx, item in enumerate(test_data):
if item_idx % 10 == 0:
print(f"检索进度: {item_idx}/{len(test_data)}")
text = item['text']
entity_results = {}
batch_embeddings = []
batch_entities = []
for triple in item['triple_list']:
for entity in [triple[0], triple[2]]:
embedding = self._generate_entity_embedding(text, entity)
if embedding is not None:
batch_embeddings.append(embedding)
batch_entities.append(entity)
if len(batch_embeddings) >= batch_size:
self._process_batch(batch_embeddings, batch_entities, entity_results, top_k)
batch_embeddings = []
batch_entities = []
# 处理剩余的实体
if batch_embeddings:
self._process_batch(batch_embeddings, batch_entities, entity_results, top_k)
results.append({
'text': text,
'entity_matches': entity_results
})
return results
def _process_batch(self, embeddings, entities, entity_results, top_k):
"""批量处理实体检索"""
distances, indices = self.index.search(np.array(embeddings), top_k)
for idx, (entity, dist, ind) in enumerate(zip(entities, distances, indices)):
neighbors = []
for j, (distance, index) in enumerate(zip(dist, ind)):
if 0 <= index < len(self.metadata):
neighbors.append({
'entity': self.metadata[index]['entity'],
'relation': self.metadata[index]['type'],
'context': self.metadata[index]['context'],
'distance': float(distance)
})
entity_results[entity] = neighbors
# 使用示例
if __name__ == "__main__":
# 初始化检索系统
retriever = EntityLevelRetriever()
# 构建训练索引(约需2-5分钟,取决于数据量)
print("Building training index...")
retriever.build_index('./data/train_triples.json')
# 执行测试检索
print("\nSearching similar entities...")
results = retriever.search_entities('./data/test_triples.json')
# 保存结果
with open('./data/entity_search_results.json', 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# print("检索完成!结果已保存至entity_search_results.json")