MindfulMedia_Mentor / llama_guard.py
jaelin215's picture
updated path
2edb6cf verified
raw
history blame
4.58 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
from dotenv import load_dotenv, find_dotenv
import os
import requests
model_id = "meta-llama/LlamaGuard-7b"
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
load_dotenv(find_dotenv())
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# updated on March 24th
# reading unsafe categories file
try:
with open("data/llama_guard-unsafe_categories.txt", "r") as file:
unsafe_categories = file.read()
except FileNotFoundError:
print("File not found. Please check the file path.")
unsafe_categories = (
"" # Or handle the missing file as appropriate for your application
)
except PermissionError:
print("Permission denied. You don't have the rights to read the file.")
unsafe_categories = "" # Or handle the lack of permissions as appropriate
except Exception as e: # Catch any other exceptions
print(f"An error occurred while reading the file: {e}")
unsafe_categories = "" # Or handle unexpected errors as appropriate
# added on March 24th
def moderation_prompt_for_chat(chat):
prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User : {chat}
<END CONVERSATION>
Provide your safety assessment for 'User' in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
return prompt
def query(payload):
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
headers = {
"Accept": "application/json",
"Authorization": bearer_txt,
"Content-Type": "application/json",
}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # This will raise an exception for HTTP error responses
return response.json(), None
except requests.exceptions.HTTPError as http_err:
error_message = f"HTTP error occurred: {http_err}"
print(error_message)
except requests.exceptions.ConnectionError:
error_message = "Could not connect to the API endpoint."
print(error_message)
except Exception as err:
error_message = f"An error occurred: {err}"
print(error_message)
return None, error_message
def query1(payload):
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
headers = {
"Accept": "application/json",
"Authorization": bearer_txt,
"Content-Type": "application/json",
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def moderate_chat(chat):
prompt = moderation_prompt_for_chat(chat)
output, error_msg = query(
{
"inputs": prompt,
"parameters": {
"top_k": 1,
"top_p": 0.2,
"temperature": 0.1,
"max_new_tokens": 512,
},
}
)
print("Llamaguard prompt****", prompt)
print("Llamaguard output****", output)
return output, error_msg
# added on March 24th
def load_category_names_from_string(file_content):
"""Load category codes and names from a string into a dictionary."""
category_names = {}
lines = file_content.split("\n")
for line in lines:
if line.startswith("O"):
parts = line.split(":")
if len(parts) == 2:
code = parts[0].strip()
name = parts[1].strip()
category_names[code] = name
return category_names
def get_category_name(input_str):
"""Return the category name given a category code from an input string."""
# Load the category names from the file content
category_names = load_category_names_from_string(unsafe_categories)
# Extract the category code from the input string
category_code = input_str.split("\n")[1].strip()
# Find the full category name using the code
category_name = category_names.get(category_code, "Unknown Category")
# return f"{category_code} : {category_name}"
return f"{category_name}"