lihongze8 commited on
Commit
bde5ca7
1 Parent(s): 822b0c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -42
app.py CHANGED
@@ -1,15 +1,28 @@
1
  import os
2
- import subprocess
3
  import sys
 
4
  import torch
 
 
 
 
 
 
5
  import gradio as gr
 
 
6
  from transformers import AutoTokenizer
7
 
8
- # 1. 若本地(或该Space)没有 skywork-o1-prm-inference 目录,则clone下来
 
 
9
  def setup_environment():
10
  if not os.path.exists("skywork-o1-prm-inference"):
11
- print("Cloning repository...")
12
- subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
 
 
 
13
  repo_path = os.path.abspath("skywork-o1-prm-inference")
14
  if repo_path not in sys.path:
15
  sys.path.append(repo_path)
@@ -17,23 +30,33 @@ def setup_environment():
17
 
18
  setup_environment()
19
 
20
- # 2. 在环境准备完成后,再导入项目内的模块
 
 
21
  from model_utils.prm_model import PRM_MODEL
22
  from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
23
 
24
- # 3. 模型ID:可根据你的需求替换
 
 
 
25
  model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
 
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
27
  model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
28
 
29
- # 4. 定义推理函数
30
- def compute_rewards(problem, response):
31
- # 准备数据
 
 
32
  data = {
33
  "problem": problem,
34
  "response": response
35
  }
