Commit
•
06cae5a
1
Parent(s):
300c000
Update app.py
Browse files
app.py
CHANGED
@@ -4,52 +4,33 @@ import matplotlib.pyplot as plt
|
|
4 |
import cv2
|
5 |
from io import BytesIO
|
6 |
import uuid
|
7 |
-
import csv
|
8 |
import os
|
9 |
import requests
|
10 |
import json
|
11 |
|
|
|
12 |
AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN")
|
|
|
|
|
13 |
|
14 |
-
def send_to_airtable(data):
|
15 |
-
url = "https://api.airtable.com/v0/appHgBulDLVnlf3uS/simt_ap05032024"
|
16 |
headers = {
|
17 |
"Authorization": f"Bearer {AIRTABLE_TOKEN}",
|
18 |
"Content-Type": "application/json"
|
19 |
}
|
20 |
data = json.dumps(data)
|
21 |
-
print(data)
|
22 |
response = requests.post(url, headers=headers, data=data)
|
23 |
print(response.text)
|
24 |
return response.text
|
25 |
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
-
# Function to initialize the CSV file if it does not exist
|
30 |
-
def initialize_data_file():
|
31 |
-
if not os.path.exists(data_file_path):
|
32 |
-
with open(data_file_path, 'w', newline='') as file:
|
33 |
-
writer = csv.writer(file)
|
34 |
-
writer.writerow(['session_id', 'x', 'y'])
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
def log_data(session_id, x, y):
|
40 |
-
with open(data_file_path, 'a', newline='') as file:
|
41 |
-
writer = csv.writer(file)
|
42 |
-
writer.writerow([session_id, x, y])
|
43 |
-
|
44 |
-
#send to airtable
|
45 |
-
data = {
|
46 |
-
"fields": {
|
47 |
-
"session_id": session_id,
|
48 |
-
"x": x,
|
49 |
-
"y": y
|
50 |
-
}
|
51 |
-
}
|
52 |
-
send_to_airtable(data)
|
53 |
|
54 |
def draw_points(points, query_count, session_id=""):
|
55 |
plt.figure(figsize=(12, 8))
|
@@ -59,7 +40,6 @@ def draw_points(points, query_count, session_id=""):
|
|
59 |
for i, (x_val, y_val) in enumerate(points):
|
60 |
plt.scatter(x_val, y_val, color=cmap(i / len(points)))
|
61 |
plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center')
|
62 |
-
|
63 |
plt.xlabel('x')
|
64 |
plt.ylabel('f(x)')
|
65 |
plt.ylim(-1000, 1000)
|
@@ -82,50 +62,91 @@ def f(x, scale=1000):
|
|
82 |
y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000)
|
83 |
return y_value
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
-
def
|
87 |
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
|
88 |
-
session_state['session_id'] = session_id
|
89 |
|
90 |
if query_count[0] >= nqueries:
|
91 |
-
return None, query_count, "Query limit reached", "", session_state
|
92 |
|
93 |
try:
|
94 |
x = float(x)
|
95 |
except ValueError:
|
96 |
-
|
|
|
97 |
|
98 |
y_value = f(x)
|
99 |
points.append((x, y_value))
|
100 |
-
|
101 |
img = draw_points(points, query_count, session_id)
|
102 |
query_count[0] += 1
|
103 |
remaining_queries = nqueries - query_count[0]
|
104 |
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
|
105 |
-
|
|
|
106 |
|
107 |
with gr.Blocks() as app:
|
108 |
query_count = gr.State([0])
|
109 |
points = gr.State([])
|
110 |
-
session_state = gr.State({'session_id':
|
111 |
-
|
112 |
with gr.Row():
|
113 |
x_input = gr.Number(label="X")
|
114 |
-
submit_button = gr.Button("Submit")
|
|
|
|
|
|
|
|
|
115 |
|
116 |
with gr.Row():
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
with gr.Row():
|
121 |
-
|
|
|
122 |
|
123 |
submit_button.click(
|
124 |
-
|
125 |
inputs=[x_input, query_count, points, session_state],
|
126 |
-
outputs=[output_image, query_count, remaining_queries_label, points_display, session_state]
|
127 |
)
|
128 |
|
129 |
-
|
|
|
|
|
|
|
|
|
130 |
|
|
|
131 |
|
|
|
|
|
|
4 |
import cv2
|
5 |
from io import BytesIO
|
6 |
import uuid
|
|
|
7 |
import os
|
8 |
import requests
|
9 |
import json
|
10 |
|
11 |
+
|
12 |
AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN")
|
13 |
+
SUMMARY_TABLE_URL = os.environ.get("SUMMARY_TABLE_URL")
|
14 |
+
QUERIES_TABLE_URL = os.environ.get("QUERIES_TABLE_URL")
|
15 |
|
16 |
+
def send_to_airtable(data, url):
|
|
|
17 |
headers = {
|
18 |
"Authorization": f"Bearer {AIRTABLE_TOKEN}",
|
19 |
"Content-Type": "application/json"
|
20 |
}
|
21 |
data = json.dumps(data)
|
|
|
22 |
response = requests.post(url, headers=headers, data=data)
|
23 |
print(response.text)
|
24 |
return response.text
|
25 |
|
26 |
+
def log_final_decision(activity, objective, budget_left, session_id, url):
|
27 |
+
data = {"fields": {"activity": activity, "objective": objective, "budget_left": budget_left, "session_id": session_id}}
|
28 |
+
return send_to_airtable(data, url)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
def log_query(session_id, x, y, url):
|
32 |
+
data = {"fields": {"session_id": session_id, "x": x, "y": y}}
|
33 |
+
return send_to_airtable(data, url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def draw_points(points, query_count, session_id=""):
|
36 |
plt.figure(figsize=(12, 8))
|
|
|
40 |
for i, (x_val, y_val) in enumerate(points):
|
41 |
plt.scatter(x_val, y_val, color=cmap(i / len(points)))
|
42 |
plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center')
|
|
|
43 |
plt.xlabel('x')
|
44 |
plt.ylabel('f(x)')
|
45 |
plt.ylim(-1000, 1000)
|
|
|
62 |
y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000)
|
63 |
return y_value
|
64 |
|
65 |
+
def finalize_decision(global_minimizer, points, session_state, query_count, nqueries=10, activity="group_1"):
|
66 |
+
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
|
67 |
+
remaining_queries = nqueries - query_count[0]
|
68 |
+
try:
|
69 |
+
x_star = float(global_minimizer)
|
70 |
+
y_value = f(x_star)
|
71 |
+
y_star = y_value - 10 * remaining_queries
|
72 |
+
session_end_msg = f"Session ends with remaining query {remaining_queries}, and objective {y_value:.2f} - {remaining_queries}*10 = {y_star:.2f}."
|
73 |
+
|
74 |
+
points.append((x_star, y_value))
|
75 |
+
log_query(session_id, x_star, y_value, QUERIES_TABLE_URL)
|
76 |
+
img = draw_points(points, query_count, session_id)
|
77 |
+
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
|
78 |
+
#add "--- final decision" string to the final points_list element
|
79 |
+
points_list += "--- final decision"
|
80 |
+
|
81 |
+
log_final_decision(activity, y_value, remaining_queries, session_id, SUMMARY_TABLE_URL)
|
82 |
+
|
83 |
+
return [session_end_msg, session_end_msg, gr.Button("Submit", interactive=False), gr.Button("Finalize Decision", interactive=False), img, points_list]
|
84 |
+
except ValueError:
|
85 |
+
img = draw_points(points, query_count, session_id)
|
86 |
+
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
|
87 |
+
msg = "Invalid input: Please enter a numeric value for x*"
|
88 |
+
return [msg, msg, gr.Button("Submit", interactive=True), gr.Button("Finalize Decision", interactive=True), img, points_list]
|
89 |
|
90 |
+
def submit_query(x, query_count, points, session_state, nqueries=10):
|
91 |
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
|
92 |
+
session_state['session_id'] = session_id
|
93 |
|
94 |
if query_count[0] >= nqueries:
|
95 |
+
return None, query_count, "Query limit reached", "", session_state, "Query limit reached"
|
96 |
|
97 |
try:
|
98 |
x = float(x)
|
99 |
except ValueError:
|
100 |
+
msg = "Invalid input: Please enter a numeric value"
|
101 |
+
return None, query_count, msg, "", session_state, msg
|
102 |
|
103 |
y_value = f(x)
|
104 |
points.append((x, y_value))
|
105 |
+
log_query(session_id, x, y_value, QUERIES_TABLE_URL)
|
106 |
img = draw_points(points, query_count, session_id)
|
107 |
query_count[0] += 1
|
108 |
remaining_queries = nqueries - query_count[0]
|
109 |
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
|
110 |
+
msg = f"{remaining_queries} queries remaining"
|
111 |
+
return img, query_count, msg, points_list, session_state, msg
|
112 |
|
113 |
with gr.Blocks() as app:
|
114 |
query_count = gr.State([0])
|
115 |
points = gr.State([])
|
116 |
+
session_state = gr.State({'session_id': str(uuid.uuid4())[:6]})
|
117 |
+
|
118 |
with gr.Row():
|
119 |
x_input = gr.Number(label="X")
|
120 |
+
submit_button = gr.Button("Submit", interactive=True)
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
remaining_queries_label = gr.Label()
|
124 |
+
points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="")
|
125 |
|
126 |
with gr.Row():
|
127 |
+
output_image = gr.Image(label="Plot", width="100vw")
|
128 |
+
|
129 |
+
with gr.Row():
|
130 |
+
global_minimizer_input = gr.Textbox(label="Global Minimizer (x*)")
|
131 |
+
finalize_button = gr.Button("Finalize Decision")
|
132 |
|
133 |
with gr.Row():
|
134 |
+
final_label = gr.Label()
|
135 |
+
|
136 |
|
137 |
submit_button.click(
|
138 |
+
submit_query,
|
139 |
inputs=[x_input, query_count, points, session_state],
|
140 |
+
outputs=[output_image, query_count, remaining_queries_label, points_display, session_state, final_label]
|
141 |
)
|
142 |
|
143 |
+
finalize_button.click(
|
144 |
+
finalize_decision,
|
145 |
+
inputs=[global_minimizer_input, points, session_state, query_count],
|
146 |
+
outputs=[remaining_queries_label, final_label, submit_button, finalize_button, output_image, points_display]
|
147 |
+
)
|
148 |
|
149 |
+
app.launch(share=False)
|
150 |
|
151 |
+
#save the output into airtable
|
152 |
+
#save the queries into the correct table
|