File size: 4,575 Bytes
00e4075 bd9870c 00e4075 bd9870c 00e4075 bd9870c 00e4075 2edb6cf 00e4075 bd9870c 00e4075 bd9870c 00e4075 bd9870c 00e4075 bd9870c 00e4075 bd9870c 00e4075 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
from dotenv import load_dotenv, find_dotenv
import os
import requests
model_id = "meta-llama/LlamaGuard-7b"
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
load_dotenv(find_dotenv())
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# updated on March 24th
# reading unsafe categories file
try:
with open("data/llama_guard-unsafe_categories.txt", "r") as file:
unsafe_categories = file.read()
except FileNotFoundError:
print("File not found. Please check the file path.")
unsafe_categories = (
"" # Or handle the missing file as appropriate for your application
)
except PermissionError:
print("Permission denied. You don't have the rights to read the file.")
unsafe_categories = "" # Or handle the lack of permissions as appropriate
except Exception as e: # Catch any other exceptions
print(f"An error occurred while reading the file: {e}")
unsafe_categories = "" # Or handle unexpected errors as appropriate
# added on March 24th
def moderation_prompt_for_chat(chat):
prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
User : {chat}
<END CONVERSATION>
Provide your safety assessment for 'User' in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
return prompt
def query(payload):
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
headers = {
"Accept": "application/json",
"Authorization": bearer_txt,
"Content-Type": "application/json",
}
try:
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status() # This will raise an exception for HTTP error responses
return response.json(), None
except requests.exceptions.HTTPError as http_err:
error_message = f"HTTP error occurred: {http_err}"
print(error_message)
except requests.exceptions.ConnectionError:
error_message = "Could not connect to the API endpoint."
print(error_message)
except Exception as err:
error_message = f"An error occurred: {err}"
print(error_message)
return None, error_message
def query1(payload):
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
headers = {
"Accept": "application/json",
"Authorization": bearer_txt,
"Content-Type": "application/json",
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def moderate_chat(chat):
prompt = moderation_prompt_for_chat(chat)
output, error_msg = query(
{
"inputs": prompt,
"parameters": {
"top_k": 1,
"top_p": 0.2,
"temperature": 0.1,
"max_new_tokens": 512,
},
}
)
print("Llamaguard prompt****", prompt)
print("Llamaguard output****", output)
return output, error_msg
# added on March 24th
def load_category_names_from_string(file_content):
"""Load category codes and names from a string into a dictionary."""
category_names = {}
lines = file_content.split("\n")
for line in lines:
if line.startswith("O"):
parts = line.split(":")
if len(parts) == 2:
code = parts[0].strip()
name = parts[1].strip()
category_names[code] = name
return category_names
def get_category_name(input_str):
"""Return the category name given a category code from an input string."""
# Load the category names from the file content
category_names = load_category_names_from_string(unsafe_categories)
# Extract the category code from the input string
category_code = input_str.split("\n")[1].strip()
# Find the full category name using the code
category_name = category_names.get(category_code, "Unknown Category")
# return f"{category_code} : {category_name}"
return f"{category_name}"
|