Tutorial: Use Llama-Guard-3-8B as a classifier AND get safety categories

#21
by LuisVasquezBSC - opened

The problem:

  • Even though LlamaGuard3's model card states that "it can be used to classify content", in reality, this model is NOT a classifier. Instead, it is a model for Next Token Prediction.
  • The recommended way to use this model as a classifier is to use the probabilities of generating the "unsafe" token.
  • If the model is not confident about the unsafety of a response, then the text generated by the model would not be reliable for classification. For example, is the probability of the "unsafe" token is 60%, then applying the model three times (with naive sampling) to the same content would give you "unsafe" 2 times and "safe" one time.
  • Inspecting the texts generated by the model using the "moderate" function from the model card, one notices that they start with a whitespace token. For example, the model generates "\n\nunsafe\nS1" or "\n\nsafe" instead of "unsafe\nS1", "safe".

Why current solutions do not work:

  1. The recommendation of using compute_transition_scores, only produces probabilities of generated tokens. Then, if the model generates "safe", one cannot obtain the probability of the "unsafe" token.
  2. In this discussion there are two proposed ways of obtaining probabilities from the model.
    2.1. The first method produces the probabilities only of the generated tokens, like (1)
    2.2 The second method produces the probabilities of the first token, which, as mentioned above, is always whitespace.

The solution

This implementation produces:

  • The probability of the "unsafe" token, even if it was not generated
  • The text generated by the model

How?

  1. Given a conversation, apply the chat template to convert it into a string for next token prediction
  2. Add the whitespace character, forcing one of "safe" or "unsafe" to be the first generated token
  3. Get the probabilities for the first generated token for all tokens in the vocabulary, and extract the one corresponding to the "unsafe" token

Implementation

# Initialization like in the model card
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Llama-Guard-3-8B"
device = "cuda"
dtype = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device)

# Get the id of the "unsafe" token, this will later be used to extract its probability
UNSAFE_TOKEN_ID = tokenizer.convert_tokens_to_ids("unsafe")


def moderate(chat):
    
    prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
    # Skip the generation of whitespace.
    # Now the next predicted token will be either "safe" or "unsafe"
    prompt += "\n\n"
    
    inputs = tokenizer([prompt], return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.eos_token_id,
        output_logits=True,  # get logits
    )
    ######
    # Get generated text
    ######

    # Number of tokens that correspond to the input prompt
    input_length = inputs.input_ids.shape[1]
    # Ignore the tokens from the input to get the tokens generated by the model
    generated_token_ids = outputs.sequences[:, input_length:].cpu()
    generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)

    ######
    # Get Probability of "unsafe" token
    ######
   
    # First generated token is either "safe" or "unsafe". 
    # use the logits to calculate the probabilities.
    first_token_logits = outputs.logits[0]
    first_token_probs = torch.softmax(first_token_logits, dim=-1)
    
    # From the probabilities of all tokens, extract the one for the "unsafe" token.
    unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
    unsafe_probability = unsafe_probability.item()

    ######
    # Result
    ######
    return {
        "unsafe_score": unsafe_probability,
        "generated_text": generated_text
    }


# Example

moderate ([
{"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
{"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
])

# {
# "unsafe_score": 4.006361268693581e-05,
# "generated_text": "safe"
# }

LuisVasquezBSC changed discussion title from Tutorial: Using Llama-Guard-3-8b as a classifier AND getting safety categories to Tutorial: Using Llama-Guard-3-8B as a classifier AND getting safety categories
LuisVasquezBSC changed discussion title from Tutorial: Using Llama-Guard-3-8B as a classifier AND getting safety categories to Tutorial: Use Llama-Guard-3-8B as a classifier AND get safety categories
This comment has been hidden

Sign up or log in to comment