granite-guardian / model.py
grahamwhiteuk's picture
Revert "feat: temporarily switch out to 2b model"
0caab14 verified
raw
history blame
7.36 kB
import math
import os
from time import sleep, time
import spaces
import torch
from ibm_watsonx_ai.client import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference
from transformers import AutoModelForCausalLM, AutoTokenizer
from logger import logger
# from vllm import LLM, SamplingParams
safe_token = "No"
risky_token = "Yes"
nlogprobs = 20
inference_engine = os.getenv("INFERENCE_ENGINE", "VLLM")
logger.debug(f"Inference engine is: '{inference_engine}'")
if inference_engine == "VLLM":
device = torch.device("cuda")
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
logger.debug(f"model_path is {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
# model = LLM(model=model_path, tensor_parallel_size=1)
model = AutoModelForCausalLM.from_pretrained(model_path)
model = model.to(device).eval()
elif inference_engine == "WATSONX":
client = APIClient(
credentials={"api_key": os.getenv("WATSONX_API_KEY"), "url": "https://us-south.ml.cloud.ibm.com"}
)
client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
model = ModelInference(model_id=model_id, api_client=client)
def parse_output(output, input_len):
label, prob_of_risk = None, None
if nlogprobs > 0:
list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
for token_i in list(output.scores)[:-1]]
if list_index_logprobs_i is not None:
prob = get_probablities(list_index_logprobs_i)
prob_of_risk = prob[1]
res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
if risky_token.lower() == res.lower():
label = risky_token
elif safe_token.lower() == res.lower():
label = safe_token
else:
label = "Failed"
return label, prob_of_risk.item()
def get_probablities(logprobs):
safe_token_prob = 1e-50
unsafe_token_prob = 1e-50
for gen_token_i in logprobs:
for logprob, index in zip(gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]):
decoded_token = tokenizer.convert_ids_to_tokens(index)
if decoded_token.strip().lower() == safe_token.lower():
safe_token_prob += math.exp(logprob)
if decoded_token.strip().lower() == risky_token.lower():
unsafe_token_prob += math.exp(logprob)
probabilities = torch.softmax(
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
)
return probabilities
def softmax(values):
exp_values = [math.exp(v) for v in values]
total = sum(exp_values)
return [v / total for v in exp_values]
def get_probablities_watsonx(top_tokens_list):
safe_token_prob = 1e-50
risky_token_prob = 1e-50
for top_tokens in top_tokens_list:
for token in top_tokens:
if token["text"].strip().lower() == safe_token.lower():
safe_token_prob += math.exp(token["logprob"])
if token["text"].strip().lower() == risky_token.lower():
risky_token_prob += math.exp(token["logprob"])
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
return probabilities
def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=False, return_tensors=None):
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
prompt = tokenizer.apply_chat_template(
messages,
guardian_config=guardian_config,
tokenize=tokenize,
add_generation_prompt=add_generation_prompt,
return_tensors=return_tensors
)
logger.debug(f'prompt is\n{prompt}')
return prompt
@spaces.GPU
def generate_tokens(prompt):
result = model.generate(
prompt=[prompt],
params={
"decoding_method": "greedy",
"max_new_tokens": 20,
"temperature": 0,
"return_options": {"token_logprobs": True, "generated_tokens": True, "input_text": True, "top_n_tokens": 5},
},
)
return result[0]["results"][0]["generated_tokens"]
def parse_output_watsonx(generated_tokens_list):
label, prob_of_risk = None, None
if nlogprobs > 0:
top_tokens_list = [generated_tokens["top_tokens"] for generated_tokens in generated_tokens_list]
prob = get_probablities_watsonx(top_tokens_list)
prob_of_risk = prob[1]
res = next(iter(generated_tokens_list))["text"].strip()
if risky_token.lower() == res.lower():
label = risky_token
elif safe_token.lower() == res.lower():
label = safe_token
else:
label = "Failed"
return label, prob_of_risk
@spaces.GPU
def generate_text(messages, criteria_name):
logger.debug(f"Messages used to create the prompt are: \n{messages}")
start = time()
if inference_engine == "MOCK":
logger.debug("Returning mocked model result.")
sleep(1)
label, prob_of_risk = "Yes", 0.97
elif inference_engine == "WATSONX":
chat = get_prompt(messages, criteria_name)
logger.debug(f"Prompt is \n{chat}")
generated_tokens = generate_tokens(chat)
label, prob_of_risk = parse_output_watsonx(generated_tokens)
elif inference_engine == "VLLM":
# input_ids = get_prompt(
# messages=messages,
# criteria_name=criteria_name,
# tokenize=True,
# add_generation_prompt=True,
# return_tensors="pt").to(model.device)
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
logger.debug(f'guardian_config is: {guardian_config}')
input_ids = tokenizer.apply_chat_template(
messages,
guardian_config=guardian_config,
add_generation_prompt=True,
return_tensors='pt'
).to(model.device)
logger.debug(f"input_ids are: {input_ids}")
input_len = input_ids.shape[1]
logger.debug(f"input_len are: {input_len}")
with torch.no_grad():
# output = model.generate(chat, sampling_params, use_tqdm=False)
output = model.generate(
input_ids,
do_sample=False,
max_new_tokens=nlogprobs,
return_dict_in_generate=True,
output_scores=True,)
logger.debug(f"model output is:\n{output}")
label, prob_of_risk = parse_output(output, input_len)
logger.debug(f"label is are: {label}")
logger.debug(f"prob_of_risk is are: {prob_of_risk}")
else:
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
logger.debug(f"Model generated label: \n{label}")
logger.debug(f"Model prob_of_risk: \n{prob_of_risk}")
end = time()
total = end - start
logger.debug(f"The evaluation took {total} secs")
return {"assessment": label, "certainty": prob_of_risk}