Create core.py
Browse files
core.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from trl.core import top_k_top_p_filtering
|
3 |
+
from torch import nn
|
4 |
+
from typing import List
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Any
|
7 |
+
from openai import OpenAI
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
class LLM:
|
11 |
+
def __init__(self, config):
|
12 |
+
self.system_prompt = config["system_prompt"]
|
13 |
+
self.config = config
|
14 |
+
|
15 |
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
16 |
+
pass
|
17 |
+
|
18 |
+
class LocalLLM(LLM):
|
19 |
+
def __init__(self, config):
|
20 |
+
super().__init__(config)
|
21 |
+
openai_api_key = "aca586edef32c3f95ed93a70830fe9fb0c38c3f408de8054370c501db0c65268"
|
22 |
+
openai_api_base = "https://api.together.xyz"
|
23 |
+
|
24 |
+
self.client = OpenAI(
|
25 |
+
api_key=openai_api_key,
|
26 |
+
base_url=openai_api_base,
|
27 |
+
)
|
28 |
+
|
29 |
+
def __call__(self, prompt: str):
|
30 |
+
# prompt = truncate_tokens_hf(prompt, encoding_name=self.config["model"])
|
31 |
+
response = self.client.chat.completions.create(
|
32 |
+
temperature=0,
|
33 |
+
model=self.config["model"],
|
34 |
+
messages=[
|
35 |
+
{"role": "system", "content": self.system_prompt},
|
36 |
+
{"role": "user", "content": prompt},
|
37 |
+
],
|
38 |
+
max_tokens=None
|
39 |
+
)
|
40 |
+
return response.choices[0].message.content
|
41 |
+
|
42 |
+
def respond_to_batch(
|
43 |
+
model: nn.Module, step_tag_id: int, stop_id: int, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0, sampling_steps: int = 1, mini_batch_size: int = 8, current_step=None, tokenizer=None,
|
44 |
+
) -> torch.LongTensor:
|
45 |
+
"""Sample text from language model using multi-GPU."""
|
46 |
+
input_ids = queries
|
47 |
+
step_counter = torch.zeros(len(queries)).to(queries[0].device)
|
48 |
+
stop_counter = torch.zeros(len(queries)).to(queries[0].device)
|
49 |
+
stop_indexes = [-1 for _ in range(len(queries))]
|
50 |
+
|
51 |
+
for _i in range(txt_len):
|
52 |
+
# Get Logits
|
53 |
+
next_token_logits = []
|
54 |
+
for mini_batch_input_ids in torch.split(input_ids, mini_batch_size):
|
55 |
+
mini_batch_input_ids = mini_batch_input_ids.to(model.module.device)
|
56 |
+
mini_batch_outputs = model(mini_batch_input_ids)
|
57 |
+
mini_next_token_logits = mini_batch_outputs.logits[:, -1, :]
|
58 |
+
next_token_logits.append(mini_next_token_logits)
|
59 |
+
|
60 |
+
next_token_logits = torch.cat(next_token_logits, dim=0)
|
61 |
+
next_token_logits /= 0.7
|
62 |
+
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
63 |
+
|
64 |
+
# Sample
|
65 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
66 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
67 |
+
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
|
68 |
+
step_counter += (next_token == step_tag_id).detach().long()
|
69 |
+
|
70 |
+
# Stop if all sequences have reached step_tag_id sampling_steps times
|
71 |
+
stop_counter += (next_token == stop_id).detach().long()
|
72 |
+
stop_indexes = [len(queries[i]) + _i if ((stop_counter[i] == 1) and (next_token[i] == stop_id)) else stop_indexes[i] for i in range(len(queries))]
|
73 |
+
|
74 |
+
if (step_counter >= sampling_steps).all() or (stop_counter > 0).all():
|
75 |
+
break
|
76 |
+
|
77 |
+
del next_token_logits
|
78 |
+
|
79 |
+
truncated_sequences = []
|
80 |
+
step_sequences = []
|
81 |
+
for i, sample in enumerate(input_ids):
|
82 |
+
max_sampled_steps = len((sample == step_tag_id).nonzero(as_tuple=True)[0])
|
83 |
+
if step_tag_id in sample:
|
84 |
+
step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][min(current_step+sampling_steps, max_sampled_steps)-1]
|
85 |
+
if len((sample == step_tag_id).nonzero(as_tuple=True)[0])-1 >= current_step:
|
86 |
+
current_step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][current_step]
|
87 |
+
else:
|
88 |
+
current_step_tag_index = 100000
|
89 |
+
stop_index = stop_indexes[i]
|
90 |
+
if stop_index != -1:
|
91 |
+
_sample = sample[len(queries[i]):stop_index+1]
|
92 |
+
step_sample = sample[len(queries[i]):min(current_step_tag_index, stop_index)+1]
|
93 |
+
else:
|
94 |
+
_sample = sample[len(queries[i]):step_tag_index+1]
|
95 |
+
step_sample = sample[len(queries[i]):current_step_tag_index+1]
|
96 |
+
truncated_sequences.append(_sample)
|
97 |
+
step_sequences.append(step_sample)
|
98 |
+
else:
|
99 |
+
return None, True, None
|
100 |
+
return truncated_sequences, False, step_sequences
|
101 |
+
|
102 |
+
|
103 |
+
def get_rewards(rm_model, rm_tokenizer, problem_statement, proposed_solutions, candidate_tokens, step_tag_id):
|
104 |
+
scores_list = []
|
105 |
+
for solution in proposed_solutions:
|
106 |
+
input_for_prm = f"{problem_statement} {solution}"
|
107 |
+
input_id = torch.tensor([rm_tokenizer.encode(input_for_prm)])
|
108 |
+
input_id = input_id.to('cuda:1')
|
109 |
+
# get the score of 2 last steps of solution
|
110 |
+
steps = solution.split("ки")
|
111 |
+
steps = [step.strip() for step in steps]
|
112 |
+
steps = steps[-2:]
|
113 |
+
steps = " ки".join(steps)
|
114 |
+
|
115 |
+
with torch.no_grad():
|
116 |
+
logits = rm_model(input_id).logits[:,:,candidate_tokens]
|
117 |
+
scores = logits.softmax(dim=-1)[:,:,0]
|
118 |
+
step_scores = scores[input_id == step_tag_id]
|
119 |
+
|
120 |
+
scores_list.append(step_scores) # 4
|
121 |
+
return scores_list
|
122 |
+
|
123 |
+
def math_evaluation(solution, answer):
|
124 |
+
answer_prefix = "The answer is: "
|
125 |
+
if answer_prefix in solution:
|
126 |
+
start = solution.rfind(answer_prefix) + len(answer_prefix)
|
127 |
+
final_answer = solution[start:]
|
128 |
+
final_answer = final_answer.replace("ки", "")
|
129 |
+
final_answer = final_answer.strip()
|
130 |
+
else:
|
131 |
+
return False, None
|
132 |
+
return (final_answer == answer), final_answer
|
133 |
+
|