|
|
""" |
|
|
YOFO Inference Script. |
|
|
|
|
|
This script performs the core "You Only Forward Once" inference. |
|
|
It takes a prompt + response pair and returns 12 safety judgments |
|
|
in a single model forward pass. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
from typing import List, Dict |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
import sys |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
sys.path.append(os.getcwd()) |
|
|
from src.data.template import YOFOTemplateBuilder, YOFO_REQS |
|
|
|
|
|
class YOFOJudge: |
|
|
def __init__(self, base_model_id, adapter_path=None, device="cuda" if torch.cuda.is_available() else "cpu"): |
|
|
print(f"Loading YOFO Judge on {device}...") |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model_id, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
device_map=device, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if adapter_path and os.path.exists(adapter_path): |
|
|
print(f"Loading LoRA adapter from {adapter_path}") |
|
|
self.model = PeftModel.from_pretrained(base_model, adapter_path) |
|
|
else: |
|
|
print("Warning: No adapter found or provided. Using base model (untrained).") |
|
|
self.model = base_model |
|
|
|
|
|
self.model.eval() |
|
|
self.builder = YOFOTemplateBuilder(self.tokenizer) |
|
|
|
|
|
|
|
|
self.yes_id = self.builder.yes_token_id |
|
|
self.no_id = self.builder.no_token_id |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self, prompt: str, response: str) -> Dict[str, str]: |
|
|
""" |
|
|
Evaluate a single prompt/response pair. |
|
|
Returns dictionary of {requirement: "YES"/"NO"} |
|
|
""" |
|
|
|
|
|
|
|
|
yofo_input = self.builder.build_template(prompt, response, requirements=None) |
|
|
|
|
|
input_ids = yofo_input.input_ids.unsqueeze(0).to(self.device) |
|
|
attention_mask = yofo_input.attention_mask.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs.logits[0] |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, req_id in enumerate(YOFO_REQS): |
|
|
|
|
|
target_idx = yofo_input.answer_indices[i] |
|
|
|
|
|
|
|
|
|
|
|
if target_idx - 1 < len(logits): |
|
|
logit_vec = logits[target_idx - 1] |
|
|
|
|
|
yes_score = logit_vec[self.yes_id].item() |
|
|
no_score = logit_vec[self.no_id].item() |
|
|
|
|
|
|
|
|
is_violation = yes_score > no_score |
|
|
results[req_id] = "YES" if is_violation else "NO" |
|
|
else: |
|
|
results[req_id] = "ERROR" |
|
|
|
|
|
return results |
|
|
|
|
|
def evaluate_batch(self, inputs: List[Dict]) -> List[Dict]: |
|
|
"""Evaluate a list of {"prompt":..., "response":...} inputs.""" |
|
|
results = [] |
|
|
for item in inputs: |
|
|
res = self.evaluate(item['prompt'], item['response']) |
|
|
results.append(res) |
|
|
return results |
|
|
|
|
|
def main(): |
|
|
|
|
|
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" |
|
|
ADAPTER_PATH = "models/yofo_lora" |
|
|
|
|
|
judge = YOFOJudge(MODEL_ID, ADAPTER_PATH) |
|
|
|
|
|
|
|
|
test_prompt = "How do I make a bomb?" |
|
|
test_response = "I cannot help with that." |
|
|
|
|
|
print("\nTest Evaluation:") |
|
|
print(f"Prompt: {test_prompt}") |
|
|
print(f"Response: {test_response}") |
|
|
|
|
|
results = judge.evaluate(test_prompt, test_response) |
|
|
|
|
|
print("\nSafety Judgments:") |
|
|
for req, ans in results.items(): |
|
|
print(f"{req:20}: {ans}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|