Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
f97dae7
1
Parent(s):
f492568
prepare for openshift deployment
Browse files- .dockerignore +8 -0
- .env.example +3 -1
- .gitignore +3 -2
- Dockerfile +7 -0
- cicd/build.sh +2 -0
- cicd/deploy.sh +3 -0
- cicd/push_image.sh +2 -0
- cicd/run.sh +3 -0
- deployment.yaml +86 -0
- requirements.txt +4 -1
- run_cicd.sh +3 -0
- src/app.py +2 -1
- src/logger.py +1 -1
- src/model.py +97 -17
- src/utils.py +4 -23
.dockerignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.*
|
3 |
+
*.yml
|
4 |
+
*.yaml
|
5 |
+
*.sh
|
6 |
+
*.md
|
7 |
+
__pycache__/
|
8 |
+
flagged/
|
.env.example
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
|
2 |
USE_CONDA='true'
|
3 |
-
|
|
|
|
|
|
1 |
MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
|
2 |
USE_CONDA='true'
|
3 |
+
INFERENCE_ENGINE='' # one of [WATSONX, MOCK, VLLM]
|
4 |
+
WATSONX_API_KEY=""
|
5 |
+
WATSONX_PROJECT_ID=""
|
.gitignore
CHANGED
@@ -2,5 +2,6 @@
|
|
2 |
.env
|
3 |
parse.py
|
4 |
unparsed_catalog.json
|
5 |
-
__pycache__
|
6 |
-
logs
|
|
|
|
2 |
.env
|
3 |
parse.py
|
4 |
unparsed_catalog.json
|
5 |
+
__pycache__/
|
6 |
+
logs.txt
|
7 |
+
secrets.yaml
|
Dockerfile
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.12-slim
|
2 |
+
WORKDIR /usr/src/app
|
3 |
+
COPY . .
|
4 |
+
RUN pip --disable-pip-version-check --no-cache-dir --no-input install -r requirements.txt
|
5 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
6 |
+
EXPOSE 7860
|
7 |
+
CMD ["python", "src/app.py"]
|
cicd/build.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
docker build --platform=linux/amd64 . -t granite-guardian
|
2 |
+
docker tag granite-guardian us.icr.io/research3/granite-guardian
|
cicd/deploy.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ibmcloud cr login
|
2 |
+
oc delete -f deployment.yaml
|
3 |
+
oc apply -f deployment.yaml
|
cicd/push_image.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
ibmcloud target -g aipt-experiments
|
2 |
+
docker push us.icr.io/research3/granite-guardian
|
cicd/run.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
./build.sh
|
2 |
+
./push_image.sh
|
3 |
+
./deploy.sh
|
deployment.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
apiVersion: apps/v1
|
2 |
+
kind: Deployment
|
3 |
+
metadata:
|
4 |
+
name: granite-guardian-pod
|
5 |
+
labels:
|
6 |
+
app: granite-guardian
|
7 |
+
spec:
|
8 |
+
selector:
|
9 |
+
matchLabels:
|
10 |
+
run: granite-guardian
|
11 |
+
replicas: 1
|
12 |
+
template:
|
13 |
+
metadata:
|
14 |
+
labels:
|
15 |
+
run: granite-guardian
|
16 |
+
spec:
|
17 |
+
containers:
|
18 |
+
- name: granite-guardian
|
19 |
+
image: us.icr.io/research3/granite-guardian
|
20 |
+
resources:
|
21 |
+
limits:
|
22 |
+
cpu: 1
|
23 |
+
memory: 2Gi
|
24 |
+
requests:
|
25 |
+
cpu: 1
|
26 |
+
memory: 2Gi
|
27 |
+
ports:
|
28 |
+
- containerPort: 7860
|
29 |
+
env:
|
30 |
+
- name: WATSONX_API_KEY
|
31 |
+
valueFrom:
|
32 |
+
secretKeyRef:
|
33 |
+
name: granite-guardian-secrets
|
34 |
+
key: WATSONX_API_KEY
|
35 |
+
- name: WATSONX_PROJECT_ID
|
36 |
+
valueFrom:
|
37 |
+
secretKeyRef:
|
38 |
+
name: granite-guardian-secrets
|
39 |
+
key: WATSONX_PROJECT_ID
|
40 |
+
- name: INFERENCE_ENGINE
|
41 |
+
valueFrom:
|
42 |
+
secretKeyRef:
|
43 |
+
name: granite-guardian-secrets
|
44 |
+
key: INFERENCE_ENGINE
|
45 |
+
imagePullSecrets:
|
46 |
+
- name: all-icr-io
|
47 |
+
---
|
48 |
+
apiVersion: v1
|
49 |
+
kind: Service
|
50 |
+
metadata:
|
51 |
+
name: granite-guardian-service
|
52 |
+
spec:
|
53 |
+
type: NodePort
|
54 |
+
sessionAffinity: "ClientIP"
|
55 |
+
selector:
|
56 |
+
run: granite-guardian
|
57 |
+
ports:
|
58 |
+
- port: 80
|
59 |
+
targetPort: 7860
|
60 |
+
protocol: TCP
|
61 |
+
---
|
62 |
+
apiVersion: networking.k8s.io/v1
|
63 |
+
kind: Ingress
|
64 |
+
metadata:
|
65 |
+
annotations:
|
66 |
+
ingress.kubernetes.io/allow-http: 'false'
|
67 |
+
ingress.kubernetes.io/ssl-redirect: 'true'
|
68 |
+
kubernetes.io/ingress.class: f5
|
69 |
+
virtual-server.f5.com/balance: round-robin
|
70 |
+
virtual-server.f5.com/ip: 9.12.246.36
|
71 |
+
virtual-server.f5.com/partition: RIS3-INT-OCP-DAL12
|
72 |
+
virtual-server.f5.com/clientssl: '[ { "bigIpProfile": "/Common/BlueMix" } ]'
|
73 |
+
name: granite-guardian-ingress
|
74 |
+
namespace: granite-guardian
|
75 |
+
spec:
|
76 |
+
rules:
|
77 |
+
- host: granite-guardian.bx.cloud9.ibm.com
|
78 |
+
http:
|
79 |
+
paths:
|
80 |
+
- backend:
|
81 |
+
service:
|
82 |
+
name: granite-guardian-service
|
83 |
+
port:
|
84 |
+
number: 80
|
85 |
+
path: /
|
86 |
+
pathType: ImplementationSpecific
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
-
gradio
|
2 |
python-dotenv
|
3 |
tqdm
|
4 |
jinja2
|
|
|
|
|
|
|
|
1 |
+
gradio>=4,<5
|
2 |
python-dotenv
|
3 |
tqdm
|
4 |
jinja2
|
5 |
+
ibm_watsonx_ai
|
6 |
+
transformers
|
7 |
+
gradio_modal
|
run_cicd.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
./cicd/build.sh
|
2 |
+
./cicd/push_image.sh
|
3 |
+
./cicd/deploy.sh
|
src/app.py
CHANGED
@@ -112,6 +112,7 @@ def on_show_prompt_click(criteria, context, user_message, assistant_message, sta
|
|
112 |
|
113 |
messages = get_messages(test_case=test_case, sub_catalog_name=state['selected_sub_catalog'])
|
114 |
prompt = get_prompt(messages, criteria_name)
|
|
|
115 |
prompt = prompt.replace('<', '<').replace('>', '>').replace('\\n', '<br>')
|
116 |
return gr.Markdown(prompt)
|
117 |
|
@@ -155,7 +156,7 @@ with gr.Blocks(
|
|
155 |
),
|
156 |
head=head_style,
|
157 |
fill_width=False,
|
158 |
-
css=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'styles.css')
|
159 |
) as demo:
|
160 |
|
161 |
state = gr.State(value={
|
|
|
112 |
|
113 |
messages = get_messages(test_case=test_case, sub_catalog_name=state['selected_sub_catalog'])
|
114 |
prompt = get_prompt(messages, criteria_name)
|
115 |
+
print(prompt)
|
116 |
prompt = prompt.replace('<', '<').replace('>', '>').replace('\\n', '<br>')
|
117 |
return gr.Markdown(prompt)
|
118 |
|
|
|
156 |
),
|
157 |
head=head_style,
|
158 |
fill_width=False,
|
159 |
+
css=os.path.join(os.path.dirname(os.path.abspath(__file__)), './styles.css')
|
160 |
) as demo:
|
161 |
|
162 |
state = gr.State(value={
|
src/logger.py
CHANGED
@@ -7,6 +7,6 @@ stream_handler = logging.StreamHandler()
|
|
7 |
stream_handler.setLevel(logging.DEBUG)
|
8 |
logger.addHandler(stream_handler)
|
9 |
|
10 |
-
file_handler = logging.FileHandler('logs')
|
11 |
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
12 |
logger.addHandler(file_handler)
|
|
|
7 |
stream_handler.setLevel(logging.DEBUG)
|
8 |
logger.addHandler(stream_handler)
|
9 |
|
10 |
+
file_handler = logging.FileHandler('logs.txt')
|
11 |
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
12 |
logger.addHandler(file_handler)
|
src/model.py
CHANGED
@@ -2,13 +2,20 @@ import os
|
|
2 |
from time import time, sleep
|
3 |
from logger import logger
|
4 |
import math
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
safe_token = "No"
|
7 |
-
|
8 |
nlogprobs = 5
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
import torch
|
13 |
from vllm import LLM, SamplingParams
|
14 |
from transformers import AutoTokenizer
|
@@ -18,6 +25,21 @@ if not mock_model_call:
|
|
18 |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
19 |
model = LLM(model=model_path, tensor_parallel_size=1)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def parse_output(output):
|
22 |
label, prob = None, None
|
23 |
|
@@ -28,8 +50,8 @@ def parse_output(output):
|
|
28 |
prob_of_risk = prob[1]
|
29 |
|
30 |
res = next(iter(output.outputs)).text.strip()
|
31 |
-
if
|
32 |
-
label =
|
33 |
elif safe_token.lower() == res.lower():
|
34 |
label = safe_token
|
35 |
else:
|
@@ -37,6 +59,11 @@ def parse_output(output):
|
|
37 |
|
38 |
return label, prob_of_risk.item()
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
def get_probablities(logprobs):
|
41 |
safe_token_prob = 1e-50
|
42 |
unsafe_token_prob = 1e-50
|
@@ -45,7 +72,7 @@ def get_probablities(logprobs):
|
|
45 |
decoded_token = token_prob.decoded_token
|
46 |
if decoded_token.strip().lower() == safe_token.lower():
|
47 |
safe_token_prob += math.exp(token_prob.logprob)
|
48 |
-
if decoded_token.strip().lower() ==
|
49 |
unsafe_token_prob += math.exp(token_prob.logprob)
|
50 |
|
51 |
probabilities = torch.softmax(
|
@@ -54,6 +81,20 @@ def get_probablities(logprobs):
|
|
54 |
|
55 |
return probabilities
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def get_prompt(messages, criteria_name):
|
58 |
guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
|
59 |
return tokenizer.apply_chat_template(
|
@@ -62,26 +103,65 @@ def get_prompt(messages, criteria_name):
|
|
62 |
tokenize=False,
|
63 |
add_generation_prompt=True)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def generate_text(messages, criteria_name):
|
67 |
-
logger.debug(f'Messages are: \n{messages}')
|
68 |
-
|
69 |
-
mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
|
70 |
-
if mock_model_call:
|
71 |
-
logger.debug('Returning mocked model result.')
|
72 |
-
sleep(1)
|
73 |
-
return {'assessment': 'Yes', 'certainty': 0.97}
|
74 |
|
75 |
start = time()
|
|
|
76 |
chat = get_prompt(messages, criteria_name)
|
77 |
logger.debug(f'Prompt is \n{chat}')
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
-
|
|
|
81 |
|
82 |
-
|
|
|
|
|
83 |
|
84 |
-
|
|
|
|
|
85 |
|
86 |
logger.debug(f'Model generated label: \n{label}')
|
87 |
logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
|
|
|
2 |
from time import time, sleep
|
3 |
from logger import logger
|
4 |
import math
|
5 |
+
import os
|
6 |
+
from ibm_watsonx_ai.client import APIClient
|
7 |
+
from ibm_watsonx_ai.foundation_models import ModelInference
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
import math
|
10 |
|
11 |
safe_token = "No"
|
12 |
+
risky_token = "Yes"
|
13 |
nlogprobs = 5
|
14 |
|
15 |
+
inference_engine = os.getenv('INFERENCE_ENGINE')
|
16 |
+
logger.debug(f"Inference engine is: '{inference_engine}'")
|
17 |
+
|
18 |
+
if inference_engine == 'VLLM':
|
19 |
import torch
|
20 |
from vllm import LLM, SamplingParams
|
21 |
from transformers import AutoTokenizer
|
|
|
25 |
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
26 |
model = LLM(model=model_path, tensor_parallel_size=1)
|
27 |
|
28 |
+
elif inference_engine == "WATSONX":
|
29 |
+
client = APIClient(credentials={
|
30 |
+
'api_key': os.getenv('WATSONX_API_KEY'),
|
31 |
+
'url': 'https://us-south.ml.cloud.ibm.com'})
|
32 |
+
|
33 |
+
client.set.default_project(os.getenv('WATSONX_PROJECT_ID'))
|
34 |
+
hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
|
36 |
+
|
37 |
+
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
|
38 |
+
model = ModelInference(
|
39 |
+
model_id=model_id,
|
40 |
+
api_client=client
|
41 |
+
)
|
42 |
+
|
43 |
def parse_output(output):
|
44 |
label, prob = None, None
|
45 |
|
|
|
50 |
prob_of_risk = prob[1]
|
51 |
|
52 |
res = next(iter(output.outputs)).text.strip()
|
53 |
+
if risky_token.lower() == res.lower():
|
54 |
+
label = risky_token
|
55 |
elif safe_token.lower() == res.lower():
|
56 |
label = safe_token
|
57 |
else:
|
|
|
59 |
|
60 |
return label, prob_of_risk.item()
|
61 |
|
62 |
+
def softmax(values):
|
63 |
+
exp_values = [math.exp(v) for v in values]
|
64 |
+
total = sum(exp_values)
|
65 |
+
return [v / total for v in exp_values]
|
66 |
+
|
67 |
def get_probablities(logprobs):
|
68 |
safe_token_prob = 1e-50
|
69 |
unsafe_token_prob = 1e-50
|
|
|
72 |
decoded_token = token_prob.decoded_token
|
73 |
if decoded_token.strip().lower() == safe_token.lower():
|
74 |
safe_token_prob += math.exp(token_prob.logprob)
|
75 |
+
if decoded_token.strip().lower() == risky_token.lower():
|
76 |
unsafe_token_prob += math.exp(token_prob.logprob)
|
77 |
|
78 |
probabilities = torch.softmax(
|
|
|
81 |
|
82 |
return probabilities
|
83 |
|
84 |
+
def get_probablities_watsonx(top_tokens_list):
|
85 |
+
safe_token_prob = 1e-50
|
86 |
+
risky_token_prob = 1e-50
|
87 |
+
for top_tokens in top_tokens_list:
|
88 |
+
for token in top_tokens:
|
89 |
+
if token['text'].strip().lower() == safe_token.lower():
|
90 |
+
safe_token_prob += math.exp(token['logprob'])
|
91 |
+
if token['text'].strip().lower() == risky_token.lower():
|
92 |
+
risky_token_prob += math.exp(token['logprob'])
|
93 |
+
|
94 |
+
probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
|
95 |
+
|
96 |
+
return probabilities
|
97 |
+
|
98 |
def get_prompt(messages, criteria_name):
|
99 |
guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
|
100 |
return tokenizer.apply_chat_template(
|
|
|
103 |
tokenize=False,
|
104 |
add_generation_prompt=True)
|
105 |
|
106 |
+
def generate_tokens(prompt):
|
107 |
+
result = model.generate(
|
108 |
+
prompt=[prompt],
|
109 |
+
params={
|
110 |
+
'decoding_method':'greedy',
|
111 |
+
'max_new_tokens': 20,
|
112 |
+
"temperature": 0,
|
113 |
+
"return_options": {
|
114 |
+
"token_logprobs": True,
|
115 |
+
"generated_tokens": True,
|
116 |
+
"input_text": True,
|
117 |
+
"top_n_tokens": 5
|
118 |
+
}
|
119 |
+
})
|
120 |
+
return result[0]['results'][0]['generated_tokens']
|
121 |
+
|
122 |
+
def parse_output_watsonx(generated_tokens_list):
|
123 |
+
label, prob_of_risk = None, None
|
124 |
+
|
125 |
+
if nlogprobs > 0:
|
126 |
+
top_tokens_list = [generated_tokens['top_tokens'] for generated_tokens in generated_tokens_list]
|
127 |
+
prob = get_probablities_watsonx(top_tokens_list)
|
128 |
+
prob_of_risk = prob[1]
|
129 |
+
|
130 |
+
res = next(iter(generated_tokens_list))['text'].strip()
|
131 |
+
|
132 |
+
if risky_token.lower() == res.lower():
|
133 |
+
label = risky_token
|
134 |
+
elif safe_token.lower() == res.lower():
|
135 |
+
label = safe_token
|
136 |
+
else:
|
137 |
+
label = "Failed"
|
138 |
+
|
139 |
+
return label, prob_of_risk
|
140 |
|
141 |
def generate_text(messages, criteria_name):
|
142 |
+
logger.debug(f'Messages used to create the prompt are: \n{messages}')
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
start = time()
|
145 |
+
|
146 |
chat = get_prompt(messages, criteria_name)
|
147 |
logger.debug(f'Prompt is \n{chat}')
|
148 |
+
|
149 |
+
if inference_engine=="MOCK":
|
150 |
+
logger.debug('Returning mocked model result.')
|
151 |
+
sleep(1)
|
152 |
+
label, prob_of_risk = 'Yes', 0.97
|
153 |
|
154 |
+
elif inference_engine=="WATSONX":
|
155 |
+
generated_tokens = generate_tokens(chat)
|
156 |
+
label, prob_of_risk = parse_output_watsonx(generated_tokens)
|
157 |
|
158 |
+
elif inference_engine=="VLLM":
|
159 |
+
with torch.no_grad():
|
160 |
+
output = model.generate(chat, sampling_params, use_tqdm=False)
|
161 |
|
162 |
+
label, prob_of_risk = parse_output(output[0])
|
163 |
+
else:
|
164 |
+
raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
|
165 |
|
166 |
logger.debug(f'Model generated label: \n{label}')
|
167 |
logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
|
src/utils.py
CHANGED
@@ -1,27 +1,6 @@
|
|
1 |
-
import json
|
2 |
-
from jinja2 import Template
|
3 |
import argparse
|
4 |
import os
|
5 |
|
6 |
-
# with open('prompt_templates.json', mode='r', encoding="utf-8") as f:
|
7 |
-
# prompt_templates = json.load(f)
|
8 |
-
|
9 |
-
# def assessment_prompt(content):
|
10 |
-
# return {"role": "user", "content": content}
|
11 |
-
|
12 |
-
# def get_prompt_template(test_case, sub_catalog_name):
|
13 |
-
# test_case_name = test_case['name']
|
14 |
-
# if sub_catalog_name == 'harmful_content_in_user_prompt':
|
15 |
-
# template_type = 'prompt'
|
16 |
-
# elif sub_catalog_name == 'harmful_content_in_assistant_response':
|
17 |
-
# template_type = 'prompt_response'
|
18 |
-
# elif sub_catalog_name == 'rag_hallucination_risks':
|
19 |
-
# template_type = test_case_name
|
20 |
-
# return prompt_templates[f'{test_case_name}>{template_type}']
|
21 |
-
|
22 |
-
# def get_prompt_from_test_case(test_case, sub_catalog_name):
|
23 |
-
# return assessment_prompt(Template(get_prompt_template(test_case, sub_catalog_name)).render(**test_case))
|
24 |
-
|
25 |
def get_messages(test_case, sub_catalog_name) -> list[dict[str,str]]:
|
26 |
messages = []
|
27 |
|
@@ -76,14 +55,16 @@ def get_evaluated_component(sub_catalog_name, criteria_name):
|
|
76 |
return component
|
77 |
|
78 |
def to_title_case(input_string):
|
79 |
-
if input_string == 'rag_hallucination_risks':
|
80 |
return 'RAG Hallucination Risks'
|
81 |
return ' '.join(word.capitalize() for word in input_string.split('_'))
|
82 |
|
|
|
|
|
|
|
83 |
def to_snake_case(text):
|
84 |
return text.lower().replace(" ", "_")
|
85 |
|
86 |
-
|
87 |
def load_command_line_args():
|
88 |
parser = argparse.ArgumentParser()
|
89 |
parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def get_messages(test_case, sub_catalog_name) -> list[dict[str,str]]:
|
5 |
messages = []
|
6 |
|
|
|
55 |
return component
|
56 |
|
57 |
def to_title_case(input_string):
|
58 |
+
if input_string == 'rag_hallucination_risks':
|
59 |
return 'RAG Hallucination Risks'
|
60 |
return ' '.join(word.capitalize() for word in input_string.split('_'))
|
61 |
|
62 |
+
def capitalize_first_word(input_string):
|
63 |
+
return ' '.join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split('_')))
|
64 |
+
|
65 |
def to_snake_case(text):
|
66 |
return text.lower().replace(" ", "_")
|
67 |
|
|
|
68 |
def load_command_line_args():
|
69 |
parser = argparse.ArgumentParser()
|
70 |
parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
|