mansurarief commited on
Commit
06cae5a
1 Parent(s): 300c000

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -45
app.py CHANGED
@@ -4,52 +4,33 @@ import matplotlib.pyplot as plt
4
  import cv2
5
  from io import BytesIO
6
  import uuid
7
- import csv
8
  import os
9
  import requests
10
  import json
11
 
 
12
  AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN")
 
 
13
 
14
- def send_to_airtable(data):
15
- url = "https://api.airtable.com/v0/appHgBulDLVnlf3uS/simt_ap05032024"
16
  headers = {
17
  "Authorization": f"Bearer {AIRTABLE_TOKEN}",
18
  "Content-Type": "application/json"
19
  }
20
  data = json.dumps(data)
21
- print(data)
22
  response = requests.post(url, headers=headers, data=data)
23
  print(response.text)
24
  return response.text
25
 
26
- # Path to the CSV file for storing session data
27
- data_file_path = 'queries_data.csv'
 
28
 
29
- # Function to initialize the CSV file if it does not exist
30
- def initialize_data_file():
31
- if not os.path.exists(data_file_path):
32
- with open(data_file_path, 'w', newline='') as file:
33
- writer = csv.writer(file)
34
- writer.writerow(['session_id', 'x', 'y'])
35
 
36
- # Ensure the data file is initialized
37
- initialize_data_file()
38
-
39
- def log_data(session_id, x, y):
40
- with open(data_file_path, 'a', newline='') as file:
41
- writer = csv.writer(file)
42
- writer.writerow([session_id, x, y])
43
-
44
- #send to airtable
45
- data = {
46
- "fields": {
47
- "session_id": session_id,
48
- "x": x,
49
- "y": y
50
- }
51
- }
52
- send_to_airtable(data)
53
 
54
  def draw_points(points, query_count, session_id=""):
55
  plt.figure(figsize=(12, 8))
@@ -59,7 +40,6 @@ def draw_points(points, query_count, session_id=""):
59
  for i, (x_val, y_val) in enumerate(points):
60
  plt.scatter(x_val, y_val, color=cmap(i / len(points)))
61
  plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center')
62
-
63
  plt.xlabel('x')
64
  plt.ylabel('f(x)')
65
  plt.ylim(-1000, 1000)
@@ -82,50 +62,91 @@ def f(x, scale=1000):
82
  y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000)
83
  return y_value
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- def custom_function(x, query_count, points, session_state, nqueries=10):
87
  session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
88
- session_state['session_id'] = session_id # Update state with the current or new session ID
89
 
90
  if query_count[0] >= nqueries:
91
- return None, query_count, "Query limit reached", "", session_state
92
 
93
  try:
94
  x = float(x)
95
  except ValueError:
96
- return None, query_count, "Invalid input: Please enter a numeric value", "", session_state
 
97
 
98
  y_value = f(x)
99
  points.append((x, y_value))
100
- log_data(session_id, x, y_value)
101
  img = draw_points(points, query_count, session_id)
102
  query_count[0] += 1
103
  remaining_queries = nqueries - query_count[0]
104
  points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
105
- return img, query_count, f"{remaining_queries} queries remaining", points_list, session_state
 
106
 
107
  with gr.Blocks() as app:
108
  query_count = gr.State([0])
109
  points = gr.State([])
110
- session_state = gr.State({'session_id': None}) # Initialize state for session_id
111
-
112
  with gr.Row():
113
  x_input = gr.Number(label="X")
114
- submit_button = gr.Button("Submit")
 
 
 
 
115
 
116
  with gr.Row():
117
- remaining_queries_label = gr.Label(elem_id="remaining_queries")
118
- points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="", elem_id="queried_points")
 
 
 
119
 
120
  with gr.Row():
121
- output_image = gr.Image(label="Plot", elem_id="plot_image", width="100vw")
 
122
 
123
  submit_button.click(
124
- custom_function,
125
  inputs=[x_input, query_count, points, session_state],
126
- outputs=[output_image, query_count, remaining_queries_label, points_display, session_state]
127
  )
128
 
129
- app.launch(share=True)
 
 
 
 
130
 
 
131
 
 
 
 
4
  import cv2
5
  from io import BytesIO
6
  import uuid
 
7
  import os
8
  import requests
9
  import json
10
 
11
+
12
  AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN")
13
+ SUMMARY_TABLE_URL = os.environ.get("SUMMARY_TABLE_URL")
14
+ QUERIES_TABLE_URL = os.environ.get("QUERIES_TABLE_URL")
15
 
16
+ def send_to_airtable(data, url):
 
17
  headers = {
18
  "Authorization": f"Bearer {AIRTABLE_TOKEN}",
19
  "Content-Type": "application/json"
20
  }
21
  data = json.dumps(data)
 
22
  response = requests.post(url, headers=headers, data=data)
