|
|
import gradio as gr |
|
|
import requests |
|
|
import json |
|
|
import pandas as pd |
|
|
import io |
|
|
|
|
|
|
|
|
SERVER_URL = "http://103.235.229.133:7000/predict" |
|
|
|
|
|
|
|
|
|
|
|
FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct" |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_DATAFRAME_VALUE = [ |
|
|
["France", "capital city of", "Paris", "she traveled to France for the first time"], |
|
|
] |
|
|
|
|
|
|
|
|
def call_api(patched_layer_number, only_final_result, data_df): |
|
|
""" |
|
|
This function is called by Gradio. It collects inputs from the UI, |
|
|
formats them for the API, and sends the request. |
|
|
""" |
|
|
if data_df is None or data_df.empty: |
|
|
raise gr.Error("Input data cannot be empty! Please add at least one row to the table.") |
|
|
|
|
|
|
|
|
data_df_with_id = data_df.copy() |
|
|
data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))]) |
|
|
|
|
|
|
|
|
output = io.StringIO() |
|
|
data_df_with_id.to_csv(output, index=False) |
|
|
csv_data = output.getvalue() |
|
|
|
|
|
|
|
|
payload = { |
|
|
"data_csv": csv_data, |
|
|
"patched_layer_number": patched_layer_number, |
|
|
"only_final_result": only_final_result |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True) |
|
|
response.raise_for_status() |
|
|
result_json = response.json() |
|
|
|
|
|
dataframe_result_data = result_json.get("result", []) |
|
|
|
|
|
if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0: |
|
|
output_df = pd.DataFrame(dataframe_result_data) |
|
|
else: |
|
|
output_df = pd.DataFrame() |
|
|
|
|
|
return result_json, output_df |
|
|
|
|
|
except requests.exceptions.HTTPError as e: |
|
|
error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}" |
|
|
raise gr.Error(f"Request failed: {error_details}") |
|
|
except requests.exceptions.RequestException as e: |
|
|
raise gr.Error(f"Failed to connect to the server: {e}") |
|
|
except Exception as e: |
|
|
raise gr.Error(f"An unknown error occurred: {e}") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# ๐ Ml_patch Function Inference Interface |
|
|
An interactive web interface designed for the `Ml_patch` function. |
|
|
Please configure the model parameters and enter your test data below. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### 1. Configure Parameters") |
|
|
with gr.Group(): |
|
|
gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`") |
|
|
|
|
|
|
|
|
patched_layer_number_input = gr.Number( |
|
|
label="Patched Layer Number (patched_layer_number)", |
|
|
value=3, |
|
|
precision=0, |
|
|
info="Specify the target layer number for patching." |
|
|
) |
|
|
|
|
|
only_final_result_input = gr.Checkbox( |
|
|
label="Only Return Final Result (only_final_result)", |
|
|
value=False, |
|
|
info="If checked, the API will only return the final aggregated result, not detailed info for all layers." |
|
|
) |
|
|
|
|
|
gr.Markdown("### 2. Input Data") |
|
|
gr.Markdown("โน๏ธ **Instructions**: Click the **โ๏ธ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.") |
|
|
gr.Markdown("**[subject,relation,object]**: A factual triple.") |
|
|
gr.Markdown("**prompt_source**: A sentence which contains subject.") |
|
|
data_input = gr.DataFrame( |
|
|
label="Data Table", |
|
|
headers=["subject", "relation", "object", "prompt_source"], |
|
|
value=DEFAULT_DATAFRAME_VALUE, |
|
|
row_count=(5, "dynamic"), |
|
|
col_count=(4, "fixed"), |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Submit for Inference", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### 3. View Inference Results") |
|
|
gr.Markdown("**layer_source**: Starting from this layer of **prompt_source**, we extracted the hidden states corresponding to the subject.") |
|
|
gr.Markdown("**is_correct_patched**: Whether the final answer is correct. If **object** is in **generations**, we think the final answer is correct.") |
|
|
gr.Markdown("**generations**: The output obtained after the model continues inference following a patch operation.") |
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Results Table"): |
|
|
output_dataframe = gr.DataFrame( |
|
|
label="Inference Result Details", |
|
|
interactive=False, |
|
|
|
|
|
headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"] |
|
|
) |
|
|
with gr.TabItem("Raw JSON Response"): |
|
|
output_json = gr.JSON(label="Full JSON response from API") |
|
|
|
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=call_api, |
|
|
inputs=[patched_layer_number_input, only_final_result_input, data_input], |
|
|
outputs=[output_json, output_dataframe] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|