Vera-ZWY's picture
Update app.py
a244f3c verified
import gradio as gr
from gradio_client import Client, handle_file
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
from io import StringIO, BytesIO
import base64
import json
import plotly.graph_objects as go
# import plotly.io as pio
# from linePlot import plot_stacked_time_series, plot_emotion_topic_grid
# Define your Hugging Face token (make sure to set it as an environment variable)
HF_TOKEN = os.getenv("HF_TOKEN") # Replace with your actual token if not using an environment variable
# Initialize the Gradio Client for the specified API
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN)
# query_input = ""
def stream_chat_with_rag(
message: str,
history: list,
year: str
):
# print(f"Message: {message}")
#answer = client.predict(question=question, api_name="/run_graph")
answer, sources = client.predict(
query= message,
election_year=year,
api_name="/process_query"
)
# Debugging: Print the raw response
response = f"Retrieving the submissions in {year}..."
print("Raw answer from API:")
print(answer)
history.append((message, response +"\n"+ answer))
# Render the figure
return answer
def topic_plot_gener(message: str, year: str):
fig = client.predict(
query= message,
election_year=year,
api_name="/topics_plot_genera"
)
# print("top works from API:")
print(fig)
# plot_base64 = fig
# plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
# img = plt.imread(BytesIO(plot_bytes), format='PNG')
# plt.figure(figsize = (12, 6), dpi = 150)
# plt.imshow(img)
# plt.axis('off')
# plt.show()
plot_json = json.loads(fig['plot'])
# Create a figure using the decoded data
fig = go.Figure(data=plot_json["data"])
# Show the plot
return fig
# return plt.gcf()
# def predict(message, history):
# history_langchain_format = []
# for msg in history:
# if msg['role'] == "user":
# history_langchain_format.append(HumanMessage(content=msg['content']))
# elif msg['role'] == "assistant":
# history_langchain_format.append(AIMessage(content=msg['content']))
# history_langchain_format.append(HumanMessage(content=message))
# gpt_response = llm(history_langchain_format)
# return gpt_response.content
def heatmap(top_n):
# df = pd.read_csv('submission_emotiontopics2024GPTresult.csv')
# topics_df = gr.Dataframe(value=df, label="Data Input")
pivot_table = client.predict(
top_n= top_n,
api_name="/get_heatmap_pivot_table"
)
print(pivot_table)
print(type(pivot_table))
"""
pivot_table is a dict like:
{'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'],
'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0],
['disgust', 26911.0, 123112.0, 64567.0, 46460.0],
['fear', 51466.0, 188898.0, 113174.0, 150578.0],
['neutral', 77005.0, 192945.0, 20549.0, 190793.0]],
'metadata': None}
"""
# transfere dictionary to df
df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers'])
df.set_index('Index', inplace=True)
plt.figure(figsize=(10, 8))
sns.heatmap(df,
cmap='YlOrRd',
cbar_kws={'label': 'Weighted Frequency'},
square=True)
plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency')
plt.xlabel('Topics')
plt.ylabel('Emotions')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
return plt.gcf()
# def decode_plot(plot_base64, top_n):
# plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
# img = plt.imread(BytesIO(plot_bytes), format='PNG')
# plt.figure(figsize = (12, 2*top_n), dpi = 150)
# plt.imshow(img)
# plt.axis('off')
# plt.show()
# return plt.gcf()
def linePlot(viz_type, weight, top_n):
# client = Client("mangoesai/Elections_Comparison_Agent_V4.1")
result = client.predict(
viz_type=viz_type,
weight=weight,
top_n=top_n,
api_name="/linePlot_3C1"
)
# print(result)
# result is a tuble of dictionary of (plot_base64, str), string message of description of the plot
plot_base64 = result[0]
plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
img = plt.imread(BytesIO(plot_bytes), format='PNG')
plt.figure(figsize = (12, 2*top_n), dpi = 150)
plt.imshow(img)
plt.axis('off')
plt.show()
return plt.gcf(), result[1]
# Create Gradio interface
with gr.Blocks(title="Reddit Election Analysis") as demo:
gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
with gr.Row():
with gr.Column():
top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
fresh_btn = gr.Button("Refresh Heatmap")
with gr.Column():
# with gr.Row():
output_heatmap = gr.Plot(
label="Top Public sentiment & Social topic Heatmap",
container=True, # Ensures the plot is contained within its area
elem_classes="heatmap-plot" # Add a custom class for styling
)
gr.Markdown("# Get the time series of the Public sentiment & Social topic")
with gr.Row():
with gr.Column(scale=1):
# Control panel
lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix'])
weight_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Weight (Score vs. Frequency)"
)
top_n_slider = gr.Slider(
minimum=2,
maximum=10,
value=5,
step=1,
label="Top N Items"
)
# with gr.Column():
viz_dropdown = gr.Dropdown(
choices=["emotions", "topics", "grid"],
value="emotions",
label="Visualization Type",
info="Select the type of visualization to display"
)
linePlot_btn = gr.Button("Update Visualizations")
linePlot_status_text = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=3):
time_series_fig = gr.Plot()
gr.Markdown("# Reddit Election Posts/Comments Analysis")
gr.Markdown("Ask questions about election-related comments and posts")
with gr.Row():
with gr.Column(scale = 1):
year_selector = gr.Radio(
choices=["2016 Election", "2024 Election", "Comparison two years"],
label="Select Election Year",
value="2024 Election"
)
slider = gr.Slider(50, 500, render=False, label= "Tokens")
# query_input = gr.Textbox(
# label="Your Question",
# placeholder="Ask about election comments or posts..."
# )
# submit_btn = gr.Button("Submit")
gr.Markdown("""
## Example Questions:
- Is there any comments don't like the election results
- Summarize the main discussions about voting process
- What're the common opinions about candidates?
- What're common opinions about immigrant topic?
""")
# with gr.Column():
# output_text = gr.Textbox(
# label="Response",
# lines=20
# )
with gr.Column(scale = 2):
gr.ChatInterface(stream_chat_with_rag,
type="messages",
# chatbot=stream_chat_with_rag,
additional_inputs = [year_selector]
)
gr.Markdown("## Top words of the relevant Q&A")
with gr.Row():
with gr.Column(scale = 1):
query_input = gr.Textbox(
label="Your Question For Topicalize",
placeholder="Copy and past your question there to vilaulize the top words of relevant topic"
)
topic_btn = gr.Button("Topicalize the RAG sources")
with gr.Column(scale = 2):
topic_plot = gr.Plot(
label="Top Words Distribution",
container=True, # Ensures the plot is contained within its area
elem_classes="topic-plot" # Add a custom class for styling
)
# Add custom CSS to ensure proper plot sizing
gr.HTML("""
<style>
.heatmap-plot {
min-height: 400px;
width: 100%;
margin: auto;
}
.topic-plot {
min-width: 600px;
height: 100%;
margin: auto;
}
</style>
""")
# topics_df = gr.Dataframe(value=df, label="Data Input")
fresh_btn.click(
fn=heatmap,
inputs=top_n,
outputs=output_heatmap
)
linePlot_btn.click(
fn = linePlot,
inputs = [viz_dropdown,weight_slider,top_n_slider],
outputs = [time_series_fig, linePlot_status_text]
)
# Update both outputs when submit is clicked
topic_btn.click(
fn= topic_plot_gener,
inputs=[query_input, year_selector],
outputs= topic_plot
)
if __name__ == "__main__":
demo.launch(share=True)