tungdop2's picture
init code
284cb2b
raw
history blame
1.39 kB
import os
import torch
from vllm import LLM, SamplingParams
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ChallengePromptGenerator:
def __init__(
self,
model_local_dir="checkpoint-15000",
):
self.generator = LLM(
model_local_dir,
dtype="bfloat16",
)
def infer_prompt(
self,
prompts,
max_generation_length=77,
beam_size=1,
sampling_temperature=0.9,
sampling_topk=1,
sampling_topp=1,
):
added_prompts = [f"{self.generator.get_tokenizer().bos_token} {prompt}" for prompt in prompts]
sampling_params = SamplingParams(
max_tokens=max_generation_length,
temperature=sampling_temperature,
top_k=sampling_topk,
top_p=sampling_topp,
use_beam_search=(beam_size > 1),
)
outputs = self.generator.generate(added_prompts, sampling_params)
out = []
for i in range(len(outputs)):
tmp_out = prompts[i] + outputs[i].outputs[0].text
if tmp_out[-1] != ".":
tmp_out = ".".join(tmp_out.split(".")[:-1]) + "."
out.append(tmp_out)
return out