grahamwhiteuk commited on
Commit
bce909e
1 Parent(s): c786139

fix: remove sampling params

Browse files
Files changed (2) hide show
  1. model.py +6 -3
  2. requirements.txt +0 -1
model.py CHANGED
@@ -7,10 +7,12 @@ import torch
7
  from ibm_watsonx_ai.client import APIClient
8
  from ibm_watsonx_ai.foundation_models import ModelInference
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
- from vllm import LLM, SamplingParams
11
 
12
  from logger import logger
13
 
 
 
 
14
  safe_token = "No"
15
  risky_token = "Yes"
16
  nlogprobs = 5
@@ -23,7 +25,7 @@ if inference_engine == "VLLM":
23
  model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
24
  logger.debug(f"model_path is {model_path}")
25
  tokenizer = AutoTokenizer.from_pretrained(model_path)
26
- sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
27
  # model = LLM(model=model_path, tensor_parallel_size=1)
28
  model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
29
 
@@ -158,7 +160,8 @@ def generate_text(messages, criteria_name):
158
 
159
  elif inference_engine == "VLLM":
160
  with torch.no_grad():
161
- output = model.generate(chat, sampling_params, use_tqdm=False)
 
162
 
163
  label, prob_of_risk = parse_output(output[0])
164
  else:
 
7
  from ibm_watsonx_ai.client import APIClient
8
  from ibm_watsonx_ai.foundation_models import ModelInference
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
10
 
11
  from logger import logger
12
 
13
+ # from vllm import LLM, SamplingParams
14
+
15
+
16
  safe_token = "No"
17
  risky_token = "Yes"
18
  nlogprobs = 5
 
25
  model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
26
  logger.debug(f"model_path is {model_path}")
27
  tokenizer = AutoTokenizer.from_pretrained(model_path)
28
+ # sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
29
  # model = LLM(model=model_path, tensor_parallel_size=1)
30
  model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
31
 
 
160
 
161
  elif inference_engine == "VLLM":
162
  with torch.no_grad():
163
+ # output = model.generate(chat, sampling_params, use_tqdm=False)
164
+ output = model.generate(chat, use_tqdm=False)
165
 
166
  label, prob_of_risk = parse_output(output[0])
167
  else:
requirements.txt CHANGED
@@ -3,4 +3,3 @@ python-dotenv
3
  transformers
4
  accelerate
5
  ibm_watsonx_ai
6
- vllm
 
3
  transformers
4
  accelerate
5
  ibm_watsonx_ai