File size: 2,955 Bytes
e350168
 
 
 
 
 
 
 
 
 
cbc1d23
e350168
 
 
8ba144e
e350168
 
fdffdf0
e350168
 
8ba144e
fdffdf0
e350168
 
fdffdf0
e350168
fdffdf0
e350168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdffdf0
e350168
 
 
 
fdffdf0
e350168
 
fdffdf0
e350168
fdffdf0
 
 
e350168
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gensim
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel
import torch


def classify_by_topic(articles, central_topics):

    # 计算与每个中心主题的相似度,返回一个矩阵
    def compute_similarity(articles, central_topics):

        model = AutoModel.from_pretrained("distilbert-base-multilingual-cased")
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")

        def sentence_to_vector(sentence, context):

            sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
            tokens = tokenizer.encode_plus(
                sentence, add_special_tokens=True, return_tensors="pt",max_length = 512,truncation=True)

            outputs = model(**tokens)
            hidden_states = outputs.last_hidden_state

            vector = np.squeeze(torch.mean(
                hidden_states, dim=1).detach().numpy())
            return vector

        # 获取一个句子的上下文
        def get_context(sentences, index):
            if index == 0:
                prev_sentence = ""
                pprev_sentence = ""
            elif index == 1:
                prev_sentence = sentences[index-1]
                pprev_sentence = ""
            else:
                prev_sentence = sentences[index-1]
                pprev_sentence = sentences[index-2]
            if index == len(sentences) - 1:
                next_sentence = ""
                nnext_sentence = ""
            elif index == len(sentences) - 2:
                next_sentence = sentences[index+1]
                nnext_sentence = ""
            else:
                next_sentence = sentences[index+1]
                nnext_sentence = sentences[index+2]
            return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)

        doc_vectors = [sentence_to_vector(sentence, get_context(
            articles, i)) for i, sentence in enumerate(articles)]
        topic_vectors = [sentence_to_vector(sentence, get_context(
            central_topics, i)) for i, sentence in enumerate(central_topics)]
        # 计算余弦相似度矩阵
        cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)

        return cos_sim_matrix

    # 分类文章
    def group_by_topic(articles, central_topics, similarity_matrix):
        group = []
        original_articles = articles.copy()
        for article, similarity in zip(original_articles, similarity_matrix):
            max_similarity = max(similarity)
            max_index = similarity.tolist().index(max_similarity)

            group.append((article, central_topics[max_index]))

        return group

    # 实现分类功能
    similarity_matrix = compute_similarity(articles, central_topics)
    groups = group_by_topic(articles, central_topics, similarity_matrix)

    return groups