|
import torch |
|
from trl.core import top_k_top_p_filtering |
|
from torch import nn |
|
from typing import List |
|
import torch.nn.functional as F |
|
from typing import Any |
|
from openai import OpenAI |
|
from tqdm import tqdm |
|
|
|
class LLM: |
|
def __init__(self, config): |
|
self.system_prompt = config["system_prompt"] |
|
self.config = config |
|
|
|
def __call__(self, *args: Any, **kwds: Any) -> Any: |
|
pass |
|
|
|
class LocalLLM(LLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
openai_api_key = "aca586edef32c3f95ed93a70830fe9fb0c38c3f408de8054370c501db0c65268" |
|
openai_api_base = "https://api.together.xyz" |
|
|
|
self.client = OpenAI( |
|
api_key=openai_api_key, |
|
base_url=openai_api_base, |
|
) |
|
|
|
def __call__(self, prompt: str): |
|
|
|
response = self.client.chat.completions.create( |
|
temperature=0, |
|
model=self.config["model"], |
|
messages=[ |
|
{"role": "system", "content": self.system_prompt}, |
|
{"role": "user", "content": prompt}, |
|
], |
|
max_tokens=None |
|
) |
|
return response.choices[0].message.content |
|
|
|
def respond_to_batch( |
|
model: nn.Module, step_tag_id: int, stop_id: int, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0, sampling_steps: int = 1, mini_batch_size: int = 8, current_step=None, tokenizer=None, |
|
) -> torch.LongTensor: |
|
"""Sample text from language model using multi-GPU.""" |
|
input_ids = queries |
|
step_counter = torch.zeros(len(queries)).to(queries[0].device) |
|
stop_counter = torch.zeros(len(queries)).to(queries[0].device) |
|
stop_indexes = [-1 for _ in range(len(queries))] |
|
|
|
for _i in range(txt_len): |
|
|
|
next_token_logits = [] |
|
for mini_batch_input_ids in torch.split(input_ids, mini_batch_size): |
|
mini_batch_input_ids = mini_batch_input_ids.to(model.module.device) |
|
mini_batch_outputs = model(mini_batch_input_ids) |
|
mini_next_token_logits = mini_batch_outputs.logits[:, -1, :] |
|
next_token_logits.append(mini_next_token_logits) |
|
|
|
next_token_logits = torch.cat(next_token_logits, dim=0) |
|
next_token_logits /= 0.7 |
|
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) |
|
step_counter += (next_token == step_tag_id).detach().long() |
|
|
|
|
|
stop_counter += (next_token == stop_id).detach().long() |
|
stop_indexes = [len(queries[i]) + _i if ((stop_counter[i] == 1) and (next_token[i] == stop_id)) else stop_indexes[i] for i in range(len(queries))] |
|
|
|
if (step_counter >= sampling_steps).all() or (stop_counter > 0).all(): |
|
break |
|
|
|
del next_token_logits |
|
|
|
truncated_sequences = [] |
|
step_sequences = [] |
|
for i, sample in enumerate(input_ids): |
|
max_sampled_steps = len((sample == step_tag_id).nonzero(as_tuple=True)[0]) |
|
if step_tag_id in sample: |
|
step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][min(current_step+sampling_steps, max_sampled_steps)-1] |
|
if len((sample == step_tag_id).nonzero(as_tuple=True)[0])-1 >= current_step: |
|
current_step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][current_step] |
|
else: |
|
current_step_tag_index = 100000 |
|
stop_index = stop_indexes[i] |
|
if stop_index != -1: |
|
_sample = sample[len(queries[i]):stop_index+1] |
|
step_sample = sample[len(queries[i]):min(current_step_tag_index, stop_index)+1] |
|
else: |
|
_sample = sample[len(queries[i]):step_tag_index+1] |
|
step_sample = sample[len(queries[i]):current_step_tag_index+1] |
|
truncated_sequences.append(_sample) |
|
step_sequences.append(step_sample) |
|
else: |
|
return None, True, None |
|
return truncated_sequences, False, step_sequences |
|
|
|
|
|
def get_rewards(rm_model, rm_tokenizer, problem_statement, proposed_solutions, candidate_tokens, step_tag_id): |
|
scores_list = [] |
|
for solution in proposed_solutions: |
|
input_for_prm = f"{problem_statement} {solution}" |
|
input_id = torch.tensor([rm_tokenizer.encode(input_for_prm)]) |
|
input_id = input_id.to('cuda:1') |
|
|
|
steps = solution.split("ки") |
|
steps = [step.strip() for step in steps] |
|
steps = steps[-2:] |
|
steps = " ки".join(steps) |
|
|
|
with torch.no_grad(): |
|
logits = rm_model(input_id).logits[:,:,candidate_tokens] |
|
scores = logits.softmax(dim=-1)[:,:,0] |
|
step_scores = scores[input_id == step_tag_id] |
|
|
|
scores_list.append(step_scores) |
|
return scores_list |
|
|
|
def math_evaluation(solution, answer): |
|
answer_prefix = "The answer is: " |
|
if answer_prefix in solution: |
|
start = solution.rfind(answer_prefix) + len(answer_prefix) |
|
final_answer = solution[start:] |
|
final_answer = final_answer.replace("ки", "") |
|
final_answer = final_answer.strip() |
|
else: |
|
return False, None |
|
return (final_answer == answer), final_answer |
|
|