import gradio as gr import numpy as np import matplotlib.pyplot as plt import cv2 from io import BytesIO import uuid import os import requests import json AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN") SUMMARY_TABLE_URL = os.environ.get("SUMMARY_TABLE_URL") QUERIES_TABLE_URL = os.environ.get("QUERIES_TABLE_URL") def send_to_airtable(data, url): headers = { "Authorization": f"Bearer {AIRTABLE_TOKEN}", "Content-Type": "application/json" } data = json.dumps(data) response = requests.post(url, headers=headers, data=data) print(response.text) return response.text def log_final_decision(activity, objective, budget_left, session_id, url): data = {"fields": {"activity": activity, "objective": objective, "budget_left": budget_left, "session_id": session_id}} return send_to_airtable(data, url) def log_query(session_id, x, y, url): data = {"fields": {"session_id": session_id, "x": x, "y": y}} return send_to_airtable(data, url) def draw_points(points, query_count, session_id=""): plt.figure(figsize=(12, 8)) if points: x_vals, y_vals = zip(*points) cmap = plt.get_cmap('viridis') for i, (x_val, y_val) in enumerate(points): plt.scatter(x_val, y_val, color=cmap(i / len(points))) plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center') plt.xlabel('x') plt.ylabel('f(x)') plt.ylim(-1000, 1000) plt.grid(True) plt.axhline(0, color='black', linewidth=0.5) plt.axvline(0, color='black', linewidth=0.5) plt.title(f'Sampled Points - Query {query_count[0] + 1} ({session_id})') buf = BytesIO() plt.savefig(buf, format='png') plt.close() buf.seek(0) image = np.frombuffer(buf.getvalue(), dtype=np.uint8) image = cv2.imdecode(image, cv2.IMREAD_COLOR) return image def f(x, scale=1000): x_scaled = x / scale polynomial = (x_scaled**4 - 10*x_scaled**3 + 35*x_scaled**2 - 50*x_scaled + 24) / 4000 sinusoidal = 800 * np.sin(abs(x_scaled)/10) + 500 * np.cos(x_scaled/8) + 200 * np.sin(x_scaled/2) y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000) return y_value def finalize_decision(global_minimizer, points, session_state, query_count, nqueries=10, activity="group_1"): session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6] remaining_queries = nqueries - query_count[0] try: x_star = float(global_minimizer) y_value = f(x_star) y_star = y_value - 10 * remaining_queries session_end_msg = f"Session ends with remaining query {remaining_queries}, and objective {y_value:.2f} - {remaining_queries}*10 = {y_star:.2f}." points.append((x_star, y_value)) log_query(session_id, x_star, y_value, QUERIES_TABLE_URL) img = draw_points(points, query_count, session_id) points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) #add "--- final decision" string to the final points_list element points_list += "--- final decision" log_final_decision(activity, y_value, remaining_queries, session_id, SUMMARY_TABLE_URL) return [session_end_msg, session_end_msg, gr.Button("Submit", interactive=False), gr.Button("Finalize Decision", interactive=False), img, points_list] except ValueError: img = draw_points(points, query_count, session_id) points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) msg = "Invalid input: Please enter a numeric value for x*" return [msg, msg, gr.Button("Submit", interactive=True), gr.Button("Finalize Decision", interactive=True), img, points_list] def submit_query(x, query_count, points, session_state, nqueries=10): session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6] session_state['session_id'] = session_id if query_count[0] >= nqueries: return None, query_count, "Query limit reached", "", session_state, "Query limit reached" try: x = float(x) except ValueError: msg = "Invalid input: Please enter a numeric value" return None, query_count, msg, "", session_state, msg y_value = f(x) points.append((x, y_value)) log_query(session_id, x, y_value, QUERIES_TABLE_URL) img = draw_points(points, query_count, session_id) query_count[0] += 1 remaining_queries = nqueries - query_count[0] points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) msg = f"{remaining_queries} queries remaining" return img, query_count, msg, points_list, session_state, msg with gr.Blocks() as app: query_count = gr.State([0]) points = gr.State([]) session_state = gr.State({'session_id': str(uuid.uuid4())[:6]}) with gr.Row(): x_input = gr.Number(label="X") submit_button = gr.Button("Submit", interactive=True) with gr.Row(): remaining_queries_label = gr.Label() points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="") with gr.Row(): output_image = gr.Image(label="Plot", width="100vw") with gr.Row(): global_minimizer_input = gr.Textbox(label="Global Minimizer (x*)") finalize_button = gr.Button("Finalize Decision") with gr.Row(): final_label = gr.Label() submit_button.click( submit_query, inputs=[x_input, query_count, points, session_state], outputs=[output_image, query_count, remaining_queries_label, points_display, session_state, final_label] ) finalize_button.click( finalize_decision, inputs=[global_minimizer_input, points, session_state, query_count], outputs=[remaining_queries_label, final_label, submit_button, finalize_button, output_image, points_display] ) app.launch(share=False) #save the output into airtable #save the queries into the correct table