kajdun's picture
Let there be light
ff70316
raw
history blame
1.95 kB
from settings import *
from typing import Iterator
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
def download_model():
print(f"Downloading model")
file = hf_hub_download(
repo_id=MODEL_REPO, filename=MODEL_FILENAME
)
print("Downloaded.")
return file
try:
if MODEL_PATH is None:
MODEL_PATH = download_model()
except Exception as e:
print(f"Error: {e}")
exit()
llm = Llama(model_path=MODEL_PATH,
n_ctx=MAX_INPUT_TOKEN_LENGTH,
n_batch=LLAMA_N_BATCH,
n_gpu_layers=LLAMA_N_GPU_LAYERS,
seed=LLAMA_SEED,
rms_norm_eps=LLAMA_RMS_NORM_EPS,
verbose=LLAMA_VERBOSE)
def get_prompt(message: str, chat_history: list[tuple[str, str]],
system_prompt: str):
prompt=""
for q, a in chat_history:
prompt += f"USER: {q}\nASSISTANT: {a}\n\n"
prompt += f"USER: {message}\nASSISTANT:"
return system_prompt+"\n\n"+prompt
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
prompt = get_prompt(message, chat_history, system_prompt)
input_ids = llm.tokenize(prompt.encode('utf-8'))
return len(input_ids)
def run(message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 49,
repeat_penalty: float = 1.0) -> Iterator[str]:
prompt = get_prompt(message, chat_history, system_prompt)
stop=["</s>"]
outputs = []
for text in llm(prompt,
max_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
stream=True):
outputs.append(text['choices'][0]['text'])
yield ''.join(outputs)