|
import gradio as gr |
|
from gradio_client import Client, handle_file |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import os |
|
import pandas as pd |
|
from io import StringIO, BytesIO |
|
import base64 |
|
import json |
|
import plotly.graph_objects as go |
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN) |
|
|
|
|
|
|
|
def stream_chat_with_rag( |
|
message: str, |
|
history: list, |
|
year: str |
|
): |
|
|
|
|
|
answer, sources = client.predict( |
|
query= message, |
|
election_year=year, |
|
api_name="/process_query" |
|
) |
|
|
|
|
|
response = f"Retrieving the submissions in {year}..." |
|
print("Raw answer from API:") |
|
print(answer) |
|
history.append((message, response +"\n"+ answer)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return answer |
|
|
|
def topic_plot_gener(message: str, year: str): |
|
fig = client.predict( |
|
query= message, |
|
election_year=year, |
|
api_name="/topics_plot_genera" |
|
) |
|
|
|
print(fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plot_json = json.loads(fig['plot']) |
|
|
|
|
|
fig = go.Figure(data=plot_json["data"]) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def heatmap(top_n): |
|
|
|
|
|
pivot_table = client.predict( |
|
top_n= top_n, |
|
api_name="/get_heatmap_pivot_table" |
|
) |
|
print(pivot_table) |
|
print(type(pivot_table)) |
|
""" |
|
pivot_table is a dict like: |
|
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], |
|
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], |
|
['disgust', 26911.0, 123112.0, 64567.0, 46460.0], |
|
['fear', 51466.0, 188898.0, 113174.0, 150578.0], |
|
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], |
|
'metadata': None} |
|
""" |
|
|
|
|
|
|
|
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers']) |
|
df.set_index('Index', inplace=True) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(df, |
|
cmap='YlOrRd', |
|
cbar_kws={'label': 'Weighted Frequency'}, |
|
square=True) |
|
|
|
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency') |
|
plt.xlabel('Topics') |
|
plt.ylabel('Emotions') |
|
plt.xticks(rotation=45, ha='right') |
|
plt.tight_layout() |
|
|
|
return plt.gcf() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linePlot(viz_type, weight, top_n): |
|
|
|
result = client.predict( |
|
viz_type=viz_type, |
|
weight=weight, |
|
top_n=top_n, |
|
api_name="/linePlot_3C1" |
|
) |
|
|
|
|
|
plot_base64 = result[0] |
|
|
|
plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1]) |
|
img = plt.imread(BytesIO(plot_bytes), format='PNG') |
|
plt.figure(figsize = (12, 2*top_n), dpi = 150) |
|
plt.imshow(img) |
|
plt.axis('off') |
|
plt.show() |
|
return plt.gcf(), result[1] |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Reddit Election Analysis") as demo: |
|
gr.Markdown("# Reddit Public sentiment & Social topic distribution ") |
|
with gr.Row(): |
|
with gr.Column(): |
|
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10]) |
|
fresh_btn = gr.Button("Refresh Heatmap") |
|
|
|
with gr.Column(): |
|
|
|
|
|
output_heatmap = gr.Plot( |
|
label="Top Public sentiment & Social topic Heatmap", |
|
container=True, |
|
elem_classes="heatmap-plot" |
|
) |
|
gr.Markdown("# Get the time series of the Public sentiment & Social topic") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix']) |
|
|
|
weight_slider = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
step=0.1, |
|
label="Weight (Score vs. Frequency)" |
|
) |
|
|
|
top_n_slider = gr.Slider( |
|
minimum=2, |
|
maximum=10, |
|
value=5, |
|
step=1, |
|
label="Top N Items" |
|
) |
|
|
|
|
|
|
|
viz_dropdown = gr.Dropdown( |
|
choices=["emotions", "topics", "grid"], |
|
value="emotions", |
|
label="Visualization Type", |
|
info="Select the type of visualization to display" |
|
) |
|
linePlot_btn = gr.Button("Update Visualizations") |
|
linePlot_status_text = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.Column(scale=3): |
|
time_series_fig = gr.Plot() |
|
|
|
gr.Markdown("# Reddit Election Posts/Comments Analysis") |
|
gr.Markdown("Ask questions about election-related comments and posts") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale = 1): |
|
year_selector = gr.Radio( |
|
choices=["2016 Election", "2024 Election", "Comparison two years"], |
|
label="Select Election Year", |
|
value="2024 Election" |
|
) |
|
slider = gr.Slider(50, 500, render=False, label= "Tokens") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
## Example Questions: |
|
- Is there any comments don't like the election results |
|
- Summarize the main discussions about voting process |
|
- What're the common opinions about candidates? |
|
- What're common opinions about immigrant topic? |
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale = 2): |
|
gr.ChatInterface(stream_chat_with_rag, |
|
type="messages", |
|
|
|
additional_inputs = [year_selector] |
|
) |
|
|
|
gr.Markdown("## Top words of the relevant Q&A") |
|
with gr.Row(): |
|
with gr.Column(scale = 1): |
|
query_input = gr.Textbox( |
|
label="Your Question For Topicalize", |
|
placeholder="Copy and past your question there to vilaulize the top words of relevant topic" |
|
) |
|
topic_btn = gr.Button("Topicalize the RAG sources") |
|
with gr.Column(scale = 2): |
|
topic_plot = gr.Plot( |
|
label="Top Words Distribution", |
|
container=True, |
|
elem_classes="topic-plot" |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<style> |
|
.heatmap-plot { |
|
min-height: 400px; |
|
width: 100%; |
|
margin: auto; |
|
} |
|
.topic-plot { |
|
min-width: 600px; |
|
height: 100%; |
|
margin: auto; |
|
} |
|
</style> |
|
""") |
|
|
|
|
|
|
|
|
|
fresh_btn.click( |
|
fn=heatmap, |
|
inputs=top_n, |
|
outputs=output_heatmap |
|
) |
|
|
|
linePlot_btn.click( |
|
fn = linePlot, |
|
inputs = [viz_dropdown,weight_slider,top_n_slider], |
|
outputs = [time_series_fig, linePlot_status_text] |
|
) |
|
|
|
|
|
topic_btn.click( |
|
fn= topic_plot_gener, |
|
inputs=[query_input, year_selector], |
|
outputs= topic_plot |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |