File size: 5,750 Bytes
d46878a
5269ad1
d46878a
5269ad1
f97dae7
 
 
 
 
477d968
182a21a
5269ad1
f97dae7
5269ad1
 
477d968
f97dae7
 
 
2cb730a
 
 
477d968
33193a0
2e81d77
2cb730a
 
 
f97dae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5269ad1
 
 
 
 
 
 
 
 
 
f97dae7
 
5269ad1
 
 
 
 
 
 
f97dae7
 
 
 
 
5269ad1
 
 
 
 
 
 
 
f97dae7
5269ad1
 
 
 
 
 
 
 
f97dae7
 
 
 
 
 
 
 
 
 
 
 
 
 
2e81d77
 
 
 
 
 
 
 
f97dae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e81d77
477d968
2e81d77
f97dae7
20a9c66
2e81d77
f97dae7
2e81d77
34f1382
f97dae7
 
 
 
 
34f1382
f97dae7
 
 
5269ad1
f97dae7
 
 
5269ad1
f97dae7
 
 
d46878a
20a9c66
 
 
 
 
 
d46878a
20a9c66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
from time import time, sleep
from logger import logger
import math
import os
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoTokenizer
import math
import spaces

safe_token = "No"
risky_token = "Yes"
nlogprobs = 5

inference_engine = os.getenv('INFERENCE_ENGINE', 'VLLM')
logger.debug(f"Inference engine is: '{inference_engine}'")

if inference_engine == 'VLLM':
    import torch
    from vllm import LLM, SamplingParams
    from transformers import AutoTokenizer
    model_path = os.getenv('MODEL_PATH', 'ibm-granite/granite-guardian-3.0-8b') 
    logger.debug(f"model_path is {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
    model = LLM(model=model_path, tensor_parallel_size=1)

elif inference_engine == "WATSONX":
    client = APIClient(credentials={
        'api_key': os.getenv('WATSONX_API_KEY'), 
        'url': 'https://us-south.ml.cloud.ibm.com'})
    
    client.set.default_project(os.getenv('WATSONX_PROJECT_ID'))
    hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
    tokenizer = AutoTokenizer.from_pretrained(hf_model_path)

    model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
    model = ModelInference(
        model_id=model_id,
        api_client=client
    )
    
def parse_output(output):
    label, prob = None, None

    if nlogprobs > 0:
        logprobs = next(iter(output.outputs)).logprobs
        if logprobs is not None:
            prob = get_probablities(logprobs)
            prob_of_risk = prob[1]

    res = next(iter(output.outputs)).text.strip()
    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk.item()

def softmax(values):
    exp_values = [math.exp(v) for v in values]
    total = sum(exp_values)
    return [v / total for v in exp_values]

def get_probablities(logprobs):
    safe_token_prob = 1e-50
    unsafe_token_prob = 1e-50
    for gen_token_i in logprobs:
        for token_prob in gen_token_i.values():
            decoded_token = token_prob.decoded_token
            if decoded_token.strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token_prob.logprob)
            if decoded_token.strip().lower() == risky_token.lower():
                unsafe_token_prob += math.exp(token_prob.logprob)

    probabilities = torch.softmax(
        torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
    )

    return probabilities

def get_probablities_watsonx(top_tokens_list):
    safe_token_prob = 1e-50
    risky_token_prob = 1e-50
    for top_tokens in top_tokens_list:
        for token in top_tokens:
            if token['text'].strip().lower() == safe_token.lower():
                safe_token_prob += math.exp(token['logprob'])
            if token['text'].strip().lower() == risky_token.lower():
                risky_token_prob += math.exp(token['logprob'])

    probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])

    return probabilities

def get_prompt(messages, criteria_name):
    guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
    return tokenizer.apply_chat_template(
        messages,
        guardian_config=guardian_config,
        tokenize=False,
        add_generation_prompt=True)

def generate_tokens(prompt):
    result = model.generate(
        prompt=[prompt],
        params={
            'decoding_method':'greedy',
            'max_new_tokens': 20, 
            "temperature": 0,
            "return_options": {
                "token_logprobs": True,
                "generated_tokens": True,
                "input_text": True,
                "top_n_tokens": 5
            }
        })
    return result[0]['results'][0]['generated_tokens']

def parse_output_watsonx(generated_tokens_list):
    label, prob_of_risk = None, None

    if nlogprobs > 0:
        top_tokens_list = [generated_tokens['top_tokens'] for generated_tokens in generated_tokens_list]
        prob = get_probablities_watsonx(top_tokens_list)
        prob_of_risk = prob[1]

    res = next(iter(generated_tokens_list))['text'].strip()

    if risky_token.lower() == res.lower():
        label = risky_token
    elif safe_token.lower() == res.lower():
        label = safe_token
    else:
        label = "Failed"

    return label, prob_of_risk

@spaces.GPU
def generate_text(messages, criteria_name):
    logger.debug(f'Messages used to create the prompt are: \n{messages}')
    
    start = time()

    chat = get_prompt(messages, criteria_name)
    logger.debug(f'Prompt is \n{chat}')

    if inference_engine=="MOCK":
        logger.debug('Returning mocked model result.')
        sleep(1)
        label, prob_of_risk = 'Yes', 0.97
    
    elif inference_engine=="WATSONX":
        generated_tokens = generate_tokens(chat)
        label, prob_of_risk = parse_output_watsonx(generated_tokens)

    elif inference_engine=="VLLM":
        with torch.no_grad():
            output = model.generate(chat, sampling_params, use_tqdm=False)

        label, prob_of_risk = parse_output(output[0])
    else:
        raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")

    logger.debug(f'Model generated label: \n{label}')
    logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
    
    end = time()
    total = end - start
    logger.debug(f'The evaluation took {total} secs')

    return {'assessment': label, 'certainty': prob_of_risk}