mansurarief's picture
Update app.py
06cae5a verified
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import cv2
from io import BytesIO
import uuid
import os
import requests
import json
AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN")
SUMMARY_TABLE_URL = os.environ.get("SUMMARY_TABLE_URL")
QUERIES_TABLE_URL = os.environ.get("QUERIES_TABLE_URL")
def send_to_airtable(data, url):
headers = {
"Authorization": f"Bearer {AIRTABLE_TOKEN}",
"Content-Type": "application/json"
}
data = json.dumps(data)
response = requests.post(url, headers=headers, data=data)
print(response.text)
return response.text
def log_final_decision(activity, objective, budget_left, session_id, url):
data = {"fields": {"activity": activity, "objective": objective, "budget_left": budget_left, "session_id": session_id}}
return send_to_airtable(data, url)
def log_query(session_id, x, y, url):
data = {"fields": {"session_id": session_id, "x": x, "y": y}}
return send_to_airtable(data, url)
def draw_points(points, query_count, session_id=""):
plt.figure(figsize=(12, 8))
if points:
x_vals, y_vals = zip(*points)
cmap = plt.get_cmap('viridis')
for i, (x_val, y_val) in enumerate(points):
plt.scatter(x_val, y_val, color=cmap(i / len(points)))
plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.ylim(-1000, 1000)
plt.grid(True)
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.title(f'Sampled Points - Query {query_count[0] + 1} ({session_id})')
buf = BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
image = np.frombuffer(buf.getvalue(), dtype=np.uint8)
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image
def f(x, scale=1000):
x_scaled = x / scale
polynomial = (x_scaled**4 - 10*x_scaled**3 + 35*x_scaled**2 - 50*x_scaled + 24) / 4000
sinusoidal = 800 * np.sin(abs(x_scaled)/10) + 500 * np.cos(x_scaled/8) + 200 * np.sin(x_scaled/2)
y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000)
return y_value
def finalize_decision(global_minimizer, points, session_state, query_count, nqueries=10, activity="group_1"):
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
remaining_queries = nqueries - query_count[0]
try:
x_star = float(global_minimizer)
y_value = f(x_star)
y_star = y_value - 10 * remaining_queries
session_end_msg = f"Session ends with remaining query {remaining_queries}, and objective {y_value:.2f} - {remaining_queries}*10 = {y_star:.2f}."
points.append((x_star, y_value))
log_query(session_id, x_star, y_value, QUERIES_TABLE_URL)
img = draw_points(points, query_count, session_id)
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
#add "--- final decision" string to the final points_list element
points_list += "--- final decision"
log_final_decision(activity, y_value, remaining_queries, session_id, SUMMARY_TABLE_URL)
return [session_end_msg, session_end_msg, gr.Button("Submit", interactive=False), gr.Button("Finalize Decision", interactive=False), img, points_list]
except ValueError:
img = draw_points(points, query_count, session_id)
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
msg = "Invalid input: Please enter a numeric value for x*"
return [msg, msg, gr.Button("Submit", interactive=True), gr.Button("Finalize Decision", interactive=True), img, points_list]
def submit_query(x, query_count, points, session_state, nqueries=10):
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
session_state['session_id'] = session_id
if query_count[0] >= nqueries:
return None, query_count, "Query limit reached", "", session_state, "Query limit reached"
try:
x = float(x)
except ValueError:
msg = "Invalid input: Please enter a numeric value"
return None, query_count, msg, "", session_state, msg
y_value = f(x)
points.append((x, y_value))
log_query(session_id, x, y_value, QUERIES_TABLE_URL)
img = draw_points(points, query_count, session_id)
query_count[0] += 1
remaining_queries = nqueries - query_count[0]
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
msg = f"{remaining_queries} queries remaining"
return img, query_count, msg, points_list, session_state, msg
with gr.Blocks() as app:
query_count = gr.State([0])
points = gr.State([])
session_state = gr.State({'session_id': str(uuid.uuid4())[:6]})
with gr.Row():
x_input = gr.Number(label="X")
submit_button = gr.Button("Submit", interactive=True)
with gr.Row():
remaining_queries_label = gr.Label()
points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="")
with gr.Row():
output_image = gr.Image(label="Plot", width="100vw")
with gr.Row():
global_minimizer_input = gr.Textbox(label="Global Minimizer (x*)")
finalize_button = gr.Button("Finalize Decision")
with gr.Row():
final_label = gr.Label()
submit_button.click(
submit_query,
inputs=[x_input, query_count, points, session_state],
outputs=[output_image, query_count, remaining_queries_label, points_display, session_state, final_label]
)
finalize_button.click(
finalize_decision,
inputs=[global_minimizer_input, points, session_state, query_count],
outputs=[remaining_queries_label, final_label, submit_button, finalize_button, output_image, points_display]
)
app.launch(share=False)
#save the output into airtable
#save the queries into the correct table