| | import argparse |
| | import re |
| |
|
| | import torch |
| | import uvicorn |
| | from fastapi import FastAPI, Request |
| | from fastapi.responses import JSONResponse |
| |
|
| | from openrlhf.models import get_llm_for_sequence_regression |
| | from openrlhf.utils import get_tokenizer |
| | from openrlhf.utils.logging_utils import init_logger |
| |
|
| | logger = init_logger(__name__) |
| |
|
| | class RewardModelProxy: |
| | def __init__(self, args): |
| | |
| | self.reward_model = get_llm_for_sequence_regression( |
| | args.reward_pretrain, |
| | "reward", |
| | normalize_reward=args.normalize_reward, |
| | use_flash_attention_2=args.flash_attn, |
| | bf16=args.bf16, |
| | load_in_4bit=args.load_in_4bit, |
| | value_head_prefix=args.value_head_prefix, |
| | device_map="auto", |
| | ) |
| | self.reward_model.eval() |
| |
|
| | self.tokenizer = get_tokenizer( |
| | args.reward_pretrain, self.reward_model, "left", None, use_fast=not args.disable_fast_tokenizer |
| | ) |
| | self.max_length = args.max_len |
| | self.batch_size = args.batch_size |
| |
|
| | def get_reward(self, queries): |
| | |
| | |
| | |
| | |
| |
|
| | |
| | correct_count = 0 |
| | total_count = len(df) |
| | allscores = [] |
| | |
| | for index, row in df.iterrows(): |
| | chosen_query = row["chosen_prompt"] + " " + row["chosen"] |
| | reject_query = row["chosen_prompt"] + " " + row["reject"] |
| |
|
| | |
| | scores = self.compare_queries(chosen_query, reject_query) |
| |
|
| | all_scores.append(scores) |
| |
|
| | |
| | chosen_score, reject_score = scores |
| | if chosen_score > reject_score: |
| | correct_count += 1 |
| |
|
| | accuracy = correct_count / total_count if total_count > 0 else 0 |
| | print(f"Current Accuracy: {accuracy * 100:.2f}%") |
| | return all_scores, accuracy |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def compare_queries(self, chosen_query, reject_query): |
| | """ |
| | Compare the reward scores for chosen_query and reject_query. |
| | :param chosen_query: The query with the 'chosen' answer |
| | :param reject_query: The query with the 'reject' answer |
| | :return: Tuple (chosen_score, reject_score) |
| | """ |
| | with torch.no_grad(): |
| | inputs_chosen = self.tokenize_fn([chosen_query], device=self.reward_model.device) |
| | inputs_reject = self.tokenize_fn([reject_query], device=self.reward_model.device) |
| |
|
| | chosen_score = self.reward_model(inputs_chosen["input_ids"], inputs_chosen["attention_mask"]).tolist()[0] |
| | reject_score = self.reward_model(inputs_reject["input_ids"], inputs_reject["attention_mask"]).tolist()[0] |
| |
|
| | return chosen_score, reject_score |
| |
|
| | def tokenize_fn(self, texts, device): |
| | batch = self.tokenizer( |
| | texts, |
| | return_tensors="pt", |
| | max_length=self.max_length, |
| | padding=True, |
| | truncation=True, |
| | ) |
| | return {k: v.to(device) for k, v in batch.items()} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") |
| | parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") |
| | parser.add_argument("--value_head_prefix", type=str, default="value_head") |
| | parser.add_argument("--max_len", type=int, default="2048") |
| |
|
| | parser.add_argument("--port", type=int, default=5000, help="Port number for the server") |
| | parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") |
| |
|
| | |
| | parser.add_argument("--load_in_4bit", action="store_true", default=False) |
| | parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") |
| | parser.add_argument( |
| | "--attn_implementation", |
| | type=str, |
| | default="flash_attention_2", |
| | help="Attention implementation (e.g., eager, flash_attention_2, flash_attention_3, kernels-community/vllm-flash-attn3)", |
| | ) |
| | parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) |
| | parser.add_argument("--batch_size", type=int, default=None) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | reward_model = RewardModelProxy(args) |
| | app = FastAPI() |
| |
|
| | @app.post("/get_reward") |
| | async def get_reward(request: Request): |
| | data = await request.json() |
| | queries = data.get("query") |
| | rewards = reward_model.get_reward(queries) |
| | result = {"rewards": rewards, "scores": rewards, "extra_logs": {"dummy_scores": rewards}} |
| | logger.info(f"Sent JSON: {result}") |
| | return JSONResponse(result) |
| |
|
| | uvicorn.run(app, host=args.host, port=args.port, log_level="info") |