sentcluster / app.py
strongeryongchao's picture
Update app.py
963a084 verified
import gradio as gr
import pandas as pd
import json
import io
import os
import random
from collections import defaultdict
from sentence_transformers import SentenceTransformer
import hdbscan
from sklearn.metrics import silhouette_score, davies_bouldin_score
import numpy as np
import umap
from sklearn.preprocessing import MinMaxScaler
# 加载模型,放到全局避免重复加载
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def color_for_label(label):
try:
label_int = int(label)
except:
label_int = -1
if label_int < 0:
return "rgb(150,150,150)" # 噪声点
random.seed(label_int + 1000)
return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})"
def cluster_sentences(sentences):
embeddings = model.encode(sentences)
clusterer = hdbscan.HDBSCAN(min_cluster_size=2, metric='euclidean')
labels = clusterer.fit_predict(embeddings)
valid_idxs = labels != -1
if np.sum(valid_idxs) > 1:
silhouette = silhouette_score(embeddings[valid_idxs], labels[valid_idxs])
db = davies_bouldin_score(embeddings[valid_idxs], labels[valid_idxs])
else:
silhouette, db = -1, -1
return labels, embeddings, {"silhouette": silhouette, "db": db}
def generate_force_graph(sentences, labels):
nodes = []
links = []
label_map = defaultdict(list)
for i, (s, l) in enumerate(zip(sentences, labels)):
color = color_for_label(l)
nodes.append({"name": s, "symbolSize": 10, "category": int(l) if l >=0 else 0, "itemStyle": {"color": color}})
label_map[l].append(i)
for group in label_map.values():
max_edges_per_node = 10
for i in group:
connected = 0
for j in group:
if i < j:
links.append({"source": sentences[i], "target": sentences[j]})
connected += 1
if connected >= max_edges_per_node:
break
return {"type": "force", "nodes": nodes, "links": links}
def generate_bubble_chart(sentences, labels):
counts = defaultdict(int)
for l in labels:
counts[l] += 1
data = [{"name": f"簇{l}" if l >=0 else "噪声", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()]
return {"type": "bubble", "series": [{"type": "scatter", "data": data}]}
def generate_umap_plot(embeddings, labels):
reducer = umap.UMAP(n_components=2, random_state=42)
umap_emb = reducer.fit_transform(embeddings)
scaled = MinMaxScaler().fit_transform(umap_emb)
data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}} for (x, y), l in zip(scaled, labels)]
return {"type": "scatter", "series": [{"data": data}]}
def process(text_input, file_obj):
# 先收集所有句子
sentences = []
# 读取txt文件内容
if file_obj is not None:
try:
# file_obj 是 tempfile.NamedTemporaryFile,直接打开它的 file_obj.name
with open(file_obj.name, "r", encoding="utf-8") as f:
content = f.read()
lines = content.strip().splitlines()
sentences.extend([line.strip() for line in lines if line.strip()])
except Exception as e:
return f"❌ 文件读取失败: {str(e)}", None, None, None, None, None, None
# 处理文本框输入
if text_input:
lines = text_input.strip().splitlines()
sentences.extend([line.strip() for line in lines if line.strip()])
# 去重
sentences = list(dict.fromkeys(sentences))
if len(sentences) < 2:
return "⚠️ 请输入至少两个有效句子进行聚类", None, None, None, None, None, None
# 聚类
labels, embeddings, scores = cluster_sentences(sentences)
# 生成数据
df = pd.DataFrame({"句子": sentences, "簇ID": labels})
force_json = generate_force_graph(sentences, labels)
bubble_json = generate_bubble_chart(sentences, labels)
umap_json = generate_umap_plot(embeddings, labels)
csv_data = df.to_csv(index=False, encoding="utf-8-sig")
return (
f"✅ Silhouette: {scores['silhouette']:.4f}, DB: {scores['db']:.4f}",
df,
json.dumps(force_json, ensure_ascii=False, indent=2),
json.dumps(bubble_json, ensure_ascii=False, indent=2),
json.dumps(umap_json, ensure_ascii=False, indent=2),
csv_data
)
def csv_download(csv_str):
return io.BytesIO(csv_str.encode("utf-8-sig"))
with gr.Blocks() as demo:
gr.Markdown("# 中文句子语义聚类 Demo")
with gr.Row():
text_input = gr.Textbox(label="输入多句子(每行一句)", lines=8)
file_input = gr.File(label="上传文本文件 (.txt)", file_types=['.txt'])
btn = gr.Button("开始聚类")
output_score = gr.Textbox(label="聚类指标", interactive=False)
output_table = gr.Dataframe(headers=["句子", "簇ID"], interactive=False)
output_force = gr.JSON(label="力导图数据")
output_bubble = gr.JSON(label="气泡图数据")
output_umap = gr.JSON(label="UMAP二维数据")
output_csv = gr.File(label="导出CSV")
btn.click(
fn=process,
inputs=[text_input, file_input],
outputs=[output_score, output_table, output_force, output_bubble, output_umap, output_csv]
)
output_csv.download = csv_download
demo.launch()