File size: 4,036 Bytes
a521442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
from dotenv import load_dotenv, find_dotenv
import os
import requests
model_id = "meta-llama/LlamaGuard-7b"


# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
load_dotenv(find_dotenv())
HUGGINGFACEHUB_API_TOKEN  = os.environ["HUGGINGFACEHUB_API_TOKEN"]

#updated on March 24th

#reading unsafe categories file
try:
   with open('./prompts/llama_guard-unsafe_categories.txt', 'r') as file:
        unsafe_categories = file.read()
except FileNotFoundError:
    print("File not found. Please check the file path.")
    unsafe_categories = ""  # Or handle the missing file as appropriate for your application
except PermissionError:
    print("Permission denied. You don't have the rights to read the file.")
    unsafe_categories = ""  # Or handle the lack of permissions as appropriate
except Exception as e:  # Catch any other exceptions
    print(f"An error occurred while reading the file: {e}")
    unsafe_categories = ""  # Or handle unexpected errors as appropriate

#added on March 24th
def moderation_prompt_for_chat(chat):
  prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.

  <BEGIN UNSAFE CONTENT CATEGORIES>

  {unsafe_categories}

  <END UNSAFE CONTENT CATEGORIES>



  <BEGIN CONVERSATION>



  User : {chat}



  <END CONVERSATION>



  Provide your safety assessment for 'User' in the above conversation:

  - First line must read 'safe' or 'unsafe'.

  - If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
  return prompt

def query(payload):
    API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
    bearer_txt = f'Bearer {HUGGINGFACEHUB_API_TOKEN}'
    headers = {
        "Accept": "application/json",
        "Authorization": bearer_txt,
        "Content-Type": "application/json"
    }
    try:
        response = requests.post(API_URL, headers=headers, json=payload)
        response.raise_for_status()  # This will raise an exception for HTTP error responses
        return response.json(), None
    except requests.exceptions.HTTPError as http_err:
        error_message = f"HTTP error occurred: {http_err}"
        print(error_message)
    except requests.exceptions.ConnectionError:
        error_message = "Could not connect to the API endpoint."
        print(error_message)
    except Exception as err:
        error_message = f"An error occurred: {err}"
        print(error_message)
    
    return None, error_message


def moderate_chat(chat):
    prompt = moderation_prompt_for_chat(chat)
    	
    output, error_msg = query({
        "inputs": prompt,
        "parameters": {
		"top_k": 1,
		"top_p": 0.2,
		"temperature": 0.1,
		"max_new_tokens": 512
	}
    })
  
    return output, error_msg


#added on March 24th
def load_category_names_from_string(file_content):
    """Load category codes and names from a string into a dictionary."""
    category_names = {}
    lines = file_content.split('\n')
    for line in lines:
        if line.startswith("O"):
            parts = line.split(':')
            if len(parts) == 2:
                code = parts[0].strip()
                name = parts[1].strip()
                category_names[code] = name
    return category_names

def get_category_name(input_str):
    """Return the category name given a category code from an input string."""
    # Load the category names from the file content
    category_names = load_category_names_from_string(unsafe_categories)
    
    # Extract the category code from the input string
    category_code = input_str.split('\n')[1].strip()
    
    # Find the full category name using the code
    category_name = category_names.get(category_code, "Unknown Category")
    
    #return f"{category_code} : {category_name}"
    return f"{category_name}"