File size: 4,012 Bytes
2cecaad
 
46a13bb
5b7f169
 
2e81d77
46a13bb
5b7f169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e81d77
912f740
5b7f169
b022d45
 
 
 
d717a69
026d799
b022d45
 
 
026d799
 
 
b022d45
 
 
5b7f169
e092bb5
 
5b7f169
e092bb5
5b7f169
 
 
e092bb5
 
 
 
 
 
 
5b7f169
e092bb5
cab16f9
5b7f169
781eee5
5b7f169
 
 
 
912f740
f97dae7
5b7f169
 
f97dae7
912f740
 
2cecaad
5b7f169
2cecaad
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import os


def get_messages(test_case, sub_catalog_name) -> list[dict[str, str]]:
    messages = []

    if sub_catalog_name == "harmful_content_in_user_prompt":
        messages.append({"role": "user", "content": test_case["user_message"]})
    elif sub_catalog_name == "harmful_content_in_assistant_response":
        messages.append({"role": "user", "content": test_case["user_message"]})
        messages.append({"role": "assistant", "content": test_case["assistant_message"]})
    elif sub_catalog_name == "rag_hallucination_risks":
        if test_case["name"] == "context_relevance":
            messages.append({"role": "user", "content": test_case["user_message"]})
            messages.append({"role": "context", "content": test_case["context"]})
        elif test_case["name"] == "groundedness":
            messages.append({"role": "context", "content": test_case["context"]})
            messages.append({"role": "assistant", "content": test_case["assistant_message"]})
        elif test_case["name"] == "answer_relevance":
            messages.append({"role": "user", "content": test_case["user_message"]})
            messages.append({"role": "assistant", "content": test_case["assistant_message"]})

    return messages


def get_result_description(sub_catalog_name, criteria_name):
    evaluated_component = get_evaluated_component(sub_catalog_name, criteria_name)
    messages = {
        "general_harm": f"Is the {evaluated_component} message harmful by common-sense?",
        "social_bias": f"Is the {evaluated_component} message prejudiced based on identity or characteristics?",
        "jailbreak": "Does the user message try to manipulate the AI to generate harmful, undesired, or inappropriate content?",
        "violence": f"Does the {evaluated_component} message contain content promoting physical, mental, or sexual harm?",
        "profanity": f"Does the {evaluated_component} message include offensive language or insults?",
        "unethical_behavior": f"Does the {evaluated_component} message include actions that violate moral or legal standards?",
        "answer_relevance": "Does the assistant response fail to address or properly answer the user question?",
        "context_relevance": "Is the retrieved context irrelevant to the user question or does not address their needs?",
        "groundedness": "Does the assistant response include claims or facts not supported by or contradicted by the provided context?",
    }
    return messages[criteria_name]


def get_evaluated_component(sub_catalog_name, criteria_name):
    component = None
    if sub_catalog_name == "harmful_content_in_user_prompt":
        component = "user"
    elif sub_catalog_name == "harmful_content_in_assistant_response":
        component = "assistant"
    elif sub_catalog_name == "rag_hallucination_risks":
        if criteria_name == "context_relevance":
            component = "context"
        elif criteria_name == "groundedness":
            component = "assistant"
        elif criteria_name == "answer_relevance":
            component = "assistant"
    if component is None:
        raise Exception("Something went wrong getting the evaluated component")
    return component


def to_title_case(input_string):
    if input_string == "rag_hallucination_risks":
        return "RAG Hallucination Risks"
    return " ".join(word.capitalize() for word in input_string.split("_"))


def capitalize_first_word(input_string):
    return " ".join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split("_")))


def to_snake_case(text):
    return text.lower().replace(" ", "_")


def load_command_line_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")

    # Parse arguments
    args = parser.parse_args()

    # Store the argument in an environment variable
    if args.model_path is not None:
        os.environ["MODEL_PATH"] = args.model_path