File size: 7,726 Bytes
96911b6
 
0ed32cd
16d491d
96911b6
1cf4aa1
290ac64
 
375dd21
 
6ce4a98
96911b6
 
 
 
 
a708fda
96911b6
9200125
96911b6
 
 
 
 
3d7d31a
96911b6
 
89661b3
96911b6
bb3ba32
96911b6
 
 
 
 
 
 
 
b1adea5
 
4f6e76c
375dd21
 
 
 
 
 
96911b6
 
d766d8b
 
 
8edaa73
b36c45b
0ed32cd
028cea4
0ed32cd
d766d8b
4bc7cb3
51319c6
 
 
 
 
 
 
 
 
 
 
1cf4aa1
51319c6
366588b
51319c6
0ed32cd
 
9fc9533
0ed32cd
 
 
 
 
 
 
 
 
 
 
 
0f36100
a708fda
4f6e76c
 
 
 
 
 
 
 
a708fda
 
4ebe04e
6ce4a98
4ebe04e
 
 
 
 
 
cedf8bf
4f6e76c
 
 
 
 
 
 
 
 
 
4ebe04e
 
a708fda
e1b9d08
bb3ba32
3d7d31a
a708fda
d766d8b
a708fda
a59e807
a708fda
d766d8b
a59e807
a708fda
d766d8b
 
 
 
 
a708fda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb3ba32
 
e1b9d08
bb3ba32
e1b9d08
 
 
 
 
bb3ba32
e1b9d08
bb3ba32
e1b9d08
 
3d7d31a
e1b9d08
bb3ba32
e1b9d08
bb3ba32
 
 
 
 
 
 
e1b9d08
bb3ba32
e1b9d08
 
 
a708fda
 
bb3ba32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d766d8b
 
 
 
 
bb3ba32
 
a708fda
 
 
 
d766d8b
 
 
 
 
a708fda
 
4ebe04e
a708fda
 
 
bb3ba32
 
dd27210
c9150f4
e1b9d08
4bc7cb3
e1b9d08
bb3ba32
 
96911b6
d766d8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import gradio as gr
from gradio_client import Client, handle_file
import seaborn as sns
import matplotlib.pyplot as plt
import os
import pandas as pd
from io import StringIO, BytesIO
import base64
import json
import plotly.io as pio
# from linePlot import plot_stacked_time_series, plot_emotion_topic_grid

# Define your Hugging Face token (make sure to set it as an environment variable)
HF_TOKEN = os.getenv("HF_TOKEN")  # Replace with your actual token if not using an environment variable

# Initialize the Gradio Client for the specified API
client = Client("mangoesai/Elections_Comparison_Agent_V4.1", hf_token=HF_TOKEN)

# client_name = ['2016 Election','2024 Election', 'Comparison two years']



def stream_chat_with_rag(
    message: str,
    # history: list,
    client_name: str
):
    # print(f"Message: {message}")
    #answer = client.predict(question=question, api_name="/run_graph")
    answer, fig = client.predict(
    	query= message,
		election_year=client_name,
		api_name="/process_query"
    )

    # Debugging: Print the raw response
    print("Raw answer from API:")
    print(answer)
    print("top works from API:")
    print(fig)
    
    fig_dict = json.loads(plotly_data['plot'])
    
    # Render the figure
    fig = pio.from_json(json.dumps(fig_dict))
    fig.show() 
    return answe, fig


    

def heatmap(top_n):
    # df = pd.read_csv('submission_emotiontopics2024GPTresult.csv')
    # topics_df = gr.Dataframe(value=df, label="Data Input")
    pivot_table = client.predict(
    	top_n= top_n,
		api_name="/get_heatmap_pivot_table"
    )
    print(pivot_table)
    print(type(pivot_table)) 
    """
    pivot_table is a dict like:
    {'headers': ['Index', 'economy', 'human rights', 'immigrant', 'politics'], 
    'data': [['anger', 55880.0, 557679.0, 147766.0, 180094.0], 
             ['disgust', 26911.0, 123112.0, 64567.0, 46460.0], 
             ['fear', 51466.0, 188898.0, 113174.0, 150578.0], 
             ['neutral', 77005.0, 192945.0, 20549.0, 190793.0]], 
    'metadata': None}
    """
    

    # transfere dictionary to df
    df = pd.DataFrame(pivot_table['data'], columns=pivot_table['headers'])
    df.set_index('Index', inplace=True)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(df,
                cmap='YlOrRd',
                cbar_kws={'label': 'Weighted Frequency'},
                square=True)
    
    plt.title(f'Top {top_n} Emotions vs Topics Weighted Frequency')
    plt.xlabel('Topics')
    plt.ylabel('Emotions')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    
    return plt.gcf()



