File size: 2,176 Bytes
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from dotenv import load_dotenv, find_dotenv
import os
import requests

# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
load_dotenv(find_dotenv())
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]

with open("data/processed/llama_guard-unsafe_files.txt", "r") as file:
    unsafe_categories = file.read()


def moderation_prompt_for_chat(chat):
    # For simplicity, we assume the chat dictionary correctly alternates "user" and "assistant" roles
    # Feel free to include error checking if needed
    conversation = [turn["content"] for turn in chat]
    is_agent = len(conversation) % 2 == 0
    role = "Agent" if len(conversation) % 2 == 0 else "User"
    prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories.

<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

"""

    # Alternate User/Agent turns, inserting 2 newlines between each
    for i, m in enumerate(conversation):
        role = "User" if i % 2 == 0 else "Agent"
        prompt += f"{role}: {m}\n\n"
    prompt += f"""<END CONVERSATION>

Provide your safety assessment for {role} in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
    return prompt


def query(payload):
    API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
    bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
    headers = {
        "Accept": "application/json",
        "Authorization": bearer_txt,
        "Content-Type": "application/json",
    }
    response = requests.post(API_URL, headers=headers, json=payload)

    return response.json()


def moderate_chat(chat):
    prompt = moderation_prompt_for_chat(chat)

    output = query(
        {
            "inputs": prompt,
            "parameters": {
                "top_k": 1,
                "top_p": 0.2,
                "temperature": 0.1,
                "max_new_tokens": 512,
            },
        }
    )

    return output