2catycm's picture
初步结果
0f4db48
from typing import Dict
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import hypernetx as hnx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import time
from utils.data_processor import load_data, process_data, build_hyperedges
from utils.visualizer import visualize_gmm, visualize_ratings
from utils.streamlit_hypergraph import hypergraph_visualization_component
def main():
st.title("NeurIPS 2024 Bench Paper 高斯混合聚类分析")
# 自动播放
slider_max = 10
if 'play_state' not in st.session_state:
st.session_state.play_state = False
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
# 定义回调函数来切换播放状态
def toggle_play():
if not st.session_state.play_state and st.session_state.iteration == slider_max:
st.session_state.iteration = 0 # 重置迭代次数
st.session_state.play_state = not st.session_state.play_state
# 创建播放/暂停按钮
if st.session_state.play_state:
button_label = "暂停"
else:
button_label = "开始拟合"
st.button(button_label, on_click=toggle_play, key="play_button")
# 播放速度
# speed = st.slider("播放速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
# 主页面布局
# 显示迭代次数滑条
iteration = st.slider("迭代步骤", min_value=1, max_value=slider_max,
value=st.session_state.iteration, step=1,
key="iteration_slider")
# st.write(f"当前迭代次数: {iteration}")
# print(st.session_state.iteration)
# 动态限制采样数量的最大值
df = load_data()
# 使用 sidebar 控制参数
with st.sidebar:
st.header("控制面板")
speed = st.slider("拟合速度", min_value=0.1, max_value=2.0, value=1.0, step=0.1, key="speed_slider")
draw_width = st.slider("绘图宽度", min_value=3, max_value=20, value=6, step=1, key="draw_width")
draw_height = st.slider("绘图高度",min_value=3, max_value=20, value=6, step=1, key="draw_height")
max_samples = len(df)
num_samples = st.slider("选择采样论文数量", min_value=1,
max_value=min(100, max_samples), value=min(10, max_samples), step=1)
# 添加复选框选择显示 paper 的属性
display_attribute = st.selectbox(
"选择显示 paper 的属性",
["order", "index", "id", "title", "keywords", "author"]
)
# 选择是 top k 还是 top p
display_option = st.selectbox(
"选择显示的选项",
["Top K Clusters", "Clusters Up To Probability P"]
)
# Top K Clusters
if display_option == "Top K Clusters":
max_k = 5
top_k = st.slider("选择 K 值", min_value=1, max_value=max_k,
value=1, step=1)
top_p = None
else:
top_k = None
top_p = st.slider("选择 P 值", min_value=0.0, max_value=1.0, value=0.5, step=0.01)
# 处理数据
sampled_df, probabilities, paper_attributes = process_data(df, iteration, num_samples)
# print(display_attribute) # 字符串
hyperedges = build_hyperedges(probabilities, paper_attributes, display_attribute, top_k=top_k, top_p=top_p)
hypergraph = hnx.Hypergraph(hyperedges)
# print(hyperedges)
show_hypergraph = st.checkbox("显示超图", value=True, key="show_hyperedges")
show_gaussian = st.checkbox("显示高斯分布", value=False, key="show_gaussian")
if show_hypergraph:
hypergraph_visualization_component(hypergraph, draw_width, draw_height)
if show_gaussian:
st.header("高斯混合分布聚类结果")
fig_gmm = visualize_gmm(sampled_df, iteration)
st.plotly_chart(fig_gmm, use_container_width=True)
# 显示采样论文的详细信息
st.header("采样论文详细信息")
st.dataframe(sampled_df[["index", "title", "keywords", "rating_avg", "confidence_avg", "site"]
]
# .style.highlight_max(axis=0)
)
# 增加第二种可视化方式
# st.header("论文评分分布")
# fig_bar = visualize_ratings(sampled_df)
# st.plotly_chart(fig_bar, use_container_width=True)
# 自动播放功能
# print(st.session_state.play_state)
if st.session_state.play_state:
# 使用空容器来显示进度
progress_container = st.empty()
with st.spinner("正在播放..."):
if st.session_state.iteration < slider_max:
# 增加滑动条值
st.session_state.iteration += 1
st.write(f"当前迭代次数: {st.session_state.iteration}")
# print(st.session_state.iteration)
# 等待一小段时间模拟滑动过程
time.sleep(1/speed) # 根据速度调整等待时间
# 使用rerun来更新页面
st.rerun()
else:
# 到达最大值时停止播放
st.session_state.play_state = False
# if __name__ == "__main__":
# # 设置页面布局
# st.set_page_config(layout="wide")
# # 运行主函数
main()