TSA / outline.py
QINGCHE's picture
fix some bad code
1c529f8
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
import matplotlib.pyplot as plt
def passage_outline(matrix,sentences):
Z = linkage(matrix, method="average")
mask = np.eye(matrix.shape[0], dtype=bool)
matrix = np.ma.masked_array(matrix, mask)
median = np.median(matrix)
labels = fcluster(Z, t=median, criterion="distance")
# 根据簇标签和主题句子生成文章结构
structure = {}
for label, sentence in zip(labels, sentences):
if label not in structure:
structure[label] = []
structure[label].append(sentence)
outline = ""
outline_list = []
for key in sorted(structure.keys()):
outline_list.append(f"主题{key}:")
outline = outline+f"主题{key}:\n"
for sentence in structure[key]:
outline_list.append(sentence)
outline = outline+f"- {sentence}\n"
return outline,outline_list
if __name__ == "__main__":
matrix = np.array([[1.0, 0.8, 0.2, 0.1],
[0.8, 1.0, 0.3, 0.2],
[0.2, 0.3, 1.0, 0.9],
[0.1, 0.2, 0.9, 1.0]])
sentences = ["主题句子1", "主题句子2", "主题句子3", "主题句子4"]
print(passage_outline(matrix,sentences)[0])