|
import gradio as gr |
|
from gradio_client import Client, handle_file |
|
import seaborn as sns |
|
import matplotlib.pyplot as plt |
|
import os |
|
import pandas as pd |
|
from io import StringIO |
|
from linePlot import plot_stacked_time_series, plot_emotion_topic_grid |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
def stream_chat_with_rag( |
|
message: str, |
|
|
|
client_name: str |
|
): |
|
|
|
|
|
answer, fig = client.predict( |
|
query= message, |
|
election_year=client_name, |
|
api_name="/process_query" |
|
) |
|
|
|
|
|
print("Raw answer from API:") |
|
print(answer) |
|
print("top works from API:") |
|
print(fig) |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
|
def heatmap(top_n): |
|
|
|
|
|
pivot_table = client.predict( |
|
top_n= top_n, |
|
api_name="/get_heatmap_pivot_table" |
|
) |
|
print(pivot_table) |
|
print(type(pivot_table)) |
|
""" |
|
pivot_table is a dict like: |
|
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], |
|
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], |
|
['disgust', 26911.0, 123112.0, 64567.0, 46460.0], |
|
['fear', 51466.0, 188898.0, 113174.0, 150578.0], |
|
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], |
|
'metadata': None} |
|
""" |
|
|
|
|
|
|
|
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers']) |
|
df.set_index('Index', inplace=True) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(df, |
|
cmap='YlOrRd', |
|
cbar_kws={'label': 'Weighted Frequency'}, |
|
square=True) |
|
|
|
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency') |
|
plt.xlabel('Topics') |
|
plt.ylabel('Emotions') |
|
plt.xticks(rotation=45, ha='right') |
|
plt.tight_layout() |
|
|
|
return plt.gcf() |
|
|
|
def linePlot_time_series(viz_type, weight, top_n): |
|
result = client.predict( |
|
viz_type=viz_type, |
|
weight=weight, |
|
top_n=top_n, |
|
api_name="/linePlot_time_series" |
|
) |
|
|
|
print("============== timeseries df transfer from pivate to public ===============") |
|
print(result) |
|
print(type(result)) |
|
|
|
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers']) |
|
|
|
|
|
return df |
|
|
|
|
|
def update_visualization(viz_type, weight, top_n): |
|
""" |
|
Update visualization based on user inputs and selected visualization type |
|
|
|
Parameters: |
|
----------- |
|
viz_type : str |
|
Type of visualization to show ('emotions', 'topics', or 'grid') |
|
weight : float |
|
Weight for scoring (0-1) |
|
top_n : int |
|
Number of top items to show |
|
""" |
|
try: |
|
|
|
|
|
series = linePlot_time_series(viz_type, weight, top_n) |
|
if viz_type == "emotions": |
|
|
|
|
|
fig = plot_stacked_time_series( |
|
series, |
|
f'Top {top_n} Emotions Popularity' |
|
) |
|
message = "Emotion time series updated" |
|
|
|
elif viz_type == "topics": |
|
|
|
|
|
fig = plot_stacked_time_series( |
|
series, |
|
f'Top {top_n} Topics Popularity' |
|
) |
|
message = "Topic time series updated" |
|
|
|
else: |
|
|
|
|
|
fig = plot_emotion_topic_grid(series, top_n) |
|
message = "Emotion-Topic grid updated" |
|
|
|
return fig, message |
|
|
|
except Exception as e: |
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Reddit Election Analysis") as demo: |
|
gr.Markdown("# Reddit Public sentiment & Social topic distribution ") |
|
with gr.Row(): |
|
with gr.Column(): |
|
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10]) |
|
fresh_btn = gr.Button("Refresh Heatmap") |
|
|
|
with gr.Column(): |
|
|
|
|
|
output_heatmap = gr.Plot( |
|
label="Top Public sentiment & Social topic Heatmap", |
|
container=True, |
|
elem_classes="heatmap-plot" |
|
) |
|
gr.Markdown("# Get the time series of the Public sentiment & Social topic") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix']) |
|
|
|
weight_slider = gr.Slider( |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
step=0.1, |
|
label="Weight (Score vs. Frequency)" |
|
) |
|
|
|
top_n_slider = gr.Slider( |
|
minimum=2, |
|
maximum=10, |
|
value=5, |
|
step=1, |
|
label="Top N Items" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
viz_dropdown = gr.Dropdown( |
|
choices=["emotions", "topics", "grid"], |
|
value="emotions", |
|
label="Visualization Type", |
|
info="Select the type of visualization to display" |
|
) |
|
linePlot_btn = gr.Button("Update Visualizations") |
|
linePlot_status_text = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.Column(scale=3): |
|
time_series_fig = gr.Plot() |
|
|
|
gr.Markdown("# Reddit Election Posts/Comments Analysis") |
|
gr.Markdown("Ask questions about election-related comments and posts") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
year_selector = gr.Radio( |
|
choices=["2016 Election", "2024 Election", "Comparison two years"], |
|
label="Select Election Year", |
|
value="2016 Election" |
|
) |
|
|
|
query_input = gr.Textbox( |
|
label="Your Question", |
|
placeholder="Ask about election comments or posts..." |
|
) |
|
|
|
submit_btn = gr.Button("Submit") |
|
|
|
gr.Markdown(""" |
|
## Example Questions: |
|
- Is there any comments don't like the election results |
|
- Summarize the main discussions about voting process |
|
- What are the common opinions about candidates? |
|
""") |
|
with gr.Column(): |
|
output_text = gr.Textbox( |
|
label="Response", |
|
lines=20 |
|
) |
|
|
|
gr.Markdown("## Top works of the relevant Q&A") |
|
with gr.Row(): |
|
output_plot = gr.Plot( |
|
label="Topic Distribution", |
|
container=True, |
|
elem_classes="topic-plot" |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<style> |
|
.topic-plot { |
|
min-height: 600px; |
|
width: 100%; |
|
margin: auto; |
|
} |
|
.heatmap-plot { |
|
min-height: 400px; |
|
width: 100%; |
|
margin: auto; |
|
} |
|
</style> |
|
""") |
|
|
|
|
|
|
|
|
|
fresh_btn.click( |
|
fn=heatmap, |
|
inputs=top_n, |
|
outputs=output_heatmap |
|
) |
|
|
|
linePlot_btn.click( |
|
fn = update_visualization, |
|
inputs = [viz_dropdown,weight_slider,top_n_slider], |
|
outputs = [time_series_fig, linePlot_status_text] |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=stream_chat_with_rag, |
|
inputs=[query_input, year_selector], |
|
outputs=output_text |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |