from dotenv import load_dotenv, find_dotenv import os import requests # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' load_dotenv(find_dotenv()) HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] with open("data/llama_guard-unsafe_files.txt", "r") as file: unsafe_categories = file.read() def moderation_prompt_for_chat(chat): # For simplicity, we assume the chat dictionary correctly alternates "user" and "assistant" roles # Feel free to include error checking if needed conversation = [turn["content"] for turn in chat] is_agent = len(conversation) % 2 == 0 role = "Agent" if len(conversation) % 2 == 0 else "User" prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. {unsafe_categories} """ # Alternate User/Agent turns, inserting 2 newlines between each for i, m in enumerate(conversation): role = "User" if i % 2 == 0 else "Agent" prompt += f"{role}: {m}\n\n" prompt += f""" Provide your safety assessment for {role} in the above conversation: - First line must read 'safe' or 'unsafe'. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" return prompt def query(payload): API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud" bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}" headers = { "Accept": "application/json", "Authorization": bearer_txt, "Content-Type": "application/json", } response = requests.post(API_URL, headers=headers, json=payload) return response.json() def moderate_chat(chat): prompt = moderation_prompt_for_chat(chat) output = query( { "inputs": prompt, "parameters": { "top_k": 1, "top_p": 0.2, "temperature": 0.1, "max_new_tokens": 512, }, } ) return output