grahamwhiteuk commited on
Commit
0caab14
1 Parent(s): 2e41a22

Revert "feat: temporarily switch out to 2b model"

Browse files

This reverts commit 2e41a220488ada8d0d858681b691d78ef41327d8.

Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +19 -18
app.py CHANGED
@@ -205,7 +205,7 @@ with gr.Blocks(
205
  gr.HTML("<h2>IBM Granite Guardian 3.0</h2>", elem_classes="title")
206
  gr.HTML(
207
  elem_classes="system-description",
208
- value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-2b.</p>",
209
  )
210
  with gr.Row(elem_classes="column-gap"):
211
  with gr.Column(scale=0, elem_classes="no-gap"):
 
205
  gr.HTML("<h2>IBM Granite Guardian 3.0</h2>", elem_classes="title")
206
  gr.HTML(
207
  elem_classes="system-description",
208
+ value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8b.</p>",
209
  )
210
  with gr.Row(elem_classes="column-gap"):
211
  with gr.Column(scale=0, elem_classes="no-gap"):
model.py CHANGED
@@ -23,7 +23,7 @@ logger.debug(f"Inference engine is: '{inference_engine}'")
23
  if inference_engine == "VLLM":
24
  device = torch.device("cuda")
25
 
26
- model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-2b")
27
  logger.debug(f"model_path is {model_path}")
28
  tokenizer = AutoTokenizer.from_pretrained(model_path)
29
  # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
@@ -37,10 +37,10 @@ elif inference_engine == "WATSONX":
37
  )
38
 
39
  client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
40
- hf_model_path = "ibm-granite/granite-guardian-3.0-2b"
41
  tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
42
 
43
- model_id = "ibm/granite-guardian-3-2b" # 2b Model: "ibm/granite-guardian-3-2b"
44
  model = ModelInference(model_id=model_id, api_client=client)
45
 
46
 
@@ -48,14 +48,13 @@ def parse_output(output, input_len):
48
  label, prob_of_risk = None, None
49
  if nlogprobs > 0:
50
 
51
- list_index_logprobs_i = [
52
- torch.topk(token_i, k=nlogprobs, largest=True, sorted=True) for token_i in list(output.scores)[:-1]
53
- ]
54
  if list_index_logprobs_i is not None:
55
  prob = get_probablities(list_index_logprobs_i)
56
  prob_of_risk = prob[1]
57
 
58
- res = tokenizer.decode(output.sequences[:, input_len:][0], skip_special_tokens=True).strip()
59
  if risky_token.lower() == res.lower():
60
  label = risky_token
61
  elif safe_token.lower() == res.lower():
@@ -65,7 +64,6 @@ def parse_output(output, input_len):
65
 
66
  return label, prob_of_risk.item()
67
 
68
-
69
  def get_probablities(logprobs):
70
  safe_token_prob = 1e-50
71
  unsafe_token_prob = 1e-50
@@ -77,7 +75,9 @@ def get_probablities(logprobs):
77
  if decoded_token.strip().lower() == risky_token.lower():
78
  unsafe_token_prob += math.exp(logprob)
79
 
80
- probabilities = torch.softmax(torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0)
 
 
81
 
82
  return probabilities
83
 
@@ -87,7 +87,6 @@ def softmax(values):
87
  total = sum(exp_values)
88
  return [v / total for v in exp_values]
89
 
90
-
91
  def get_probablities_watsonx(top_tokens_list):
92
  safe_token_prob = 1e-50
93
  risky_token_prob = 1e-50
@@ -110,9 +109,9 @@ def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=Fa
110
  guardian_config=guardian_config,
111
  tokenize=tokenize,
112
  add_generation_prompt=add_generation_prompt,
113
- return_tensors=return_tensors,
114
  )
115
- logger.debug(f"prompt is\n{prompt}")
116
  return prompt
117
 
118
 
@@ -167,15 +166,18 @@ def generate_text(messages, criteria_name):
167
 
