ML_Patch / app.py
lllyx's picture
Update app.py
3d60ca8 verified
import gradio as gr
import requests
import json
import pandas as pd
import io
# ----------------- Configure Your Server API Address -----------------
SERVER_URL = "http://103.235.229.133:7000/predict"
# ---------------------------------------------------------------------
# ----------------- Fixed Model Name -----------------
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
# ----------------------------------------------------
# Default input data for easy testing
DEFAULT_DATAFRAME_VALUE = [
["France", "capital city of", "Paris", "she traveled to France for the first time"],
]
# --- KEY CHANGE 1: Update the function signature to accept the new parameter ---
def call_api(patched_layer_number, only_final_result, data_df):
"""
This function is called by Gradio. It collects inputs from the UI,
formats them for the API, and sends the request.
"""
if data_df is None or data_df.empty:
raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")
# Create a copy and add an 'id' column
data_df_with_id = data_df.copy()
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
# Convert DataFrame to CSV string
output = io.StringIO()
data_df_with_id.to_csv(output, index=False)
csv_data = output.getvalue()
# --- KEY CHANGE 2: Add the new parameter to the API payload ---
payload = {
"data_csv": csv_data,
"patched_layer_number": patched_layer_number,
"only_final_result": only_final_result
}
try:
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
response.raise_for_status()
result_json = response.json()
dataframe_result_data = result_json.get("result", [])
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
output_df = pd.DataFrame(dataframe_result_data)
else:
output_df = pd.DataFrame()
return result_json, output_df
except requests.exceptions.HTTPError as e:
error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
raise gr.Error(f"Request failed: {error_details}")
except requests.exceptions.RequestException as e:
raise gr.Error(f"Failed to connect to the server: {e}")
except Exception as e:
raise gr.Error(f"An unknown error occurred: {e}")
# --- Use Gradio Blocks for a more flexible and beautiful interface ---
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
gr.Markdown(
"""
# ๐Ÿš€ Ml_patch Function Inference Interface
An interactive web interface designed for the `Ml_patch` function.
Please configure the model parameters and enter your test data below.
"""
)
with gr.Column():
gr.Markdown("### 1. Configure Parameters")
with gr.Group():
gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
# --- KEY CHANGE 3: Add the new Gradio input component ---
patched_layer_number_input = gr.Number(
label="Patched Layer Number (patched_layer_number)",
value=3, # Set a reasonable default value
precision=0, # This ensures the input is an integer
info="Specify the target layer number for patching."
)
only_final_result_input = gr.Checkbox(
label="Only Return Final Result (only_final_result)",
value=False,
info="If checked, the API will only return the final aggregated result, not detailed info for all layers."
)
gr.Markdown("### 2. Input Data")
gr.Markdown("โ„น๏ธ **Instructions**: Click the **โœ๏ธ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
gr.Markdown("**[subject,relation,object]**: A factual triple.")
gr.Markdown("**prompt_source**: A sentence which contains subject.")
data_input = gr.DataFrame(
label="Data Table",
headers=["subject", "relation", "object", "prompt_source"],
value=DEFAULT_DATAFRAME_VALUE,
row_count=(5, "dynamic"),
col_count=(4, "fixed"),
interactive=True
)
submit_btn = gr.Button("Submit for Inference", variant="primary")
with gr.Column():
gr.Markdown("### 3. View Inference Results")
gr.Markdown("**layer_source**: Starting from this layer of **prompt_source**, we extracted the hidden states corresponding to the subject.")
gr.Markdown("**is_correct_patched**: Whether the final answer is correct. If **object** is in **generations**, we think the final answer is correct.")
gr.Markdown("**generations**: The output obtained after the model continues inference following a patch operation.")
with gr.Tabs():
with gr.TabItem("Results Table"):
output_dataframe = gr.DataFrame(
label="Inference Result Details",
interactive=False,
# You might want to update headers if the output changes based on this parameter
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
)
with gr.TabItem("Raw JSON Response"):
output_json = gr.JSON(label="Full JSON response from API")
# --- KEY CHANGE 4: Update the click event to include the new input ---
# The order of inputs here MUST match the order of arguments in the `call_api` function.
submit_btn.click(
fn=call_api,
inputs=[patched_layer_number_input, only_final_result_input, data_input],
outputs=[output_json, output_dataframe]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()