MindfulMedia_Mentor / llama_guard.py
jaelin215's picture
Upload 14 files
bd9870c verified
raw history blame
No virus
2.18 kB
from dotenv import load_dotenv, find_dotenv
import os
import requests
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
load_dotenv(find_dotenv())
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
with open("data/processed/llama_guard-unsafe_files.txt", "r") as file:
unsafe_categories = file.read()
def moderation_prompt_for_chat(chat):
# For simplicity, we assume the chat dictionary correctly alternates "user" and "assistant" roles
# Feel free to include error checking if needed
conversation = [turn["content"] for turn in chat]
is_agent = len(conversation) % 2 == 0
role = "Agent" if len(conversation) % 2 == 0 else "User"
prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories.
<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
"""
# Alternate User/Agent turns, inserting 2 newlines between each
for i, m in enumerate(conversation):
role = "User" if i % 2 == 0 else "Agent"
prompt += f"{role}: {m}\n\n"
prompt += f"""<END CONVERSATION>
Provide your safety assessment for {role} in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
return prompt
def query(payload):
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
headers = {
"Accept": "application/json",
"Authorization": bearer_txt,
"Content-Type": "application/json",
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def moderate_chat(chat):
prompt = moderation_prompt_for_chat(chat)
output = query(
{
"inputs": prompt,
"parameters": {
"top_k": 1,
"top_p": 0.2,
"temperature": 0.1,
"max_new_tokens": 512,
},
}
)
return output