Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| import faiss | |
| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| class EntityLevelRetriever: | |
| def __init__(self, model_name='bert-base-chinese'): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.index = faiss.IndexFlatL2(768) # L2距离更适合BERT嵌入 | |
| self.entity_db = [] | |
| self.metadata = [] | |
| def _get_entity_span(self, text, entity): | |
| """通过精确匹配获取实体在文本中的位置""" | |
| start = text.find(entity) | |
| if start == -1: | |
| return None | |
| return (start, start + len(entity)) | |
| def _generate_entity_embedding(self, text, entity): | |
| """生成实体级上下文嵌入""" | |
| span = self._get_entity_span(text, entity) | |
| if not span: | |
| return None | |
| inputs = self.tokenizer(text, return_tensors='pt', truncation=True) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # 将字符位置转换为token位置 | |
| char_to_token = lambda x: inputs.char_to_token(x) | |
| start_token = char_to_token(span[0]) | |
| end_token = char_to_token(span[1]-1) | |
| if not start_token or not end_token: | |
| return None | |
| # 提取实体对应的token嵌入并平均 | |
| entity_embedding = outputs.last_hidden_state[0, start_token:end_token+1].mean(dim=0).numpy() | |
| return entity_embedding.astype('float32') | |
| def build_index(self, train_path): | |
| """构建实体索引""" | |
| with open(train_path, 'r', encoding='utf-8') as f: | |
| dataset = json.load(f) | |
| dataset = dataset[500:1000] | |
| embeddings = [] | |
| meta_info = [] | |
| for idx, item in enumerate(dataset): | |
| if idx % 100 == 0: | |
| print(f"处理进度: {idx}/{len(dataset)}") | |
| text = item['text'] | |
| for triple in item['triple_list']: | |
| for entity in [triple[0], triple[2]]: | |
| try: | |
| embedding = self._generate_entity_embedding(text, entity) | |
| if embedding is not None: | |
| embeddings.append(embedding) | |
| meta_info.append({ | |
| 'entity': entity, | |
| 'type': triple[1], | |
| 'context': text | |
| }) | |
| except Exception as e: | |
| print(f"处理实体 {entity} 时出错: {str(e)}") | |
| if embeddings: | |
| self.entity_db = embeddings | |
| self.metadata = meta_info | |
| self.index.add(np.array(embeddings)) | |
| print(f"索引构建完成 - 向量数: {len(self.entity_db)}, 元数据数: {len(self.metadata)}") | |
| print(f"索引维度: {self.index.d}, 存储数量: {self.index.ntotal}") | |
| else: | |
| print("警告:没有有效的实体嵌入被添加到索引中") | |
| def search_entities(self, test_path, top_k=3, batch_size=32): | |
| """优化的实体检索""" | |
| with open(test_path, 'r', encoding='utf-8') as f: | |
| test_data = json.load(f) | |
| results = [] | |
| for item_idx, item in enumerate(test_data): | |
| if item_idx % 10 == 0: | |
| print(f"检索进度: {item_idx}/{len(test_data)}") | |
| text = item['text'] | |
| entity_results = {} | |
| batch_embeddings = [] | |
| batch_entities = [] | |
| for triple in item['triple_list']: | |
| for entity in [triple[0], triple[2]]: | |
| embedding = self._generate_entity_embedding(text, entity) | |
| if embedding is not None: | |
| batch_embeddings.append(embedding) | |
| batch_entities.append(entity) | |
| if len(batch_embeddings) >= batch_size: | |
| self._process_batch(batch_embeddings, batch_entities, entity_results, top_k) | |
| batch_embeddings = [] | |
| batch_entities = [] | |
| # 处理剩余的实体 | |
| if batch_embeddings: | |
| self._process_batch(batch_embeddings, batch_entities, entity_results, top_k) | |
| results.append({ | |
| 'text': text, | |
| 'entity_matches': entity_results | |
| }) | |
| return results | |
| def _process_batch(self, embeddings, entities, entity_results, top_k): | |
| """批量处理实体检索""" | |
| distances, indices = self.index.search(np.array(embeddings), top_k) | |
| for idx, (entity, dist, ind) in enumerate(zip(entities, distances, indices)): | |
| neighbors = [] | |
| for j, (distance, index) in enumerate(zip(dist, ind)): | |
| if 0 <= index < len(self.metadata): | |
| neighbors.append({ | |
| 'entity': self.metadata[index]['entity'], | |
| 'relation': self.metadata[index]['type'], | |
| 'context': self.metadata[index]['context'], | |
| 'distance': float(distance) | |
| }) | |
| entity_results[entity] = neighbors | |
| # 使用示例 | |
| if __name__ == "__main__": | |
| # 初始化检索系统 | |
| retriever = EntityLevelRetriever() | |
| # 构建训练索引(约需2-5分钟,取决于数据量) | |
| print("Building training index...") | |
| retriever.build_index('./data/train_triples.json') | |
| # 执行测试检索 | |
| print("\nSearching similar entities...") | |
| results = retriever.search_entities('./data/test_triples.json') | |
| # 保存结果 | |
| with open('./data/entity_search_results.json', 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| # print("检索完成!结果已保存至entity_search_results.json") | |