import gensim import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import AutoTokenizer, AutoModel import torch def classify_by_topic(articles, central_topics): # 计算与每个中心主题的相似度,返回一个矩阵 def compute_similarity(articles, central_topics): model = AutoModel.from_pretrained("distilbert-base-multilingual-cased") tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased") def sentence_to_vector(sentence, context): sentence = context[0]+context[1]+sentence*4+context[2]+context[3] tokens = tokenizer.encode_plus( sentence, add_special_tokens=True, return_tensors="pt",max_length = 512,truncation=True) outputs = model(**tokens) hidden_states = outputs.last_hidden_state vector = np.squeeze(torch.mean( hidden_states, dim=1).detach().numpy()) return vector # 获取一个句子的上下文 def get_context(sentences, index): if index == 0: prev_sentence = "" pprev_sentence = "" elif index == 1: prev_sentence = sentences[index-1] pprev_sentence = "" else: prev_sentence = sentences[index-1] pprev_sentence = sentences[index-2] if index == len(sentences) - 1: next_sentence = "" nnext_sentence = "" elif index == len(sentences) - 2: next_sentence = sentences[index+1] nnext_sentence = "" else: next_sentence = sentences[index+1] nnext_sentence = sentences[index+2] return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence) doc_vectors = [sentence_to_vector(sentence, get_context( articles, i)) for i, sentence in enumerate(articles)] topic_vectors = [sentence_to_vector(sentence, get_context( central_topics, i)) for i, sentence in enumerate(central_topics)] # 计算余弦相似度矩阵 cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors) return cos_sim_matrix # 分类文章 def group_by_topic(articles, central_topics, similarity_matrix): group = [] original_articles = articles.copy() for article, similarity in zip(original_articles, similarity_matrix): max_similarity = max(similarity) max_index = similarity.tolist().index(max_similarity) group.append((article, central_topics[max_index])) return group # 实现分类功能 similarity_matrix = compute_similarity(articles, central_topics) groups = group_by_topic(articles, central_topics, similarity_matrix) return groups