zhiqings commited on
Commit
1cb09e4
1 Parent(s): c68d625

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +80 -0
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPRMs (Outcome \& Process Reward Models) are trained to predict the correctness of each step on the position of "\n\n", as well as the correctness of the whole solution on the position of "\<eos\>".
2
+
3
+ Usage:
4
+
5
+ ```python
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ model_name = "ScalableMath/llemma-7b-oprm-prm800k-level-1to3-hf"
10
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")
13
+
14
+ qa_example = """# Question
15
+
16
+ Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
17
+
18
+ # Solution
19
+
20
+ To convert from rectangular to polar coordinates, I need to use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \tan^{-1}(y/x).$
21
+
22
+ In this case, $x = 0$ and $y = 3,$ so I can plug them into the formulas.
23
+
24
+ For $r,$ I get $r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3.$
25
+
26
+ For $\theta,$ I get $\theta = \tan^{-1}(3/0).$
27
+
28
+ This is undefined, since the tangent function is not defined at $0.$
29
+
30
+ However, I can use the fact that the point $(0,3)$ lies on the positive $y$-axis, which has an angle of $\pi/2$ radians or $90^\circ.$
31
+
32
+ Therefore, I can choose any angle in the range $(0,\pi/2)$ as the value of $\theta.$
33
+
34
+ I will choose $\theta = \pi/2,$ since it is the simplest and most natural choice.
35
+
36
+ Therefore, the polar coordinates of the point $(0,3)$ are $(3,\pi/2).$
37
+
38
+ # Answer
39
+
40
+ (3,\pi/2)"""
41
+
42
+ begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
43
+ scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
44
+ eos_token = tokenizer.eos_token_id
45
+
46
+ input_ids = tokenizer.encode(qa_example)
47
+
48
+ begin_solution_flag = False
49
+
50
+ candidate_positions = []
51
+
52
+ for start_idx in range(len(input_ids)):
53
+ if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
54
+ begin_solution_flag = True
55
+
56
+ if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
57
+ candidate_positions.append(start_idx)
58
+
59
+ if input_ids[start_idx] == eos_token:
60
+ candidate_positions.append(start_idx)
61
+ break
62
+
63
+ # maybe delete the first and the second to last candidate_positions
64
+ # because they are "\n\n" after "# Solution" and after "# Answer"
65
+ del candidate_positions[0]
66
+ del candidate_positions[-2]
67
+
68
+ input_tensor = torch.tensor([input_ids])
69
+ candidate_positions = torch.tensor(candidate_positions)
70
+
71
+ with torch.no_grad():
72
+ logits = model(input_tensor).logits
73
+ scores =logits.mean(dim=-1)
74
+ step_scores = scores[0][candidate_positions]
75
+ step_probs = torch.sigmoid(step_scores)
76
+
77
+ print(step_probs)
78
+
79
+ # tensor([0.8093, 0.9566, 0.9872, 0.9890, 0.9797, 0.3090, 0.8044, 0.7677, 0.8105, 0.5247])
80
+ ```