import numpy as np import matplotlib.pyplot as plt import numpy as np def find_parent(matrix, node): parents = matrix[:, node] max_parent = np.argmax(parents) if parents[max_parent] > 0: return max_parent else: return None def find_tree(matrix, node, depth=1, children=[], max_depth=1, visited=set()): result = [] parent = find_parent(matrix, node) if parent is not None and parent not in visited: result.append([parent, node]) for child in children: result.append([node, child]) if depth < max_depth: visited.add(node) result.extend(find_tree(matrix, parent, depth + 1, visited=visited)) # 返回结果列表 return result # def find_prob(tree, matrix): prob = 1 for parent, child in tree: prob *= matrix[parent][child] return prob def find_forests(matrix, k): forests = {} for i in range(len(matrix)): children = matrix[i] child_list = [] for j in range(len(children)): if children[j] > 0: child_list.append(j) tree = find_tree(matrix, i, children=child_list) tree = tuple([tuple(x) for x in tree]) if tree: prob = find_prob(tree, matrix) if tuple(tree) in forests: forests[tuple(tree)] += prob else: forests[tuple(tree)] = prob sorted_forests = sorted(forests.items(), key=lambda x: x[1], reverse=True) forest, prob = sorted_forests[0] result = {} # 遍历森林中的每个树形结构 for parent, child in forest: if parent in result: result[parent].append(child) else: result[parent] = [child] return result, prob def passage_outline(matrix,sentences): result, prob = find_forests(matrix, 1) print(result, prob) structure = {} for each in result.keys(): structure[each] =[sentences[i] for i in result[each]] outl = [] outline_list = [] for key in sorted(structure.keys()): outline_list.append(f"主题:") outl.append(f"主题:\n") for sentence in structure[key]: outline_list.append(sentence) outl.append(f"- {sentence}\n") return outl,outline_list if __name__ == "__main__": matrix = np.array([[0.0 ,0.02124888, 0.10647043 ,0.09494194 ,0.0689209 ], [0.01600688 ,0.0 ,0.05879448 ,0.0331325 , 0.0155093 ], [0.01491911 ,0.01652437, 0.0, 0.04714563, 0.04577385], [0.01699071 ,0.0313585 , 0.040299 ,0.0 ,0.014933 ], [0.02308992 ,0.02791895 ,0.06547201, 0.08517842 ,0.0]]) sentences = ["主题句子1", "主题句子2", "主题句子3", "主题句子4", "主题句子5"] print(passage_outline(matrix,sentences)[0])