|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
from io import BytesIO |
|
import uuid |
|
import os |
|
import requests |
|
import json |
|
|
|
|
|
AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN") |
|
SUMMARY_TABLE_URL = os.environ.get("SUMMARY_TABLE_URL") |
|
QUERIES_TABLE_URL = os.environ.get("QUERIES_TABLE_URL") |
|
|
|
def send_to_airtable(data, url): |
|
headers = { |
|
"Authorization": f"Bearer {AIRTABLE_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
data = json.dumps(data) |
|
response = requests.post(url, headers=headers, data=data) |
|
print(response.text) |
|
return response.text |
|
|
|
def log_final_decision(activity, objective, budget_left, session_id, url): |
|
data = {"fields": {"activity": activity, "objective": objective, "budget_left": budget_left, "session_id": session_id}} |
|
return send_to_airtable(data, url) |
|
|
|
|
|
def log_query(session_id, x, y, url): |
|
data = {"fields": {"session_id": session_id, "x": x, "y": y}} |
|
return send_to_airtable(data, url) |
|
|
|
def draw_points(points, query_count, session_id=""): |
|
plt.figure(figsize=(12, 8)) |
|
if points: |
|
x_vals, y_vals = zip(*points) |
|
cmap = plt.get_cmap('viridis') |
|
for i, (x_val, y_val) in enumerate(points): |
|
plt.scatter(x_val, y_val, color=cmap(i / len(points))) |
|
plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center') |
|
plt.xlabel('x') |
|
plt.ylabel('f(x)') |
|
plt.ylim(-1000, 1000) |
|
plt.grid(True) |
|
plt.axhline(0, color='black', linewidth=0.5) |
|
plt.axvline(0, color='black', linewidth=0.5) |
|
plt.title(f'Sampled Points - Query {query_count[0] + 1} ({session_id})') |
|
buf = BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
image = np.frombuffer(buf.getvalue(), dtype=np.uint8) |
|
image = cv2.imdecode(image, cv2.IMREAD_COLOR) |
|
return image |
|
|
|
def f(x, scale=1000): |
|
x_scaled = x / scale |
|
polynomial = (x_scaled**4 - 10*x_scaled**3 + 35*x_scaled**2 - 50*x_scaled + 24) / 4000 |
|
sinusoidal = 800 * np.sin(abs(x_scaled)/10) + 500 * np.cos(x_scaled/8) + 200 * np.sin(x_scaled/2) |
|
y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000) |
|
return y_value |
|
|
|
def finalize_decision(global_minimizer, points, session_state, query_count, nqueries=10, activity="group_1"): |
|
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6] |
|
remaining_queries = nqueries - query_count[0] |
|
try: |
|
x_star = float(global_minimizer) |
|
y_value = f(x_star) |
|
y_star = y_value - 10 * remaining_queries |
|
session_end_msg = f"Session ends with remaining query {remaining_queries}, and objective {y_value:.2f} - {remaining_queries}*10 = {y_star:.2f}." |
|
|
|
points.append((x_star, y_value)) |
|
log_query(session_id, x_star, y_value, QUERIES_TABLE_URL) |
|
img = draw_points(points, query_count, session_id) |
|
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) |
|
|
|
points_list += "--- final decision" |
|
|
|
log_final_decision(activity, y_value, remaining_queries, session_id, SUMMARY_TABLE_URL) |
|
|
|
return [session_end_msg, session_end_msg, gr.Button("Submit", interactive=False), gr.Button("Finalize Decision", interactive=False), img, points_list] |
|
except ValueError: |
|
img = draw_points(points, query_count, session_id) |
|
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) |
|
msg = "Invalid input: Please enter a numeric value for x*" |
|
return [msg, msg, gr.Button("Submit", interactive=True), gr.Button("Finalize Decision", interactive=True), img, points_list] |
|
|
|
def submit_query(x, query_count, points, session_state, nqueries=10): |
|
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6] |
|
session_state['session_id'] = session_id |
|
|
|
if query_count[0] >= nqueries: |
|
return None, query_count, "Query limit reached", "", session_state, "Query limit reached" |
|
|
|
try: |
|
x = float(x) |
|
except ValueError: |
|
msg = "Invalid input: Please enter a numeric value" |
|
return None, query_count, msg, "", session_state, msg |
|
|
|
y_value = f(x) |
|
points.append((x, y_value)) |
|
log_query(session_id, x, y_value, QUERIES_TABLE_URL) |
|
img = draw_points(points, query_count, session_id) |
|
query_count[0] += 1 |
|
remaining_queries = nqueries - query_count[0] |
|
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) |
|
msg = f"{remaining_queries} queries remaining" |
|
return img, query_count, msg, points_list, session_state, msg |
|
|
|
with gr.Blocks() as app: |
|
query_count = gr.State([0]) |
|
points = gr.State([]) |
|
session_state = gr.State({'session_id': str(uuid.uuid4())[:6]}) |
|
|
|
with gr.Row(): |
|
x_input = gr.Number(label="X") |
|
submit_button = gr.Button("Submit", interactive=True) |
|
|
|
with gr.Row(): |
|
remaining_queries_label = gr.Label() |
|
points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="") |
|
|
|
with gr.Row(): |
|
output_image = gr.Image(label="Plot", width="100vw") |
|
|
|
with gr.Row(): |
|
global_minimizer_input = gr.Textbox(label="Global Minimizer (x*)") |
|
finalize_button = gr.Button("Finalize Decision") |
|
|
|
with gr.Row(): |
|
final_label = gr.Label() |
|
|
|
|
|
submit_button.click( |
|
submit_query, |
|
inputs=[x_input, query_count, points, session_state], |
|
outputs=[output_image, query_count, remaining_queries_label, points_display, session_state, final_label] |
|
) |
|
|
|
finalize_button.click( |
|
finalize_decision, |
|
inputs=[global_minimizer_input, points, session_state, query_count], |
|
outputs=[remaining_queries_label, final_label, submit_button, finalize_button, output_image, points_display] |
|
) |
|
|
|
app.launch(share=False) |
|
|
|
|
|
|