Update app.py
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ def stream_chat_with_rag(
|
|
26 |
):
|
27 |
# print(f"Message: {message}")
|
28 |
#answer = client.predict(question=question, api_name="/run_graph")
|
29 |
-
answer,
|
30 |
query= message,
|
31 |
election_year=year,
|
32 |
api_name="/process_query"
|
@@ -38,16 +38,32 @@ def stream_chat_with_rag(
|
|
38 |
print(answer)
|
39 |
history.append((message, response +"\n"+ answer))
|
40 |
|
41 |
-
|
42 |
-
# print(fig)
|
43 |
|
44 |
-
#
|
|
|
45 |
|
46 |
-
# # Render the figure
|
47 |
-
# fig = pio.from_json(json.dumps(fig_dict))
|
48 |
-
# fig.show()
|
49 |
return answer
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# def predict(message, history):
|
53 |
# history_langchain_format = []
|
@@ -228,26 +244,26 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
|
|
228 |
)
|
229 |
|
230 |
gr.Markdown("## Top words of the relevant Q&A")
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
|
238 |
# Add custom CSS to ensure proper plot sizing
|
239 |
gr.HTML("""
|
240 |
<style>
|
241 |
-
# .topic-plot {
|
242 |
-
# min-height: 600px;
|
243 |
-
# width: 100%;
|
244 |
-
# margin: auto;
|
245 |
-
# }
|
246 |
.heatmap-plot {
|
247 |
min-height: 400px;
|
248 |
width: 100%;
|
249 |
margin: auto;
|
250 |
}
|
|
|
|
|
|
|
|
|
|
|
251 |
</style>
|
252 |
""")
|
253 |
# topics_df = gr.Dataframe(value=df, label="Data Input")
|
@@ -266,12 +282,12 @@ with gr.Blocks(title="Reddit Election Analysis") as demo:
|
|
266 |
outputs = [time_series_fig, linePlot_status_text]
|
267 |
)
|
268 |
|
269 |
-
#
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
|
276 |
|
277 |
if __name__ == "__main__":
|
|
|
26 |
):
|
27 |
# print(f"Message: {message}")
|
28 |
#answer = client.predict(question=question, api_name="/run_graph")
|
29 |
+
answer, sources = client.predict(
|
30 |
query= message,
|
31 |
election_year=year,
|
32 |
api_name="/process_query"
|
|
|
38 |
print(answer)
|
39 |
history.append((message, response +"\n"+ answer))
|
40 |
|
41 |
+
|
|
|
42 |
|
43 |
+
# Render the figure
|
44 |
+
|
45 |
|
|
|
|
|
|
|
46 |
return answer
|
47 |
|
48 |
+
def topic_plot_gener(message: str, year: str):
|
49 |
+
fig = client.predict(
|
50 |
+
query= message,
|
51 |
+
election_year=year,
|
52 |
+
api_name="/topics_plot_genera"
|
53 |
+
)
|
54 |
+
# print("top works from API:")
|
55 |
+
print(fig)
|
56 |
+
plot_base64 = fig
|
57 |
+
|
58 |
+
plot_bytes = base64.b64decode(plot_base64['plot'].split(',')[1])
|
59 |
+
img = plt.imread(BytesIO(plot_bytes), format='PNG')
|
60 |
+
plt.figure(figsize = (12, 6), dpi = 150)
|
61 |
+
plt.imshow(img)
|
62 |
+
plt.axis('off')
|
63 |
+
plt.show()
|
64 |
+
|
65 |
+
return plt.gcf()
|
66 |
+
|
67 |
|
68 |
# def predict(message, history):
|
69 |
# history_langchain_format = []
|
|
|
244 |
)
|
245 |
|
246 |
gr.Markdown("## Top words of the relevant Q&A")
|
247 |
+
with gr.Row():
|
248 |
+
topic_plot = gr.Plot(
|
249 |
+
label="Topic Distribution",
|
250 |
+
container=True, # Ensures the plot is contained within its area
|
251 |
+
elem_classes="topic-plot" # Add a custom class for styling
|
252 |
+
)
|
253 |
|
254 |
# Add custom CSS to ensure proper plot sizing
|
255 |
gr.HTML("""
|
256 |
<style>
|
|
|
|
|
|
|
|
|
|
|
257 |
.heatmap-plot {
|
258 |
min-height: 400px;
|
259 |
width: 100%;
|
260 |
margin: auto;
|
261 |
}
|
262 |
+
.topic-plot {
|
263 |
+
min-width: 600px;
|
264 |
+
height: 100%;
|
265 |
+
margin: auto;
|
266 |
+
}
|
267 |
</style>
|
268 |
""")
|
269 |
# topics_df = gr.Dataframe(value=df, label="Data Input")
|
|
|
282 |
outputs = [time_series_fig, linePlot_status_text]
|
283 |
)
|
284 |
|
285 |
+
# Update both outputs when submit is clicked
|
286 |
+
topic_btn.click(
|
287 |
+
fn= topic_plot_gener,
|
288 |
+
inputs=[query_input, year_selector],
|
289 |
+
outputs= topic_plot
|
290 |
+
)
|
291 |
|
292 |
|
293 |
if __name__ == "__main__":
|