rm_code / openrlhf_rm.py
hahayang012's picture
Upload folder using huggingface_hub
d8a76be verified
import argparse
import re
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from openrlhf.models import get_llm_for_sequence_regression
from openrlhf.utils import get_tokenizer
from openrlhf.utils.logging_utils import init_logger
logger = init_logger(__name__)
class RewardModelProxy:
def __init__(self, args):
# Modify the reward_model to your remote model
self.reward_model = get_llm_for_sequence_regression(
args.reward_pretrain,
"reward",
normalize_reward=args.normalize_reward,
use_flash_attention_2=args.flash_attn,
bf16=args.bf16,
load_in_4bit=args.load_in_4bit,
value_head_prefix=args.value_head_prefix,
device_map="auto",
)
self.reward_model.eval()
self.tokenizer = get_tokenizer(
args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer
)
self.max_length = args.max_len
self.batch_size = args.batch_size
def get_reward(self, queries):
# if self.batch_size is None:
# batch_size = len(queries)
# else:
# batch_size = self.batch_size
# logger.info(f"queries[0]: {queries[0]}")
correct_count = 0 # For accuracy calculation
total_count = len(df)
allscores = []
# batch
for index, row in df.iterrows():
chosen_query = row["chosen_prompt"] + " " + row["chosen"]
reject_query = row["chosen_prompt"] + " " + row["reject"]
# Get reward scores for chosen and reject
scores = self.compare_queries(chosen_query, reject_query)
all_scores.append(scores)
# Determine correctness based on comparison
chosen_score, reject_score = scores
if chosen_score > reject_score:
correct_count += 1
accuracy = correct_count / total_count if total_count > 0 else 0
print(f"Current Accuracy: {accuracy * 100:.2f}%")
return all_scores, accuracy
# with torch.no_grad():
# for i in range(0, len(queries), batch_size):
# inputs = self.tokenize_fn(
# queries[i : min(len(queries), i + batch_size)], device=self.reward_model.device
# )
# r = self.reward_model(inputs["input_ids"], inputs["attention_mask"])
# r = r.tolist()
# scores.extend(r)
# return scores
def compare_queries(self, chosen_query, reject_query):
"""
Compare the reward scores for chosen_query and reject_query.
:param chosen_query: The query with the 'chosen' answer
:param reject_query: The query with the 'reject' answer
:return: Tuple (chosen_score, reject_score)
"""
with torch.no_grad():
inputs_chosen = self.tokenize_fn([chosen_query], device=self.reward_model.device)
inputs_reject = self.tokenize_fn([reject_query], device=self.reward_model.device)
chosen_score = self.reward_model(inputs_chosen["input_ids"], inputs_chosen["attention_mask"]).tolist()[0]
reject_score = self.reward_model(inputs_reject["input_ids"], inputs_reject["attention_mask"]).tolist()[0]
return chosen_score, reject_score
def tokenize_fn(self, texts, device):
batch = self.tokenizer(
texts,
return_tensors="pt",
max_length=self.max_length,
padding=True,
truncation=True,
)
return {k: v.to(device) for k, v in batch.items()}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Reward Model
parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation")
parser.add_argument("--value_head_prefix", type=str, default="value_head")
parser.add_argument("--max_len", type=int, default="2048")
parser.add_argument("--port", type=int, default=5000, help="Port number for the server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server")
# Performance
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
parser.add_argument(
"--attn_implementation",
type=str,
default="flash_attention_2",
help="Attention implementation (e.g., eager, flash_attention_2, flash_attention_3, kernels-community/vllm-flash-attn3)",
)
parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
parser.add_argument("--batch_size", type=int, default=None)
args = parser.parse_args()
# server
reward_model = RewardModelProxy(args)
app = FastAPI()
@app.post("/get_reward")
async def get_reward(request: Request):
data = await request.json()
queries = data.get("query")
rewards = reward_model.get_reward(queries)
result = {"rewards": rewards, "scores": rewards, "extra_logs": {"dummy_scores": rewards}}
logger.info(f"Sent JSON: {result}")
return JSONResponse(result)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")