File size: 5,754 Bytes
5269ad1
f97dae7
5b7f169
 
 
 
f97dae7
 
026d799
5b7f169
 
182a21a
bce909e
 
 
5269ad1
f97dae7
5269ad1
 
5b7f169
f97dae7
 
5b7f169
 
 
33193a0
2e81d77
bce909e
026d799
 
2cb730a
f97dae7
5b7f169
 
 
 
 
f97dae7
 
 
5b7f169
 
 
 
5269ad1
 
 
 
 
 
 
 
 
 
f97dae7
 
5269ad1
 
 
 
 
 
 
5b7f169
f97dae7
 
 
 
 
5b7f169
5269ad1
 
 
 
 
 
 
 
f97dae7
5269ad1
 
5b7f169
5269ad1
 
 
5b7f169
f97dae7
 
 
 
 
5b7f169
 
 
 
f97dae7
 
 
 
 
5b7f169
2e81d77
5b7f169
2e81d77
5b7f169
 
 
2e81d77
c786139
f97dae7
 
 
 
5b7f169
 
f97dae7
5b7f169
 
 
 
 
f97dae7
 
 
 
 
5b7f169
f97dae7
 
 
5b7f169
f97dae7
 
 
 
 
 
 
 
 
2e81d77
5b7f169
477d968
2e81d77
5b7f169
 
2e81d77
f97dae7
2e81d77
5b7f169
f97dae7
5b7f169
 
f97dae7
5b7f169
 
 
f97dae7
 
5269ad1
5b7f169
f97dae7
bce909e
 
5269ad1
f97dae7
 
 
d46878a
5b7f169
 
 
20a9c66
 
5b7f169
d46878a
5b7f169
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import math
import os
from time import sleep, time

import spaces
import torch
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoModelForCausalLM, AutoTokenizer

from logger import logger

# from vllm import LLM, SamplingParams


safe_token = "No"
risky_token = "Yes"
nlogprobs = 5

inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
logger.debug(f"Inference engine is: '{inference_engine}'")

if inference_engine == "VLLM":

    model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
    logger.debug(f"model_path is {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
    # model = LLM(model=model_path, tensor_parallel_size=1)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")

elif inference_engine == "WATSONX":
    client = APIClient(
        credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
    )

    client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
    hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)

    model_id = "ibm/granite-guardian-3-8b"  # 8B Model: "ibm/granite-guardian-3-8b"
    model = ModelInference(model_id=model_id, api_client=client)


def parse_output(output):
    label, prob = None, None

    if nlogprobs > 0:
        logprobs = next(iter(output.outputs)).logprobs
        if logprobs is not None:
            prob = get_probablities(logprobs)
            prob_of_risk = prob[1]

    res = next(iter(output.outputs)).text.strip()
    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()


def softmax(values):
    exp_values = [math.exp(v) for v in values]
    total = sum(exp_values)
    return [v / total for v in exp_values]


def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for token_prob in gen_token_i.values():
            decoded_token = token_prob.decoded_token
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token_prob.logprob)
            if decoded_token.strip().lower() == risky_token.lower():
                unsafe_token_prob += math.exp(token_prob.logprob)

    probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)

    return probabilities


def get_probablities_watsonx(top_tokens_list):
    safe_token_prob = 1e-50
    risky_token_prob = 1e-50
    for top_tokens in top_tokens_list:
        for token in top_tokens:
            if token["text"].strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token["logprob"])
            if token["text"].strip().lower() == risky_token.lower():
                risky_token_prob += math.exp(token["logprob"])

    probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])

    return probabilities


def get_prompt(messages, criteria_name):
    guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
    return tokenizer.apply_chat_template(
        messages, guardian_config=guardian_config, tokenize=False, add_generation_prompt=True
    )


@spaces.GPU
def generate_tokens(prompt):
    result = model.generate(
        prompt=[prompt],
        params={
            "decoding_method": "greedy",
            "max_new_tokens": 20,
            "temperature": 0,
            "return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
        },
    )
    return result[0]["results"][0]["generated_tokens"]


def parse_output_watsonx(generated_tokens_list):
    label, prob_of_risk = None, None

    if nlogprobs > 0:
        top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
        prob = get_probablities_watsonx(top_tokens_list)
        prob_of_risk = prob[1]

    res = next(iter(generated_tokens_list))["text"].strip()

    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk


@spaces.GPU
def generate_text(messages, criteria_name):
    logger.debug(f"Messages used to create the prompt are: \n{messages}")

    start = time()

    chat = get_prompt(messages, criteria_name)
    logger.debug(f"Prompt is \n{chat}")

    if inference_engine == "MOCK":
        logger.debug("Returning mocked model result.")
        sleep(1)
        label, prob_of_risk = "Yes", 0.97

    elif inference_engine == "WATSONX":
        generated_tokens = generate_tokens(chat)
        label, prob_of_risk = parse_output_watsonx(generated_tokens)

    elif inference_engine == "VLLM":
        with torch.no_grad():
            # output = model.generate(chat, sampling_params, use_tqdm=False)
            output = model.generate(chat, use_tqdm=False)

        label, prob_of_risk = parse_output(output[0])
    else:
        raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")

    logger.debug(f"Model generated label: \n{label}")
    logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")

    end = time()
    total = end - start
    logger.debug(f"The evaluation took {total} secs")

    return {"assessment": label, "certainty": prob_of_risk}