# def decode_plot(plot_base64, top_n):
#     plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
#     img = plt.imread(BytesIO(plot_bytes), format='PNG')
#     plt.figure(figsize = (12, 2*top_n), dpi = 150)
#     plt.imshow(img)
#     plt.axis('off')
#     plt.show()
#     return plt.gcf()


def linePlot(viz_type, weight, top_n):
    # client = Client("mangoesai/Elections_Comparison_Agent_V4.1")
    result = client.predict(
    		viz_type=viz_type,
    		weight=weight,
    		top_n=top_n,
    		api_name="/linePlot_3C1"
    )
    # print(result)
    # result is a tuble of dictionary of (plot_base64, str), string message of description of the plot
    plot_base64 = result[0]

    plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
    img = plt.imread(BytesIO(plot_bytes), format='PNG')
    plt.figure(figsize = (12, 2*top_n), dpi = 150)
    plt.imshow(img)
    plt.axis('off')
    plt.show()   
    return plt.gcf(), result[1]

    

# Create Gradio interface
with gr.Blocks(title="Reddit Election Analysis") as demo:
    gr.Markdown("# Reddit Public sentiment & Social topic distribution ")
    with gr.Row():        
        with gr.Column():
            top_n = gr.Dropdown(choices=[1,2,3,4,5,6,7,8,9,10])
            fresh_btn = gr.Button("Refresh Heatmap")

        with gr.Column():
            
            # with gr.Row():
            output_heatmap = gr.Plot(
                label="Top Public sentiment & Social topic Heatmap",
                container=True,  # Ensures the plot is contained within its area
                elem_classes="heatmap-plot"  # Add a custom class for styling
            )
    gr.Markdown("# Get the time series of the Public sentiment & Social topic")
    with gr.Row():
        with gr.Column(scale=1):
            # Control panel
            lineGraph_type = gr.Dropdown(choices = ['emotions', 'topics', '2Dmatrix'])
                            
            weight_slider = gr.Slider(
                minimum=0,
                maximum=1,
                value=0.5,
                step=0.1,
                label="Weight (Score vs. Frequency)"
            )
            
            top_n_slider = gr.Slider(
                minimum=2,
                maximum=10,
                value=5,
                step=1,
                label="Top N Items"
            )


        # with gr.Column():
            viz_dropdown = gr.Dropdown(
                choices=["emotions", "topics", "grid"],
                value="emotions",
                label="Visualization Type",
                info="Select the type of visualization to display"
            )
            linePlot_btn = gr.Button("Update Visualizations")
            linePlot_status_text = gr.Textbox(label="Status", interactive=False)
        
        with gr.Column(scale=3):
            time_series_fig = gr.Plot()

    gr.Markdown("# Reddit Election Posts/Comments Analysis")
    gr.Markdown("Ask questions about election-related comments and posts")

    with gr.Row():
        with gr.Column():
            year_selector = gr.Radio(
                choices=["2016 Election", "2024 Election", "Comparison two years"],
                label="Select Election Year",
                value="2016 Election"
            )

            query_input = gr.Textbox(
                label="Your Question",
                placeholder="Ask about election comments or posts..."
            )

            submit_btn = gr.Button("Submit")

            gr.Markdown("""
            ## Example Questions:
            - Is there any comments don't like the election results
            - Summarize the main discussions about voting process
            - What are the common opinions about candidates?
            """)
        with gr.Column():
            output_text = gr.Textbox(
                label="Response",
                lines=20
            )
    
    gr.Markdown("## Top works of the relevant Q&A")
    with gr.Row():
        output_plot = gr.Plot(
            label="Topic Distribution",
            container=True,  # Ensures the plot is contained within its area
            elem_classes="topic-plot"  # Add a custom class for styling
        )

    # Add custom CSS to ensure proper plot sizing
    gr.HTML("""
        <style>
            .topic-plot {
                min-height: 600px;
                width: 100%;
                margin: auto;
            }
            .heatmap-plot {
                min-height: 400px;
                width: 100%;
                margin: auto;
            }
        </style>
    """)
    # topics_df = gr.Dataframe(value=df, label="Data Input")


    
    fresh_btn.click(
        fn=heatmap,
        inputs=top_n,
        outputs=output_heatmap
    )
    
    linePlot_btn.click(
        fn = linePlot,
        inputs = [viz_dropdown,weight_slider,top_n_slider],
        outputs = [time_series_fig, linePlot_status_text]
    )

    # Update both outputs when submit is clicked
    submit_btn.click(
        fn=stream_chat_with_rag,
        inputs=[query_input, year_selector],
        outputs=output_text
    )


if __name__ == "__main__":
    demo.launch(share=True)