Martín Santillán Cooper commited on
Commit
2cecaad
1 Parent(s): aa94892

restructure files

Browse files
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio
2
  python-dotenv
3
  tqdm
4
  jinja2
 
1
+ gradio>=5.0
2
  python-dotenv
3
  tqdm
4
  jinja2
app.py → src/app.py RENAMED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  from dotenv import load_dotenv
3
 
4
- from utils import get_result_description, to_title_case, get_prompt_from_test_case, to_snake_case
 
5
  load_dotenv()
6
  import json
7
  from model import generate_text
@@ -99,15 +100,15 @@ head_style = """
99
  """
100
 
101
  with gr.Blocks(
102
- title='Granite Guardian',
103
- theme=gr.themes.Soft(
104
- primary_hue=ibm_blue,
105
  font=[gr.themes.GoogleFont("IBM Plex Sans"), gr.themes.GoogleFont('Source Sans 3')],
106
- ),
107
- head=head_style,
108
- fill_width=False,
109
- css='styles.css') as demo:
110
-
111
 
112
  state = gr.State(value={
113
  'selected_sub_catalog': 'harmful_content_in_user_prompt',
@@ -116,7 +117,7 @@ with gr.Blocks(
116
 
117
  starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == state.value['selected_criteria_name'] and sub_catalog_name == state.value['selected_sub_catalog']][0]
118
 
119
- with gr.Row(elem_classes='title-row'):
120
  with gr.Column(scale=4):
121
  gr.HTML('<h2>IBM Granite Guardian 3.0</h2>', elem_classes='title')
122
  gr.HTML(elem_classes='system-description', value='<p>Granite Guardian models are specialized language models in the Granite family that allow you to detect harms and risks in generative AI systems. The Granite Guardian models can be used with any other large language models to make interactions with generative AI systems safe. Select an example in the left panel to see how the model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8B.</p>')
@@ -152,9 +153,17 @@ with gr.Blocks(
152
  user_message = gr.Textbox(label="User Prompt", lines=3, interactive=True, value=starting_test_case['user_message'], elem_classes=['input-box'])
153
  assistant_message = gr.Textbox(label="Assistant Response", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'], elem_classes=['input-box'])
154
 
155
- submit_button = gr.Button("Evaluate", variant='primary',icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'send-white.png'), elem_classes='submit-button')
 
 
 
 
156
 
157
- result_text = gr.Textbox(label='Result', elem_classes=['result-text', 'read-only', 'input-box'], visible=False, value='')
 
 
 
 
158
 
159
  with Modal(visible=False, elem_classes='modal') as modal:
160
  prompt = gr.Markdown('')
@@ -173,11 +182,20 @@ with gr.Blocks(
173
  outputs=[result_text],
174
  scroll_to_output=True
175
  )
176
-
177
 
178
  for button in [t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()]:
179
- button.click(update_selected_test_case, inputs=[button, state], outputs=[state])\
180
- .then(on_test_case_click, inputs=state, outputs={test_case_name, criteria, context, user_message, assistant_message, result_text}) \
181
- .then(change_button_color, None, [v for c in catalog_buttons.values() for v in c.values()])
 
 
 
 
 
 
 
 
 
 
182
 
183
  demo.launch(server_name='0.0.0.0')
 
1
  import gradio as gr
2
  from dotenv import load_dotenv
3
 
4
+ from utils import get_result_description, to_title_case, get_prompt_from_test_case, to_snake_case, load_command_line_args
5
+ load_command_line_args()
6
  load_dotenv()
7
  import json
8
  from model import generate_text
 
100
  """
101
 
102
  with gr.Blocks(
103
+ title='Granite Guardian',
104
+ theme=gr.themes.Soft(
105
+ primary_hue=ibm_blue,
106
  font=[gr.themes.GoogleFont("IBM Plex Sans"), gr.themes.GoogleFont('Source Sans 3')],
107
+ ),
108
+ head=head_style,
109
+ fill_width=False,
110
+ css=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'styles.css')
111
+ ) as demo:
112
 
113
  state = gr.State(value={
114
  'selected_sub_catalog': 'harmful_content_in_user_prompt',
 
117
 
118
  starting_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == state.value['selected_criteria_name'] and sub_catalog_name == state.value['selected_sub_catalog']][0]
119
 
120
+ with gr.Row(elem_classes='header-row'):
121
  with gr.Column(scale=4):
122
  gr.HTML('<h2>IBM Granite Guardian 3.0</h2>', elem_classes='title')
123
  gr.HTML(elem_classes='system-description', value='<p>Granite Guardian models are specialized language models in the Granite family that allow you to detect harms and risks in generative AI systems. The Granite Guardian models can be used with any other large language models to make interactions with generative AI systems safe. Select an example in the left panel to see how the model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8B.</p>')
 
153
  user_message = gr.Textbox(label="User Prompt", lines=3, interactive=True, value=starting_test_case['user_message'], elem_classes=['input-box'])
154
  assistant_message = gr.Textbox(label="Assistant Response", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'], elem_classes=['input-box'])
155
 
156
+ submit_button = gr.Button(
157
+ "Evaluate",
158
+ variant='primary',
159
+ icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'send-white.png'),
160
+ elem_classes='submit-button')
161
 
162
+ result_text = gr.Textbox(
163
+ label='Result',
164
+ elem_classes=['result-text', 'read-only', 'input-box'],
165
+ visible=False,
166
+ value='')
167
 
168
  with Modal(visible=False, elem_classes='modal') as modal:
169
  prompt = gr.Markdown('')
 
182
  outputs=[result_text],
183
  scroll_to_output=True
184
  )
 
185
 
186
  for button in [t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values()]:
187
+ button \
188
+ .click(
189
+ change_button_color,
190
+ inputs=None,
191
+ outputs=[v for c in catalog_buttons.values() for v in c.values()]) \
192
+ .then(
193
+ update_selected_test_case,
194
+ inputs=[button, state],
195
+ outputs=[state]) \
196
+ .then(
197
+ on_test_case_click,
198
+ inputs=state,
199
+ outputs={test_case_name, criteria, context, user_message, assistant_message, result_text})
200
 
201
  demo.launch(server_name='0.0.0.0')
logger.py → src/logger.py RENAMED
File without changes
model.py → src/model.py RENAMED
@@ -12,7 +12,7 @@ if not mock_model_call:
12
  import torch
13
  from vllm import LLM, SamplingParams
14
  from transformers import AutoTokenizer
15
- model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
16
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
17
  model = LLM(model=model_path, tensor_parallel_size=1)
18
  tokenizer = AutoTokenizer.from_pretrained(model_path)
 
12
  import torch
13
  from vllm import LLM, SamplingParams
14
  from transformers import AutoTokenizer
15
+ model_path = os.getenv('MODEL_PATH') #"granite-guardian-3b-pipecleaner-r241024a"
16
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
17
  model = LLM(model=model_path, tensor_parallel_size=1)
18
  tokenizer = AutoTokenizer.from_pretrained(model_path)
send-white.png → src/send-white.png RENAMED
File without changes
styles.css → src/styles.css RENAMED
@@ -1,4 +1,4 @@
1
- .title-row {
2
  margin-bottom: 0.75rem;
3
  }
4
 
@@ -7,6 +7,10 @@
7
  margin-bottom: -0.25rem;
8
  }
9
 
 
 
 
 
10
  .title h2 {
11
  font-weight: 600;
12
  font-size: 30px;
@@ -29,7 +33,6 @@
29
  justify-content: flex-start;
30
  background-color: transparent;
31
  box-shadow: none;
32
-
33
  }
34
 
35
  .selected {
 
1
+ .header-row {
2
  margin-bottom: 0.75rem;
3
  }
4
 
 
7
  margin-bottom: -0.25rem;
8
  }
9
 
10
+ .title div {
11
+ overflow-y: hidden;
12
+ }
13
+
14
  .title h2 {
15
  font-weight: 600;
16
  font-size: 30px;
 
33
  justify-content: flex-start;
34
  background-color: transparent;
35
  box-shadow: none;
 
36
  }
37
 
38
  .selected {
utils.py → src/utils.py RENAMED
@@ -1,5 +1,7 @@
1
  import json
2
  from jinja2 import Template
 
 
3
 
4
  with open('prompt_templates.json', mode='r', encoding="utf-8") as f:
5
  prompt_templates = json.load(f)
@@ -60,3 +62,15 @@ def to_title_case(input_string):
60
 
61
  def to_snake_case(text):
62
  return text.lower().replace(" ", "_")
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  from jinja2 import Template
3
+ import argparse
4
+ import os
5
 
6
  with open('prompt_templates.json', mode='r', encoding="utf-8") as f:
7
  prompt_templates = json.load(f)
 
62
 
63
  def to_snake_case(text):
64
  return text.lower().replace(" ", "_")
65
+
66
+
67
+ def load_command_line_args():
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
70
+
71
+ # Parse arguments
72
+ args = parser.parse_args()
73
+
74
+ # Store the argument in an environment variable
75
+ if args.model_path is not None:
76
+ os.environ["MODEL_PATH"] = args.model_path