sana-zero / safety_check.py
gen6scp's picture
Patched codes for ZeroGPU
d643072
raw
history blame
2.37 kB
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is copy from https://github.com/mit-han-lab/hart/blob/main/hart/utils/safety_check.py
import json
from functools import lru_cache
from typing import Dict, Sequence
import torch
import transformers
from torch.nn.functional import softmax
from transformers import AutoModelForCausalLM, AutoTokenizer
__all__ = ["is_dangerous"]
__rank, __local_rank, __world_size, __device = (
0,
0,
1,
"cuda" if torch.cuda.is_available() else "cpu",
)
safety_policy = """
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
"""
def is_dangerous(tokenizer, model, user_prompt, threshold=0.75):
prompt = f"""You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: {user_prompt.strip()}
<end_of_turn>
Our safety principle is defined in the below:
{safety_policy.strip()}
Does the human question violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
logits = model(**inputs).logits
# Extract the logits for the Yes and No tokens
vocab = tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
# Convert these logits to a probability with softmax
probabilities = softmax(selected_logits, dim=0)
# Return probability of 'Yes'
score = probabilities[0].item()
print(f"Safety score ={score}")
return score > threshold