23
  print(response.text)
24
  return response.text
25
 
26
+ def log_final_decision(activity, objective, budget_left, session_id, url):
27
+ data = {"fields": {"activity": activity, "objective": objective, "budget_left": budget_left, "session_id": session_id}}
28
+ return send_to_airtable(data, url)
29
 
 
 
 
 
 
 
30
 
31
+ def log_query(session_id, x, y, url):
32
+ data = {"fields": {"session_id": session_id, "x": x, "y": y}}
33
+ return send_to_airtable(data, url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def draw_points(points, query_count, session_id=""):
36
  plt.figure(figsize=(12, 8))
 
40
  for i, (x_val, y_val) in enumerate(points):
41
  plt.scatter(x_val, y_val, color=cmap(i / len(points)))
42
  plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center')
 
43
  plt.xlabel('x')
44
  plt.ylabel('f(x)')
45
  plt.ylim(-1000, 1000)
 
62
  y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000)
63
  return y_value
64
 
65
+ def finalize_decision(global_minimizer, points, session_state, query_count, nqueries=10, activity="group_1"):
66
+ session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
67
+ remaining_queries = nqueries - query_count[0]
68
+ try:
69
+ x_star = float(global_minimizer)
70
+ y_value = f(x_star)
71
+ y_star = y_value - 10 * remaining_queries
72
+ session_end_msg = f"Session ends with remaining query {remaining_queries}, and objective {y_value:.2f} - {remaining_queries}*10 = {y_star:.2f}."
73
+
74
+ points.append((x_star, y_value))
75
+ log_query(session_id, x_star, y_value, QUERIES_TABLE_URL)
76
+ img = draw_points(points, query_count, session_id)
77
+ points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
78
+ #add "--- final decision" string to the final points_list element
79
+ points_list += "--- final decision"
80
+
81
+ log_final_decision(activity, y_value, remaining_queries, session_id, SUMMARY_TABLE_URL)
82
+
83
+ return [session_end_msg, session_end_msg, gr.Button("Submit", interactive=False), gr.Button("Finalize Decision", interactive=False), img, points_list]
84
+ except ValueError:
85
+ img = draw_points(points, query_count, session_id)
86
+ points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
87
+ msg = "Invalid input: Please enter a numeric value for x*"
88
+ return [msg, msg, gr.Button("Submit", interactive=True), gr.Button("Finalize Decision", interactive=True), img, points_list]
89
 
90
+ def submit_query(x, query_count, points, session_state, nqueries=10):
91
  session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6]
92
+ session_state['session_id'] = session_id
93
 
94
  if query_count[0] >= nqueries:
95
+ return None, query_count, "Query limit reached", "", session_state, "Query limit reached"
96
 
97
  try:
98
  x = float(x)
99
  except ValueError:
100
+ msg = "Invalid input: Please enter a numeric value"
101
+ return None, query_count, msg, "", session_state, msg
102
 
103
  y_value = f(x)
104
  points.append((x, y_value))
105
+ log_query(session_id, x, y_value, QUERIES_TABLE_URL)
106
  img = draw_points(points, query_count, session_id)
107
  query_count[0] += 1
108
  remaining_queries = nqueries - query_count[0]
109
  points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points])
110
+ msg = f"{remaining_queries} queries remaining"
111
+ return img, query_count, msg, points_list, session_state, msg
112
 
113
  with gr.Blocks() as app:
114
  query_count = gr.State([0])
115
  points = gr.State([])
116
+ session_state = gr.State({'session_id': str(uuid.uuid4())[:6]})
117
+
118
  with gr.Row():
119
  x_input = gr.Number(label="X")
120
+ submit_button = gr.Button("Submit", interactive=True)
121
+
122
+ with gr.Row():
123
+ remaining_queries_label = gr.Label()
124
+ points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="")
125
 
126
  with gr.Row():
127
+ output_image = gr.Image(label="Plot", width="100vw")
128
+
129
+ with gr.Row():
130
+ global_minimizer_input = gr.Textbox(label="Global Minimizer (x*)")
131
+ finalize_button = gr.Button("Finalize Decision")
132
 
133
  with gr.Row():
134
+ final_label = gr.Label()
135
+
136
 
137
  submit_button.click(
138
+ submit_query,
139
  inputs=[x_input, query_count, points, session_state],
140
+ outputs=[output_image, query_count, remaining_queries_label, points_display, session_state, final_label]
141
  )
142
 
143
+ finalize_button.click(
144
+ finalize_decision,
145
+ inputs=[global_minimizer_input, points, session_state, query_count],
146
+ outputs=[remaining_queries_label, final_label, submit_button, finalize_button, output_image, points_display]
147
+ )
148
 
149
+ app.launch(share=False)
150
 
151
+ #save the output into airtable
152
+ #save the queries into the correct table