Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from time import time, sleep | |
from logger import logger | |
import math | |
safe_token = "No" | |
unsafe_token = "Yes" | |
nlogprobs = 5 | |
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true' | |
if not mock_model_call: | |
import torch | |
from vllm import LLM, SamplingParams | |
from transformers import AutoTokenizer | |
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a" | |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs) | |
model = LLM(model=model_path, tensor_parallel_size=1) | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
def parse_output(output): | |
label, prob = None, None | |
if nlogprobs > 0: | |
logprobs = next(iter(output.outputs)).logprobs | |
if logprobs is not None: | |
prob = get_probablities(logprobs) | |
prob_of_risk = prob[1] | |
res = next(iter(output.outputs)).text.strip() | |
if unsafe_token.lower() == res.lower(): | |
label = unsafe_token | |
elif safe_token.lower() == res.lower(): | |
label = safe_token | |
else: | |
label = "Failed" | |
return label, prob_of_risk.item() | |
def get_probablities(logprobs): | |
safe_token_prob = 1e-50 | |
unsafe_token_prob = 1e-50 | |
for gen_token_i in logprobs: | |
for token_prob in gen_token_i.values(): | |
decoded_token = token_prob.decoded_token | |
if decoded_token.strip().lower() == safe_token.lower(): | |
safe_token_prob += math.exp(token_prob.logprob) | |
if decoded_token.strip().lower() == unsafe_token.lower(): | |
unsafe_token_prob += math.exp(token_prob.logprob) | |
probabilities = torch.softmax( | |
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0 | |
) | |
return probabilities | |
def generate_text(prompt): | |
logger.debug(f'Prompts content is: \n{prompt["content"]}') | |
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true' | |
if mock_model_call: | |
logger.debug('Returning mocked model result.') | |
sleep(3) | |
return {'assessment': 'Yes', 'certainty': 0.97} | |
else: | |
start = time() | |
tokenized_chat = tokenizer.apply_chat_template([prompt], tokenize=False, add_generation_prompt=True) | |
with torch.no_grad(): | |
output = model.generate(tokenized_chat, sampling_params, use_tqdm=False) | |
# predicted_label = output[0].outputs[0].text.strip() | |
label, prob_of_risk = parse_output(output[0]) | |
logger.debug(f'Model generated label: \n{label}') | |
logger.debug(f'Model prob_of_risk: \n{prob_of_risk}') | |
end = time() | |
total = end - start | |
logger.debug(f'The evaluation took {total} secs') | |
return {'assessment': label, 'certainty': prob_of_risk} |