Update app.py
Browse files
app.py
CHANGED
|
@@ -4,17 +4,17 @@ import json
|
|
| 4 |
import pandas as pd
|
| 5 |
import io
|
| 6 |
|
| 7 |
-
# -----------------
|
| 8 |
-
#
|
| 9 |
SERVER_URL = "http://103.235.229.133:7000/predict"
|
| 10 |
-
#
|
| 11 |
|
| 12 |
-
# -----------------
|
| 13 |
-
#
|
| 14 |
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
|
| 15 |
-
#
|
| 16 |
|
| 17 |
-
#
|
| 18 |
DEFAULT_DATAFRAME_VALUE = [
|
| 19 |
["France", "capital city of", "Paris", "she traveled to France for the first time"],
|
| 20 |
["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
|
|
@@ -23,29 +23,34 @@ DEFAULT_DATAFRAME_VALUE = [
|
|
| 23 |
|
| 24 |
def call_api(only_final_result, data_df):
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
"""
|
| 29 |
if data_df is None or data_df.empty:
|
| 30 |
-
raise gr.Error("
|
| 31 |
|
|
|
|
| 32 |
data_df_with_id = data_df.copy()
|
| 33 |
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
|
| 34 |
|
|
|
|
| 35 |
output = io.StringIO()
|
| 36 |
data_df_with_id.to_csv(output, index=False)
|
| 37 |
csv_data = output.getvalue()
|
| 38 |
|
|
|
|
| 39 |
payload = {
|
| 40 |
"data_csv": csv_data,
|
| 41 |
"only_final_result": only_final_result
|
| 42 |
}
|
| 43 |
|
| 44 |
try:
|
|
|
|
| 45 |
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
|
| 46 |
-
response.raise_for_status()
|
| 47 |
result_json = response.json()
|
| 48 |
|
|
|
|
| 49 |
dataframe_result_data = result_json.get("result", [])
|
| 50 |
|
| 51 |
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
|
|
@@ -56,65 +61,61 @@ def call_api(only_final_result, data_df):
|
|
| 56 |
return result_json, output_df
|
| 57 |
|
| 58 |
except requests.exceptions.HTTPError as e:
|
| 59 |
-
error_details = f"API
|
| 60 |
-
raise gr.Error(f"
|
| 61 |
except requests.exceptions.RequestException as e:
|
| 62 |
-
raise gr.Error(f"
|
| 63 |
except Exception as e:
|
| 64 |
-
raise gr.Error(f"
|
| 65 |
|
| 66 |
|
| 67 |
-
# ---
|
| 68 |
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
|
| 69 |
gr.Markdown(
|
| 70 |
"""
|
| 71 |
-
# 🚀 Ml_patch
|
| 72 |
-
|
|
|
|
| 73 |
"""
|
| 74 |
)
|
| 75 |
|
| 76 |
with gr.Column():
|
| 77 |
-
gr.Markdown("### 1.
|
| 78 |
with gr.Group():
|
| 79 |
-
gr.Markdown(f"
|
| 80 |
only_final_result_input = gr.Checkbox(
|
| 81 |
-
label="
|
| 82 |
value=False,
|
| 83 |
-
info="
|
| 84 |
)
|
| 85 |
|
| 86 |
-
gr.Markdown("### 2.
|
| 87 |
-
gr.Markdown("ℹ️
|
| 88 |
|
| 89 |
-
# -------------------- 主要修改部分在这里 --------------------
|
| 90 |
-
# 移除了 'height' 参数,因为它在旧版Gradio中不受支持。
|
| 91 |
-
# 我们通过增加 row_count 的第一个参数(初始行数)来为表格提供更多初始垂直空间。
|
| 92 |
data_input = gr.DataFrame(
|
| 93 |
-
label="
|
| 94 |
headers=["subject", "relation", "object", "prompt_source"],
|
| 95 |
value=DEFAULT_DATAFRAME_VALUE,
|
| 96 |
-
#
|
| 97 |
-
# (
|
| 98 |
row_count=(5, "dynamic"),
|
| 99 |
col_count=(4, "fixed"),
|
| 100 |
interactive=True
|
| 101 |
-
# 已移除不受支持的 'height' 参数
|
| 102 |
)
|
| 103 |
-
# -------------------- 修改结束 --------------------
|
| 104 |
|
| 105 |
-
submit_btn = gr.Button("
|
| 106 |
|
| 107 |
with gr.Column():
|
| 108 |
-
gr.Markdown("### 3.
|
| 109 |
with gr.Tabs():
|
| 110 |
-
with gr.TabItem("
|
| 111 |
output_dataframe = gr.DataFrame(
|
| 112 |
-
label="
|
| 113 |
interactive=False,
|
| 114 |
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 115 |
)
|
| 116 |
-
with gr.TabItem("
|
| 117 |
-
output_json = gr.JSON(label="
|
| 118 |
|
| 119 |
submit_btn.click(
|
| 120 |
fn=call_api,
|
|
@@ -122,6 +123,6 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
|
|
| 122 |
outputs=[output_json, output_dataframe]
|
| 123 |
)
|
| 124 |
|
| 125 |
-
#
|
| 126 |
if __name__ == "__main__":
|
| 127 |
demo.launch()
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import io
|
| 6 |
|
| 7 |
+
# ----------------- Configure Your Server API Address -----------------
|
| 8 |
+
# Replace the IP address and port with your own
|
| 9 |
SERVER_URL = "http://103.235.229.133:7000/predict"
|
| 10 |
+
# ---------------------------------------------------------------------
|
| 11 |
|
| 12 |
+
# ----------------- Fixed Model Name -----------------
|
| 13 |
+
# As per requirements, the model name is fixed and not a UI input.
|
| 14 |
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
|
| 15 |
+
# ----------------------------------------------------
|
| 16 |
|
| 17 |
+
# Default input data for easy testing
|
| 18 |
DEFAULT_DATAFRAME_VALUE = [
|
| 19 |
["France", "capital city of", "Paris", "she traveled to France for the first time"],
|
| 20 |
["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
|
|
|
|
| 23 |
|
| 24 |
def call_api(only_final_result, data_df):
|
| 25 |
"""
|
| 26 |
+
This function is called by Gradio. It collects inputs from the UI,
|
| 27 |
+
formats them for the API, and sends the request.
|
| 28 |
"""
|
| 29 |
if data_df is None or data_df.empty:
|
| 30 |
+
raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")
|
| 31 |
|
| 32 |
+
# Create a copy and add an 'id' column
|
| 33 |
data_df_with_id = data_df.copy()
|
| 34 |
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
|
| 35 |
|
| 36 |
+
# Convert DataFrame to CSV string
|
| 37 |
output = io.StringIO()
|
| 38 |
data_df_with_id.to_csv(output, index=False)
|
| 39 |
csv_data = output.getvalue()
|
| 40 |
|
| 41 |
+
# Prepare the payload for the API
|
| 42 |
payload = {
|
| 43 |
"data_csv": csv_data,
|
| 44 |
"only_final_result": only_final_result
|
| 45 |
}
|
| 46 |
|
| 47 |
try:
|
| 48 |
+
# Send the request to the server
|
| 49 |
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
|
| 50 |
+
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
| 51 |
result_json = response.json()
|
| 52 |
|
| 53 |
+
# Extract the result data for the DataFrame
|
| 54 |
dataframe_result_data = result_json.get("result", [])
|
| 55 |
|
| 56 |
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
|
|
|
|
| 61 |
return result_json, output_df
|
| 62 |
|
| 63 |
except requests.exceptions.HTTPError as e:
|
| 64 |
+
error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
|
| 65 |
+
raise gr.Error(f"Request failed: {error_details}")
|
| 66 |
except requests.exceptions.RequestException as e:
|
| 67 |
+
raise gr.Error(f"Failed to connect to the server: {e}")
|
| 68 |
except Exception as e:
|
| 69 |
+
raise gr.Error(f"An unknown error occurred: {e}")
|
| 70 |
|
| 71 |
|
| 72 |
+
# --- Use Gradio Blocks for a more flexible and beautiful interface ---
|
| 73 |
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
|
| 74 |
gr.Markdown(
|
| 75 |
"""
|
| 76 |
+
# 🚀 Ml_patch Function Inference Interface
|
| 77 |
+
An interactive web interface designed for the `Ml_patch` function.
|
| 78 |
+
Please configure the model parameters and enter your test data below.
|
| 79 |
"""
|
| 80 |
)
|
| 81 |
|
| 82 |
with gr.Column():
|
| 83 |
+
gr.Markdown("### 1. Configure Parameters")
|
| 84 |
with gr.Group():
|
| 85 |
+
gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
|
| 86 |
only_final_result_input = gr.Checkbox(
|
| 87 |
+
label="Only Return Final Result (only_final_result)",
|
| 88 |
value=False,
|
| 89 |
+
info="If checked, the API will only return the final aggregated result, not the detailed information for all layers."
|
| 90 |
)
|
| 91 |
|
| 92 |
+
gr.Markdown("### 2. Input Data")
|
| 93 |
+
gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
data_input = gr.DataFrame(
|
| 96 |
+
label="Data Table",
|
| 97 |
headers=["subject", "relation", "object", "prompt_source"],
|
| 98 |
value=DEFAULT_DATAFRAME_VALUE,
|
| 99 |
+
# Set the initial number of rows to 5 to give the table more vertical space on load.
|
| 100 |
+
# (rows_to_display, "dynamic" allows user to add/remove rows)
|
| 101 |
row_count=(5, "dynamic"),
|
| 102 |
col_count=(4, "fixed"),
|
| 103 |
interactive=True
|
|
|
|
| 104 |
)
|
|
|
|
| 105 |
|
| 106 |
+
submit_btn = gr.Button("Submit for Inference", variant="primary")
|
| 107 |
|
| 108 |
with gr.Column():
|
| 109 |
+
gr.Markdown("### 3. View Inference Results")
|
| 110 |
with gr.Tabs():
|
| 111 |
+
with gr.TabItem("Results Table"):
|
| 112 |
output_dataframe = gr.DataFrame(
|
| 113 |
+
label="Inference Result Details",
|
| 114 |
interactive=False,
|
| 115 |
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
|
| 116 |
)
|
| 117 |
+
with gr.TabItem("Raw JSON Response"):
|
| 118 |
+
output_json = gr.JSON(label="Full JSON response from API")
|
| 119 |
|
| 120 |
submit_btn.click(
|
| 121 |
fn=call_api,
|
|
|
|
| 123 |
outputs=[output_json, output_dataframe]
|
| 124 |
)
|
| 125 |
|
| 126 |
+
# Launch the Gradio app
|
| 127 |
if __name__ == "__main__":
|
| 128 |
demo.launch()
|