168
  elif inference_engine == "VLLM":
169
  # input_ids = get_prompt(
170
- # messages=messages,
171
- # criteria_name=criteria_name,
172
  # tokenize=True,
173
  # add_generation_prompt=True,
174
  # return_tensors="pt").to(model.device)
175
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
176
- logger.debug(f"guardian_config is: {guardian_config}")
177
  input_ids = tokenizer.apply_chat_template(
178
- messages, guardian_config=guardian_config, add_generation_prompt=True, return_tensors="pt"
 
 
 
179
  ).to(model.device)
180
  logger.debug(f"input_ids are: {input_ids}")
181
  input_len = input_ids.shape[1]
@@ -188,8 +190,7 @@ def generate_text(messages, criteria_name):
188
  do_sample=False,
189
  max_new_tokens=nlogprobs,
190
  return_dict_in_generate=True,
191
- output_scores=True,
192
- )
193
  logger.debug(f"model output is:\n{output}")
194
 
195
  label, prob_of_risk = parse_output(output, input_len)
 
23
  if inference_engine == "VLLM":
24
  device = torch.device("cuda")
25
 
26
+ model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
27
  logger.debug(f"model_path is {model_path}")
28
  tokenizer = AutoTokenizer.from_pretrained(model_path)
29
  # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
 
37
  )
38
 
39
  client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
40
+ hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
41
  tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
42
 
43
+ model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
44
  model = ModelInference(model_id=model_id, api_client=client)
45
 
46
 
 
48
  label, prob_of_risk = None, None
49
  if nlogprobs > 0:
50
 
51
+ list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
52
+ for token_i in list(output.scores)[:-1]]
 
53
  if list_index_logprobs_i is not None:
54
  prob = get_probablities(list_index_logprobs_i)
55
  prob_of_risk = prob[1]
56
 
57
+ res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
58
  if risky_token.lower() == res.lower():
59
  label = risky_token
60
  elif safe_token.lower() == res.lower():
 
64
 
65
  return label, prob_of_risk.item()
66
 
 
67
  def get_probablities(logprobs):
68
  safe_token_prob = 1e-50
69
  unsafe_token_prob = 1e-50
 
75
  if decoded_token.strip().lower() == risky_token.lower():
76
  unsafe_token_prob += math.exp(logprob)
77
 
78
+ probabilities = torch.softmax(
79
+ torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
80
+ )
81
 
82
  return probabilities
83
 
 
87
  total = sum(exp_values)
88
  return [v / total for v in exp_values]
89
 
 
90
  def get_probablities_watsonx(top_tokens_list):
91
  safe_token_prob = 1e-50
92
  risky_token_prob = 1e-50
 
109
  guardian_config=guardian_config,
110
  tokenize=tokenize,
111
  add_generation_prompt=add_generation_prompt,
112
+ return_tensors=return_tensors
113
  )
114
+ logger.debug(f'prompt is\n{prompt}')
115
  return prompt
116
 
117
 
 
166
 
167
  elif inference_engine == "VLLM":
168
  # input_ids = get_prompt(
169
+ # messages=messages,
170
+ # criteria_name=criteria_name,
171
  # tokenize=True,
172
  # add_generation_prompt=True,
173
  # return_tensors="pt").to(model.device)
174
  guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
175
+ logger.debug(f'guardian_config is: {guardian_config}')
176
  input_ids = tokenizer.apply_chat_template(
177
+ messages,
178
+ guardian_config=guardian_config,
179
+ add_generation_prompt=True,
180
+ return_tensors='pt'
181
  ).to(model.device)
182
  logger.debug(f"input_ids are: {input_ids}")
183
  input_len = input_ids.shape[1]
 
190
  do_sample=False,
191
  max_new_tokens=nlogprobs,
192
  return_dict_in_generate=True,
193
+ output_scores=True,)
 
194
  logger.debug(f"model output is:\n{output}")
195
 
196
  label, prob_of_risk = parse_output(output, input_len)