File size: 2,889 Bytes
fdffdf0
 
8ba144e
fdffdf0
8ba144e
 
 
 
 
 
 
fdffdf0
 
8ba144e
 
 
 
 
 
 
 
 
 
 
 
1c529f8
bff547d
8ba144e
 
 
 
 
fdffdf0
8ba144e
 
 
 
 
 
 
 
 
bff547d
8ba144e
 
bff547d
 
8ba144e
bff547d
8ba144e
 
 
 
 
bff547d
8ba144e
 
 
 
 
 
fdffdf0
8ba144e
 
fdffdf0
8ba144e
 
02d932f
fdffdf0
 
8ba144e
02d932f
fdffdf0
 
02d932f
 
1c529f8
8ba144e
 
 
 
 
 
1c529f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

def find_parent(matrix, node):
    parents = matrix[:, node]
    max_parent = np.argmax(parents)
    if parents[max_parent] > 0:
        return max_parent
    else:
        return None


def find_tree(matrix, node, depth=1, children=[], max_depth=1, visited=set()):
    result = []
    parent = find_parent(matrix, node)
    if parent is not None and parent not in visited:
        result.append([parent, node])
        for child in children:
            result.append([node, child])
        if depth < max_depth:
            visited.add(node)
            result.extend(find_tree(matrix, parent, depth + 1, visited=visited))
    # 返回结果列表
    return result

# 
def find_prob(tree, matrix):
    prob = 1
    for parent, child in tree:
        prob *= matrix[parent][child]
    return prob

def find_forests(matrix, k):
    forests = {}
    for i in range(len(matrix)):
        children = matrix[i]
        child_list = []
        for j in range(len(children)):
            if children[j] > 0:
                child_list.append(j)
        tree = find_tree(matrix, i, children=child_list)
        tree = tuple([tuple(x) for x in tree]) 
        if tree:
            prob = find_prob(tree, matrix)
            if tuple(tree) in forests: 
                forests[tuple(tree)] += prob  
            else:
                forests[tuple(tree)] = prob  
    sorted_forests = sorted(forests.items(), key=lambda x: x[1], reverse=True)
    forest, prob = sorted_forests[0]
    result = {}
    # 遍历森林中的每个树形结构
    for parent, child in forest:
    
        if parent in result:
            result[parent].append(child)
        else:
            result[parent] = [child]
    return result, prob
def passage_outline(matrix,sentences):

    result, prob = find_forests(matrix, 1)
    print(result, prob)
    structure = {}
    for each in result.keys():
        structure[each] =[sentences[i] for i in result[each]]
    outl = []
    outline_list = []
    for key in sorted(structure.keys()):
        outline_list.append(f"主题:")
        outl.append(f"主题:\n")
        for sentence in structure[key]:
            outline_list.append(sentence)
            outl.append(f"- {sentence}\n")
    return outl,outline_list
if __name__ == "__main__":
    matrix = np.array([[0.0 ,0.02124888, 0.10647043 ,0.09494194 ,0.0689209 ],
                        [0.01600688 ,0.0 ,0.05879448 ,0.0331325 , 0.0155093 ],
                        [0.01491911 ,0.01652437, 0.0, 0.04714563, 0.04577385],
                        [0.01699071 ,0.0313585 , 0.040299   ,0.0 ,0.014933  ],
                        [0.02308992 ,0.02791895 ,0.06547201, 0.08517842 ,0.0]])
    sentences = ["主题句子1", "主题句子2", "主题句子3", "主题句子4", "主题句子5"]
    print(passage_outline(matrix,sentences)[0])