Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
from io import BytesIO | |
import uuid | |
import csv | |
import os | |
import requests | |
import json | |
AIRTABLE_TOKEN = os.environ.get("AIRTABLE_TOKEN") | |
def send_to_airtable(data): | |
url = "https://api.airtable.com/v0/appHgBulDLVnlf3uS/simt_ap05032024" | |
headers = { | |
"Authorization": f"Bearer {AIRTABLE_TOKEN}", | |
"Content-Type": "application/json" | |
} | |
data = json.dumps(data) | |
print(data) | |
response = requests.post(url, headers=headers, data=data) | |
print(response.text) | |
return response.text | |
# Path to the CSV file for storing session data | |
data_file_path = 'queries_data.csv' | |
# Function to initialize the CSV file if it does not exist | |
def initialize_data_file(): | |
if not os.path.exists(data_file_path): | |
with open(data_file_path, 'w', newline='') as file: | |
writer = csv.writer(file) | |
writer.writerow(['session_id', 'x', 'y']) | |
# Ensure the data file is initialized | |
initialize_data_file() | |
def log_data(session_id, x, y): | |
with open(data_file_path, 'a', newline='') as file: | |
writer = csv.writer(file) | |
writer.writerow([session_id, x, y]) | |
#send to airtable | |
data = { | |
"fields": { | |
"session_id": session_id, | |
"x": x, | |
"y": y | |
} | |
} | |
send_to_airtable(data) | |
def draw_points(points, query_count, session_id=""): | |
plt.figure(figsize=(12, 8)) | |
if points: | |
x_vals, y_vals = zip(*points) | |
cmap = plt.get_cmap('viridis') | |
for i, (x_val, y_val) in enumerate(points): | |
plt.scatter(x_val, y_val, color=cmap(i / len(points))) | |
plt.text(x_val, y_val + 50, str(i + 1), fontsize=7, ha='center') | |
plt.xlabel('x') | |
plt.ylabel('f(x)') | |
plt.ylim(-1000, 1000) | |
plt.grid(True) | |
plt.axhline(0, color='black', linewidth=0.5) | |
plt.axvline(0, color='black', linewidth=0.5) | |
plt.title(f'Sampled Points - Query {query_count[0] + 1} ({session_id})') | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
image = np.frombuffer(buf.getvalue(), dtype=np.uint8) | |
image = cv2.imdecode(image, cv2.IMREAD_COLOR) | |
return image | |
def f(x, scale=1000): | |
x_scaled = x / scale | |
polynomial = (x_scaled**4 - 10*x_scaled**3 + 35*x_scaled**2 - 50*x_scaled + 24) / 4000 | |
sinusoidal = 800 * np.sin(abs(x_scaled)/10) + 500 * np.cos(x_scaled/8) + 200 * np.sin(x_scaled/2) | |
y_value = 800 * np.tanh((polynomial + sinusoidal) / 1000) | |
return y_value | |
def custom_function(x, query_count, points, session_state, nqueries=10): | |
session_id = session_state['session_id'] if session_state['session_id'] else str(uuid.uuid4())[:6] | |
session_state['session_id'] = session_id # Update state with the current or new session ID | |
if query_count[0] >= nqueries: | |
return None, query_count, "Query limit reached", "", session_state | |
try: | |
x = float(x) | |
except ValueError: | |
return None, query_count, "Invalid input: Please enter a numeric value", "", session_state | |
y_value = f(x) | |
points.append((x, y_value)) | |
log_data(session_id, x, y_value) | |
img = draw_points(points, query_count, session_id) | |
query_count[0] += 1 | |
remaining_queries = nqueries - query_count[0] | |
points_list = '\n'.join([f"x: {p[0]:.0f}, y: {p[1]:.3f}" for p in points]) | |
return img, query_count, f"{remaining_queries} queries remaining", points_list, session_state | |
with gr.Blocks() as app: | |
query_count = gr.State([0]) | |
points = gr.State([]) | |
session_state = gr.State({'session_id': None}) # Initialize state for session_id | |
with gr.Row(): | |
x_input = gr.Number(label="X") | |
submit_button = gr.Button("Submit") | |
with gr.Row(): | |
remaining_queries_label = gr.Label(elem_id="remaining_queries") | |
points_display = gr.Textbox(label="Queried Points", lines=10, interactive=False, value="", elem_id="queried_points") | |
with gr.Row(): | |
output_image = gr.Image(label="Plot", elem_id="plot_image", width="100vw") | |
submit_button.click( | |
custom_function, | |
inputs=[x_input, query_count, points, session_state], | |
outputs=[output_image, query_count, remaining_queries_label, points_display, session_state] | |
) | |
app.launch(share=True) | |