lllyx commited on
Commit
f63b31c
·
verified ·
1 Parent(s): a628b4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -4,17 +4,17 @@ import json
4
  import pandas as pd
5
  import io
6
 
7
- # ----------------- 配置你的服务器API地址 -----------------
8
- # IP地址和端口替换为你自己的
9
  SERVER_URL = "http://103.235.229.133:7000/predict"
10
- # -------------------------------------------------------
11
 
12
- # ----------------- 固定的模型名称 -----------------
13
- # 根据要求,模型名称是固定的,不再作为UI输入
14
  FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
15
- # -------------------------------------------------------
16
 
17
- # 默认的输入数据,方便用户测试
18
  DEFAULT_DATAFRAME_VALUE = [
19
  ["France", "capital city of", "Paris", "she traveled to France for the first time"],
20
  ["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
@@ -23,29 +23,34 @@ DEFAULT_DATAFRAME_VALUE = [
23
 
24
  def call_api(only_final_result, data_df):
25
  """
26
- 这个函数会被Gradio调用,它负责收集UI上的输入,
27
- 将其转换成API需要的格式,然后发送请求。
28
  """
29
  if data_df is None or data_df.empty:
30
- raise gr.Error("输入数据不能为空!请在表格中至少添加一行数据。")
31
 
 
32
  data_df_with_id = data_df.copy()
33
  data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
34
 
 
35
  output = io.StringIO()
36
  data_df_with_id.to_csv(output, index=False)
37
  csv_data = output.getvalue()
38
 
 
39
  payload = {
40
  "data_csv": csv_data,
41
  "only_final_result": only_final_result
42
  }
43
 
44
  try:
 
45
  response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
46
- response.raise_for_status()
47
  result_json = response.json()
48
 
 
49
  dataframe_result_data = result_json.get("result", [])
50
 
51
  if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
@@ -56,65 +61,61 @@ def call_api(only_final_result, data_df):
56
  return result_json, output_df
57
 
58
  except requests.exceptions.HTTPError as e:
59
- error_details = f"API返回错误 (状态码: {e.response.status_code}):\n{e.response.text}"
60
- raise gr.Error(f"请求失败: {error_details}")
61
  except requests.exceptions.RequestException as e:
62
- raise gr.Error(f"连接服务器失败: {e}")
63
  except Exception as e:
64
- raise gr.Error(f"发生未知错误: {e}")
65
 
66
 
67
- # --- 使用Gradio Blocks创建更灵活、更美观的界面 ---
68
  with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
69
  gr.Markdown(
70
  """
71
- # 🚀 Ml_patch 函数推理接口
72
- 一个为 `Ml_patch` 函数设计的交互式Web界面。请在下方配置模型参数并输入您要测试的数据。
 
73
  """
74
  )
75
 
76
  with gr.Column():
77
- gr.Markdown("### 1. 配置参数")
78
  with gr.Group():
79
- gr.Markdown(f"**当前模型 (固定):** `{FIXED_MODEL_NAME}`")
80
  only_final_result_input = gr.Checkbox(
81
- label="仅返回最终结果 (only_final_result)",
82
  value=False,
83
- info="勾选此项,API将只返回最终聚合结果,而不是所有层的详细信息。"
84
  )
85
 
86
- gr.Markdown("### 2. 输入数据")
87
- gr.Markdown("ℹ️ **操作提示**: 点击表格右下方的 **✏️ Edit** 按钮,然后点击 **+ Add row** 来添加新行,或点击垃圾桶图标删除行。")
88
 
89
- # -------------------- 主要修改部分在这里 --------------------
90
- # 移除了 'height' 参数,因为它在旧版Gradio中不受支持。
91
- # 我们通过增加 row_count 的第一个参数(初始行数)来为表格提供更多初始垂直空间。
92
  data_input = gr.DataFrame(
93
- label="数据表格",
94
  headers=["subject", "relation", "object", "prompt_source"],
95
  value=DEFAULT_DATAFRAME_VALUE,
96
- # 将初始行数从3增加到5,使表格在加载时就更高。
97
- # (要显示的行数, "dynamic"表示用户可以增删行)
98
  row_count=(5, "dynamic"),
99
  col_count=(4, "fixed"),
100
  interactive=True
101
- # 已移除不受支持的 'height' 参数
102
  )
103
- # -------------------- 修改结束 --------------------
104
 
105
- submit_btn = gr.Button("提交推理", variant="primary")
106
 
107
  with gr.Column():
108
- gr.Markdown("### 3. 查看推理结果")
109
  with gr.Tabs():
110
- with gr.TabItem("结果表格"):
111
  output_dataframe = gr.DataFrame(
112
- label="推理结果详情",
113
  interactive=False,
114
  headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
115
  )
116
- with gr.TabItem("原始JSON响应"):
117
- output_json = gr.JSON(label="API 返回的完整JSON")
118
 
119
  submit_btn.click(
120
  fn=call_api,
@@ -122,6 +123,6 @@ with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280p
122
  outputs=[output_json, output_dataframe]
123
  )
124
 
125
- # 启动Gradio应用
126
  if __name__ == "__main__":
127
  demo.launch()
 
4
  import pandas as pd
5
  import io
6
 
7
+ # ----------------- Configure Your Server API Address -----------------
8
+ # Replace the IP address and port with your own
9
  SERVER_URL = "http://103.235.229.133:7000/predict"
10
+ # ---------------------------------------------------------------------
11
 
12
+ # ----------------- Fixed Model Name -----------------
13
+ # As per requirements, the model name is fixed and not a UI input.
14
  FIXED_MODEL_NAME = "Qwen2.5-1.5B-Instruct"
15
+ # ----------------------------------------------------
16
 
17
+ # Default input data for easy testing
18
  DEFAULT_DATAFRAME_VALUE = [
19
  ["France", "capital city of", "Paris", "she traveled to France for the first time"],
20
  ["Germany", "capital city of", "Berlin", "Germany is a country in Central Europe"],
 
23
 
24
  def call_api(only_final_result, data_df):
25
  """
26
+ This function is called by Gradio. It collects inputs from the UI,
27
+ formats them for the API, and sends the request.
28
  """
29
  if data_df is None or data_df.empty:
30
+ raise gr.Error("Input data cannot be empty! Please add at least one row to the table.")
31
 
32
+ # Create a copy and add an 'id' column
33
  data_df_with_id = data_df.copy()
34
  data_df_with_id.insert(0, 'id', [f'{i:03d}' for i in range(len(data_df_with_id))])
35
 
36
+ # Convert DataFrame to CSV string
37
  output = io.StringIO()
38
  data_df_with_id.to_csv(output, index=False)
39
  csv_data = output.getvalue()
40
 
41
+ # Prepare the payload for the API
42
  payload = {
43
  "data_csv": csv_data,
44
  "only_final_result": only_final_result
45
  }
46
 
47
  try:
48
+ # Send the request to the server
49
  response = requests.post(SERVER_URL, json=payload, timeout=300, stream=True)
50
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
51
  result_json = response.json()
52
 
53
+ # Extract the result data for the DataFrame
54
  dataframe_result_data = result_json.get("result", [])
55
 
56
  if isinstance(dataframe_result_data, list) and len(dataframe_result_data) > 0:
 
61
  return result_json, output_df
62
 
63
  except requests.exceptions.HTTPError as e:
64
+ error_details = f"API returned an error (Status Code: {e.response.status_code}):\n{e.response.text}"
65
+ raise gr.Error(f"Request failed: {error_details}")
66
  except requests.exceptions.RequestException as e:
67
+ raise gr.Error(f"Failed to connect to the server: {e}")
68
  except Exception as e:
69
+ raise gr.Error(f"An unknown error occurred: {e}")
70
 
71
 
72
+ # --- Use Gradio Blocks for a more flexible and beautiful interface ---
73
  with gr.Blocks(theme=gr.themes.Glass(), css=".gradio-container {max-width: 1280px !important; margin: auto;}") as demo:
74
  gr.Markdown(
75
  """
76
+ # 🚀 Ml_patch Function Inference Interface
77
+ An interactive web interface designed for the `Ml_patch` function.
78
+ Please configure the model parameters and enter your test data below.
79
  """
80
  )
81
 
82
  with gr.Column():
83
+ gr.Markdown("### 1. Configure Parameters")
84
  with gr.Group():
85
+ gr.Markdown(f"**Current Model (Fixed):** `{FIXED_MODEL_NAME}`")
86
  only_final_result_input = gr.Checkbox(
87
+ label="Only Return Final Result (only_final_result)",
88
  value=False,
89
+ info="If checked, the API will only return the final aggregated result, not the detailed information for all layers."
90
  )
91
 
92
+ gr.Markdown("### 2. Input Data")
93
+ gr.Markdown("ℹ️ **Instructions**: Click the **✏️ Edit** button at the bottom right of the table, then click **+ Add row** to add a new row, or the trash can icon to delete a row.")
94
 
 
 
 
95
  data_input = gr.DataFrame(
96
+ label="Data Table",
97
  headers=["subject", "relation", "object", "prompt_source"],
98
  value=DEFAULT_DATAFRAME_VALUE,
99
+ # Set the initial number of rows to 5 to give the table more vertical space on load.
100
+ # (rows_to_display, "dynamic" allows user to add/remove rows)
101
  row_count=(5, "dynamic"),
102
  col_count=(4, "fixed"),
103
  interactive=True
 
104
  )
 
105
 
106
+ submit_btn = gr.Button("Submit for Inference", variant="primary")
107
 
108
  with gr.Column():
109
+ gr.Markdown("### 3. View Inference Results")
110
  with gr.Tabs():
111
+ with gr.TabItem("Results Table"):
112
  output_dataframe = gr.DataFrame(
113
+ label="Inference Result Details",
114
  interactive=False,
115
  headers=["id", "subject", "relation", "object", "prompt_source", "prompt_target", "layer_source", "is_correct_patched", "generations"]
116
  )
117
+ with gr.TabItem("Raw JSON Response"):
118
+ output_json = gr.JSON(label="Full JSON response from API")
119
 
120
  submit_btn.click(
121
  fn=call_api,
 
123
  outputs=[output_json, output_dataframe]
124
  )
125
 
126
+ # Launch the Gradio app
127
  if __name__ == "__main__":
128
  demo.launch()