36
- # 进行格式化
37
  processed_data = [
38
  prepare_input(
39
  data["problem"],
@@ -50,12 +73,13 @@ def compute_rewards(problem, response):
50
  reward_flags,
51
  tokenizer.pad_token_id
52
  )
 
53
  input_ids = input_ids.to("cpu")
54
  attention_mask = attention_mask.to("cpu")
55
  if isinstance(reward_flags, torch.Tensor):
56
  reward_flags = reward_flags.to("cpu")
57
 
58
- # 前向传播
59
  with torch.no_grad():
60
  _, _, rewards = model(
61
  input_ids=input_ids,
@@ -67,38 +91,78 @@ def compute_rewards(problem, response):
67
  step_rewards = derive_step_rewards(rewards, reward_flags)
68
  return step_rewards[0]
69
 
70
- # 5. Gradio 可视化界面
71
- with gr.Blocks() as demo:
72
- gr.Markdown("## PRM Reward Calculation\n"
73
- "在这里输入 problem response,即可计算对每个 step reward。")
74
-
75
- problem_input = gr.Textbox(
76
- label="Problem",
77
- lines=5,
78
- value=(
79
- "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
80
- "and bakes muffins for her friends every day with four. She sells the remainder "
81
- "at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
82
- "does she make every day at the farmers' market?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
- )
85
- response_input = gr.Textbox(
86
- label="Response",
87
- lines=10,
88
- value=(
89
- "To determine how much money Janet makes every day at the farmers' market, "
90
- "we need to follow these steps:\n1. ..."
 
 
 
91
  )
92
- )
93
- output = gr.JSON(label="Step Rewards")
 
 
 
 
 
 
 
94
 
95
- submit_btn = gr.Button("Compute Rewards")
96
- submit_btn.click(
97
- fn=compute_rewards,
98
- inputs=[problem_input, response_input],
99
- outputs=output
100
- )
 
 
 
 
 
 
 
101
 
102
- # 6. 启动 Gradio 服务
103
- # 注意:在 Hugging Face Spaces 环境中,推荐使用 server_name="0.0.0.0",端口常用 7860/7861。
104
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
1
  import os
 
2
  import sys
3
+ import subprocess
4
  import torch
5
+ from typing import Union
6
+ from pydantic import BaseModel
7
+ from fastapi import FastAPI
8
+ import uvicorn
9
+
10
+ # gradio相关
11
  import gradio as gr
12
+
13
+ # transformers相关
14
  from transformers import AutoTokenizer
15
 
16
+ ##############################################################################
17
+ # 1) 若本地(或Space)没有 skywork-o1-prm-inference 目录,则 clone 下来,并将其加入 sys.path
18
+ ##############################################################################
19
  def setup_environment():
20
  if not os.path.exists("skywork-o1-prm-inference"):
21
+ print("Cloning repository skywork-o1-prm-inference...")
22
+ subprocess.run(
23
+ ["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"],
24
+ check=True
25
+ )
26
  repo_path = os.path.abspath("skywork-o1-prm-inference")
27
  if repo_path not in sys.path:
28
  sys.path.append(repo_path)
 
30
 
31
  setup_environment()
32
 
33
+ ##############################################################################
34
+ # 2) 导入Skywork项目内的模块
35
+ ##############################################################################
36
  from model_utils.prm_model import PRM_MODEL
37
  from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
38
 
39
+ ##############################################################################
40
+ # 3) 加载模型及其相关资源
41
+ ##############################################################################
42
+ # 你可根据需求更换 model_id
43
  model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
44
+
45
+ print(f"Loading tokenizer from {model_id} ...")
46
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
47
+
48
+ print(f"Loading model from {model_id} ...")
49
  model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
50
 
51
+ ##############################################################################
52
+ # 4) 定义推理函数
53
+ ##############################################################################
54
+ def compute_rewards(problem: str, response: str):
55
+ """核心推理函数:将 problem + response 输入模型,输出 step_rewards。"""
56
  data = {
57
  "problem": problem,
58
  "response": response
59
  }
 
60
  processed_data = [
61
  prepare_input(
62
  data["problem"],
 
73
  reward_flags,
74
  tokenizer.pad_token_id
75
  )
76
+
77
  input_ids = input_ids.to("cpu")
78
  attention_mask = attention_mask.to("cpu")
79
  if isinstance(reward_flags, torch.Tensor):
80
  reward_flags = reward_flags.to("cpu")
81
 
82
+ # 前向推理
83
  with torch.no_grad():
84
  _, _, rewards = model(
85
  input_ids=input_ids,
 
91
  step_rewards = derive_step_rewards(rewards, reward_flags)
92
  return step_rewards[0]
93
 
94
+ ##############################################################################
95
+ # 5) 准备FastAPI应用:对外暴露 /api/predict 接口
96
+ ##############################################################################
97
+ app = FastAPI(title="PRM Inference Server", description="FastAPI + Gradio App")
98
+
99
+ # 5.1 定义API输入与输出的数据结构
100
+ class InferenceData(BaseModel):
101
+ problem: str
102
+ response: str
103
+
104
+ class InferenceOutput(BaseModel):
105
+ step_rewards: list
106
+
107
+ @app.post("/api/predict", response_model=InferenceOutput)
108
+ def predict(data: InferenceData):
109
+ """
110
+ 直接使用HTTP POST /api/predict,
111
+ body JSON: { "problem": "...", "response": "..." } 即可得到 step_rewards。
112
+ """
113
+ rewards = compute_rewards(data.problem, data.response)
114
+ return InferenceOutput(step_rewards=rewards)
115
+
116
+ ##############################################################################
117
+ # 6) 构建 Gradio 界面,并将其挂载到 /gradio 路径
118
+ ##############################################################################
119
+ def build_gradio_interface():
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown(
122
+ "## PRM Reward Calculation\n\n"
123
+ "输入 `Problem` 和 `Response`,点击下方按钮即可获得 step_rewards。"
124
  )
125
+
126
+ problem_input = gr.Textbox(
127
+ label="Problem",
128
+ lines=5,
129
+ value=(
130
+ "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
131
+ "and bakes muffins for her friends every day with four. She sells the remainder "
132
+ "at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
133
+ "does she make every day at the farmers' market?"
134
+ )
135
  )
136
+ response_input = gr.Textbox(
137
+ label="Response",
138
+ lines=10,
139
+ value=(
140
+ "To determine how much money Janet makes every day at the farmers' market, "
141
+ "we need to follow these steps:\n1. ..."
142
+ )
143
+ )
144
+ output = gr.JSON(label="Step Rewards")
145
 
146
+ submit_btn = gr.Button("Compute Rewards")
147
+ submit_btn.click(
148
+ fn=compute_rewards,
149
+ inputs=[problem_input, response_input],
150
+ outputs=output
151
+ )
152
+
153
+ return demo
154
+
155
+ demo = build_gradio_interface()
156
+
157
+ # 6.1 挂载Gradio到FastAPI的 /gradio 路径
158
+ app = gr.mount_gradio_app(app, demo, path="/gradio")
159
 
160
+ ##############################################################################
161
+ # 7) main 中用 uvicorn 启动
162
+ ##############################################################################
163
+ # 注意:在Hugging Face Spaces中,一般只需要 `python app.py` 即可开始监听。
164
+ # 若Spaces是自动检测并执行 `app.py`,则不一定会执行 `if __name__ == "__main__"` 分支。
165
+ # 因此你只要确保上面定义的 `app` 变量存在即可。
166
+ # 这里写一个可选的 main,方便本地调试。
167
+ if __name__ == "__main__":
168
+ uvicorn.run(app, host="0.0.0.0", port=7860)