File size: 2,955 Bytes
e350168 cbc1d23 e350168 8ba144e e350168 fdffdf0 e350168 8ba144e fdffdf0 e350168 fdffdf0 e350168 fdffdf0 e350168 fdffdf0 e350168 fdffdf0 e350168 fdffdf0 e350168 fdffdf0 e350168 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gensim
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
import torch
def classify_by_topic(articles, central_topics):
# 计算与每个中心主题的相似度,返回一个矩阵
def compute_similarity(articles, central_topics):
model = AutoModel.from_pretrained("distilbert-base-multilingual-cased")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
def sentence_to_vector(sentence, context):
sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
tokens = tokenizer.encode_plus(
sentence, add_special_tokens=True, return_tensors="pt",max_length = 512,truncation=True)
outputs = model(**tokens)
hidden_states = outputs.last_hidden_state
vector = np.squeeze(torch.mean(
hidden_states, dim=1).detach().numpy())
return vector
# 获取一个句子的上下文
def get_context(sentences, index):
if index == 0:
prev_sentence = ""
pprev_sentence = ""
elif index == 1:
prev_sentence = sentences[index-1]
pprev_sentence = ""
else:
prev_sentence = sentences[index-1]
pprev_sentence = sentences[index-2]
if index == len(sentences) - 1:
next_sentence = ""
nnext_sentence = ""
elif index == len(sentences) - 2:
next_sentence = sentences[index+1]
nnext_sentence = ""
else:
next_sentence = sentences[index+1]
nnext_sentence = sentences[index+2]
return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)
doc_vectors = [sentence_to_vector(sentence, get_context(
articles, i)) for i, sentence in enumerate(articles)]
topic_vectors = [sentence_to_vector(sentence, get_context(
central_topics, i)) for i, sentence in enumerate(central_topics)]
# 计算余弦相似度矩阵
cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)
return cos_sim_matrix
# 分类文章
def group_by_topic(articles, central_topics, similarity_matrix):
group = []
original_articles = articles.copy()
for article, similarity in zip(original_articles, similarity_matrix):
max_similarity = max(similarity)
max_index = similarity.tolist().index(max_similarity)
group.append((article, central_topics[max_index]))
return group
# 实现分类功能
similarity_matrix = compute_similarity(articles, central_topics)
groups = group_by_topic(articles, central_topics, similarity_matrix)
return groups
|