import gradio as gr from bertopic import BERTopic from datasets import load_dataset from functools import lru_cache def prep_dataset(): dataset = load_dataset("OpenAssistant/oasst1", split="train") assistant_ds = dataset.filter(lambda x: x["role"] == "assistant") assistant_ds_en = assistant_ds.filter(lambda x: x["lang"] == "en") return assistant_ds_en["text"] topic_model = BERTopic.load("davanstrien/chat_topics") fig = topic_model.visualize_topics() def plot_docs(): docs = prep_dataset() return topic_model.visualize_documents(docs,sample=0.05) def search_topic(text): similar_topics, _ = topic_model.find_topics(text, top_n=5) topic_info = topic_model.get_topic_info() return topic_info[topic_info["Topic"].isin(similar_topics)] def plot_topic_words(num_topics=9, n_words=5): return topic_model.visualize_barchart(top_n_topics=num_topics, n_words=n_words) with gr.Blocks() as demo: with gr.Tab("Topic words"): topic_number = gr.Slider( minimum=3, maximum=20, value=9, step=1, label="Number of topics" ) plot = gr.Plot(plot_topic_words()) topic_number.change(plot_topic_words, [topic_number], plot) with gr.Tab("Topic search"): text = gr.Textbox(lines=1, label="Search text") df = gr.DataFrame() text.change(search_topic, [text], df) with gr.Tab("Topic distribution"): gr.Plot(fig) with gr.Tab("Doc visualization"): gr.Plot(plot_docs()) demo.launch(debug=True)