Tutorial: Use Llama-Guard-3-8B as a classifier AND get safety categories
#21
by
LuisVasquezBSC
- opened
The problem:
- Even though LlamaGuard3's model card states that "it can be used to classify content", in reality, this model is NOT a classifier. Instead, it is a model for Next Token Prediction.
- The recommended way to use this model as a classifier is to use the probabilities of generating the
"unsafe"
token. - If the model is not confident about the unsafety of a response, then the text generated by the model would not be reliable for classification. For example, is the probability of the
"unsafe"
token is 60%, then applying the model three times (with naive sampling) to the same content would give you "unsafe" 2 times and "safe" one time. - Inspecting the texts generated by the model using the
"moderate"
function from the model card, one notices that they start with a whitespace token. For example, the model generates"\n\nunsafe\nS1"
or"\n\nsafe"
instead of"unsafe\nS1"
,"safe"
.
Why current solutions do not work:
- The recommendation of using
compute_transition_scores
, only produces probabilities of generated tokens. Then, if the model generates"safe"
, one cannot obtain the probability of the"unsafe"
token. - In this discussion there are two proposed ways of obtaining probabilities from the model.
2.1. The first method produces the probabilities only of the generated tokens, like (1)
2.2 The second method produces the probabilities of the first token, which, as mentioned above, is always whitespace.
The solution
This implementation produces:
- The probability of the
"unsafe"
token, even if it was not generated - The text generated by the model
How?
- Given a conversation, apply the chat template to convert it into a string for next token prediction
- Add the whitespace character, forcing one of
"safe"
or"unsafe"
to be the first generated token - Get the probabilities for the first generated token for all tokens in the vocabulary, and extract the one corresponding to the
"unsafe"
token
Implementation
# Initialization like in the model card
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "meta-llama/Llama-Guard-3-8B"
device = "cuda"
dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device)
# Get the id of the "unsafe" token, this will later be used to extract its probability
UNSAFE_TOKEN_ID = tokenizer.convert_tokens_to_ids("unsafe")
def moderate(chat):
prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
# Skip the generation of whitespace.
# Now the next predicted token will be either "safe" or "unsafe"
prompt += "\n\n"
inputs = tokenizer([prompt], return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=50,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
output_logits=True, # get logits
)
######
# Get generated text
######
# Number of tokens that correspond to the input prompt
input_length = inputs.input_ids.shape[1]
# Ignore the tokens from the input to get the tokens generated by the model
generated_token_ids = outputs.sequences[:, input_length:].cpu()
generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
######
# Get Probability of "unsafe" token
######
# First generated token is either "safe" or "unsafe".
# use the logits to calculate the probabilities.
first_token_logits = outputs.logits[0]
first_token_probs = torch.softmax(first_token_logits, dim=-1)
# From the probabilities of all tokens, extract the one for the "unsafe" token.
unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
unsafe_probability = unsafe_probability.item()
######
# Result
######
return {
"unsafe_score": unsafe_probability,
"generated_text": generated_text
}
# Example
moderate ([
{"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
{"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
])
# {
# "unsafe_score": 4.006361268693581e-05,
# "generated_text": "safe"
# }
LuisVasquezBSC
changed discussion title from
Tutorial: Using Llama-Guard-3-8b as a classifier AND getting safety categories
to Tutorial: Using Llama-Guard-3-8B as a classifier AND getting safety categories
LuisVasquezBSC
changed discussion title from
Tutorial: Using Llama-Guard-3-8B as a classifier AND getting safety categories
to Tutorial: Use Llama-Guard-3-8B as a classifier AND get safety categories
This comment has been hidden