Spaces:
Running
on
Zero
Running
on
Zero
grahamwhiteuk
commited on
Commit
•
bce909e
1
Parent(s):
c786139
fix: remove sampling params
Browse files- model.py +6 -3
- requirements.txt +0 -1
model.py
CHANGED
@@ -7,10 +7,12 @@ import torch
|
|
7 |
from ibm_watsonx_ai.client import APIClient
|
8 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
-
from vllm import LLM, SamplingParams
|
11 |
|
12 |
from logger import logger
|
13 |
|
|
|
|
|
|
|
14 |
safe_token = "No"
|
15 |
risky_token = "Yes"
|
16 |
nlogprobs = 5
|
@@ -23,7 +25,7 @@ if inference_engine == "VLLM":
|
|
23 |
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
|
24 |
logger.debug(f"model_path is {model_path}")
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
26 |
-
sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
27 |
# model = LLM(model=model_path, tensor_parallel_size=1)
|
28 |
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
|
29 |
|
@@ -158,7 +160,8 @@ def generate_text(messages, criteria_name):
|
|
158 |
|
159 |
elif inference_engine == "VLLM":
|
160 |
with torch.no_grad():
|
161 |
-
output = model.generate(chat, sampling_params, use_tqdm=False)
|
|
|
162 |
|
163 |
label, prob_of_risk = parse_output(output[0])
|
164 |
else:
|
|
|
7 |
from ibm_watsonx_ai.client import APIClient
|
8 |
from ibm_watsonx_ai.foundation_models import ModelInference
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
10 |
|
11 |
from logger import logger
|
12 |
|
13 |
+
# from vllm import LLM, SamplingParams
|
14 |
+
|
15 |
+
|
16 |
safe_token = "No"
|
17 |
risky_token = "Yes"
|
18 |
nlogprobs = 5
|
|
|
25 |
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
|
26 |
logger.debug(f"model_path is {model_path}")
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
28 |
+
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
29 |
# model = LLM(model=model_path, tensor_parallel_size=1)
|
30 |
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
|
31 |
|
|
|
160 |
|
161 |
elif inference_engine == "VLLM":
|
162 |
with torch.no_grad():
|
163 |
+
# output = model.generate(chat, sampling_params, use_tqdm=False)
|
164 |
+
output = model.generate(chat, use_tqdm=False)
|
165 |
|
166 |
label, prob_of_risk = parse_output(output[0])
|
167 |
else:
|
requirements.txt
CHANGED
@@ -3,4 +3,3 @@ python-dotenv
|
|
3 |
transformers
|
4 |
accelerate
|
5 |
ibm_watsonx_ai
|
6 |
-
vllm
|
|
|
3 |
transformers
|
4 |
accelerate
|
5 |
ibm_watsonx_ai
|
|