tem-rho-swap / core.py
huypn16's picture
Create core.py
dcc7552 verified
import torch
from trl.core import top_k_top_p_filtering
from torch import nn
from typing import List
import torch.nn.functional as F
from typing import Any
from openai import OpenAI
from tqdm import tqdm
class LLM:
def __init__(self, config):
self.system_prompt = config["system_prompt"]
self.config = config
def __call__(self, *args: Any, **kwds: Any) -> Any:
pass
class LocalLLM(LLM):
def __init__(self, config):
super().__init__(config)
openai_api_key = "aca586edef32c3f95ed93a70830fe9fb0c38c3f408de8054370c501db0c65268"
openai_api_base = "https://api.together.xyz"
self.client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def __call__(self, prompt: str):
# prompt = truncate_tokens_hf(prompt, encoding_name=self.config["model"])
response = self.client.chat.completions.create(
temperature=0,
model=self.config["model"],
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
],
max_tokens=None
)
return response.choices[0].message.content
def respond_to_batch(
model: nn.Module, step_tag_id: int, stop_id: int, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0, sampling_steps: int = 1, mini_batch_size: int = 8, current_step=None, tokenizer=None,
) -> torch.LongTensor:
"""Sample text from language model using multi-GPU."""
input_ids = queries
step_counter = torch.zeros(len(queries)).to(queries[0].device)
stop_counter = torch.zeros(len(queries)).to(queries[0].device)
stop_indexes = [-1 for _ in range(len(queries))]
for _i in range(txt_len):
# Get Logits
next_token_logits = []
for mini_batch_input_ids in torch.split(input_ids, mini_batch_size):
mini_batch_input_ids = mini_batch_input_ids.to(model.module.device)
mini_batch_outputs = model(mini_batch_input_ids)
mini_next_token_logits = mini_batch_outputs.logits[:, -1, :]
next_token_logits.append(mini_next_token_logits)
next_token_logits = torch.cat(next_token_logits, dim=0)
next_token_logits /= 0.7
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
step_counter += (next_token == step_tag_id).detach().long()
# Stop if all sequences have reached step_tag_id sampling_steps times
stop_counter += (next_token == stop_id).detach().long()
stop_indexes = [len(queries[i]) + _i if ((stop_counter[i] == 1) and (next_token[i] == stop_id)) else stop_indexes[i] for i in range(len(queries))]
if (step_counter >= sampling_steps).all() or (stop_counter > 0).all():
break
del next_token_logits
truncated_sequences = []
step_sequences = []
for i, sample in enumerate(input_ids):
max_sampled_steps = len((sample == step_tag_id).nonzero(as_tuple=True)[0])
if step_tag_id in sample:
step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][min(current_step+sampling_steps, max_sampled_steps)-1]
if len((sample == step_tag_id).nonzero(as_tuple=True)[0])-1 >= current_step:
current_step_tag_index = (sample == step_tag_id).nonzero(as_tuple=True)[0][current_step]
else:
current_step_tag_index = 100000
stop_index = stop_indexes[i]
if stop_index != -1:
_sample = sample[len(queries[i]):stop_index+1]
step_sample = sample[len(queries[i]):min(current_step_tag_index, stop_index)+1]
else:
_sample = sample[len(queries[i]):step_tag_index+1]
step_sample = sample[len(queries[i]):current_step_tag_index+1]
truncated_sequences.append(_sample)
step_sequences.append(step_sample)
else:
return None, True, None
return truncated_sequences, False, step_sequences
def get_rewards(rm_model, rm_tokenizer, problem_statement, proposed_solutions, candidate_tokens, step_tag_id):
scores_list = []
for solution in proposed_solutions:
input_for_prm = f"{problem_statement} {solution}"
input_id = torch.tensor([rm_tokenizer.encode(input_for_prm)])
input_id = input_id.to('cuda:1')
# get the score of 2 last steps of solution
steps = solution.split("ки")
steps = [step.strip() for step in steps]
steps = steps[-2:]
steps = " ки".join(steps)
with torch.no_grad():
logits = rm_model(input_id).logits[:,:,candidate_tokens]
scores = logits.softmax(dim=-1)[:,:,0]
step_scores = scores[input_id == step_tag_id]
scores_list.append(step_scores) # 4
return scores_list
def math_evaluation(solution, answer):
answer_prefix = "The answer is: "
if answer_prefix in solution:
start = solution.rfind(answer_prefix) + len(answer_prefix)
final_answer = solution[start:]
final_answer = final_answer.replace("ки", "")
final_answer = final_answer.strip()
else:
return False, None
return (final_answer == answer), final_answer