huypn16 commited on
Commit
dcc7552
1 Parent(s): a71c743

Create core.py

Browse files
Files changed (1) hide show
  1. core.py +133 -0
core.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from trl.core import top_k_top_p_filtering
3
+ from torch import nn
4
+ from typing import List
5
+ import torch.nn.functional as F
6
+ from typing import Any
7
+ from openai import OpenAI
8
+ from tqdm import tqdm
9
+
10
+ class LLM:
11
+ def __init__(self, config):
12
+ self.system_prompt = config["system_prompt"]
13
+ self.config = config
14
+
15
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
16
+ pass
17
+
18
+ class LocalLLM(LLM):
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ openai_api_key = "aca586edef32c3f95ed93a70830fe9fb0c38c3f408de8054370c501db0c65268"
22
+ openai_api_base = "https://api.together.xyz"
23
+
24
+ self.client = OpenAI(
25
+ api_key=openai_api_key,
26
+ base_url=openai_api_base,
27
+ )
28
+
29
+ def __call__(self, prompt: str):
30
+ # prompt = truncate_tokens_hf(prompt, encoding_name=self.config["model"])
31
+ response = self.client.chat.completions.create(
32
+ temperature=0,
33
+ model=self.config["model"],
34
+ messages=[
35
+ {"role": "system", "content": self.system_prompt},
36
+ {"role": "user", "content": prompt},
37
+ ],
38
+ max_tokens=None
39
+ )
40
+ return response.choices[0].message.content
41
+
42
+ def respond_to_batch(
43
+ model: nn.Module, step_tag_id: int, stop_id: int, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0, sampling_steps: int = 1, mini_batch_size: int = 8, current_step=None, tokenizer=None,
44
+ ) -> torch.LongTensor:
45
+ """Sample text from language model using multi-GPU."""
46
+ input_ids = queries
47
+ step_counter = torch.zeros(len(queries)).to(queries[0].device)
48
+ stop_counter = torch.zeros(len(queries)).to(queries[0].device)
49
+ stop_indexes = [-1 for _ in range(len(queries))]
50
+
51
+ for _i in range(txt_len):
52
+ # Get Logits
53
+ next_token_logits = []
54
+ for mini_batch_input_ids in torch.split(input_ids, mini_batch_size):
55
+ mini_batch_input_ids = mini_batch_input_ids.to(model.module.device)
56
+ mini_batch_outputs = model(mini_batch_input_ids)
57
+ mini_next_token_logits = mini_batch_outputs.logits[:, -1, :]
58
+ next_token_logits.append(mini_next_token_logits)
59
+
60
+ next_token_logits = torch.cat(next_token_logits, dim=0)
61
+ next_token_logits /= 0.7
62
+ next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
63
+
64
+ # Sample
65
+ probs = F.softmax(next_token_logits, dim=-1)
66
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
67
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
68
+ step_counter += (next_token == step_tag_id).detach().long()
69
+
70
+ # Stop if all sequences have reached step_tag_id sampling_steps times
71
+ stop_counter += (next_token == stop_id).detach().long()
72
+ stop_indexes = [len(queries[i]) + _i if ((stop_counter[i] == 1) and (next_token[i] == stop_id)) else stop_indexes[i] for i in range(len(queries))]
73
+
74
+ if (step_counter >= sampling_steps).all() or (stop_counter > 0).all():
75
+ break
76
+
77
+ del next_token_logits
78
+
79
+ truncated_sequences = []
80
+ step_sequences = []
81
+ for i, sample in enumerate(input_ids):
82
+ max_sampled_steps = len((sample == step_tag_id).nonzero(as_tuple=True)[0])
83
+ if step_tag_id in sample:
84
+ step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][min(current_step+sampling_steps, max_sampled_steps)-1]
85
+ if len((sample == step_tag_id).nonzero(as_tuple=True)[0])-1 >= current_step:
86
+ current_step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][current_step]
87
+ else:
88
+ current_step_tag_index = 100000
89
+ stop_index = stop_indexes[i]
90
+ if stop_index != -1:
91
+ _sample = sample[len(queries[i]):stop_index+1]
92
+ step_sample = sample[len(queries[i]):min(current_step_tag_index, stop_index)+1]
93
+ else:
94
+ _sample = sample[len(queries[i]):step_tag_index+1]
95
+ step_sample = sample[len(queries[i]):current_step_tag_index+1]
96
+ truncated_sequences.append(_sample)
97
+ step_sequences.append(step_sample)
98
+ else:
99
+ return None, True, None
100
+ return truncated_sequences, False, step_sequences
101
+
102
+
103
+ def get_rewards(rm_model, rm_tokenizer, problem_statement, proposed_solutions, candidate_tokens, step_tag_id):
104
+ scores_list = []
105
+ for solution in proposed_solutions:
106
+ input_for_prm = f"{problem_statement} {solution}"
107
+ input_id = torch.tensor([rm_tokenizer.encode(input_for_prm)])
108
+ input_id = input_id.to('cuda:1')
109
+ # get the score of 2 last steps of solution
110
+ steps = solution.split("ки")
111
+ steps = [step.strip() for step in steps]
112
+ steps = steps[-2:]
113
+ steps = " ки".join(steps)
114
+
115
+ with torch.no_grad():
116
+ logits = rm_model(input_id).logits[:,:,candidate_tokens]
117
+ scores = logits.softmax(dim=-1)[:,:,0]
118
+ step_scores = scores[input_id == step_tag_id]
119
+
120
+ scores_list.append(step_scores) # 4
121
+ return scores_list
122
+
123
+ def math_evaluation(solution, answer):
124
+ answer_prefix = "The answer is: "
125
+ if answer_prefix in solution:
126
+ start = solution.rfind(answer_prefix) + len(answer_prefix)
127
+ final_answer = solution[start:]
128
+ final_answer = final_answer.replace("ки", "")
129
+ final_answer = final_answer.strip()
130
+ else:
131
+ return False, None
132
+ return (final_answer == answer), final_answer
133
+