from settings import * from typing import Iterator from llama_cpp import Llama from huggingface_hub import hf_hub_download from loguru import logger def download_model(): print(f"Downloading model") file = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILENAME ) print("Downloaded.") return file try: if MODEL_PATH is None: MODEL_PATH = download_model() except Exception as e: print(f"Error: {e}") exit() def load_llm(): print("Loading llm") return Llama(model_path=MODEL_PATH, n_ctx=MAX_INPUT_TOKEN_LENGTH, n_batch=LLAMA_N_BATCH, n_gpu_layers=LLAMA_N_GPU_LAYERS, seed=LLAMA_SEED, rms_norm_eps=LLAMA_RMS_NORM_EPS, verbose=LLAMA_VERBOSE) llm = load_llm() logger.info(f"done load llm") def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str): prompt="" for q, a in chat_history: prompt += f"USER: {q}\nASSISTANT: {a}\n\n" prompt += f"USER: {message}\nASSISTANT:" return system_prompt+"\n\n"+prompt def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int: prompt = get_prompt(message, chat_history, system_prompt) input_ids = llm.tokenize(prompt.encode('utf-8')) return len(input_ids) def run(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 49, repeat_penalty: float = 1.0) -> Iterator[str]: global llm llm.reset() prompt = get_prompt(message, chat_history, system_prompt) logger.info(f"New prompt") logger.info(f"{prompt}") #stop=[""] stop = ["USER:", "ASSISTANT:"] outputs = [] for text in llm(prompt, max_tokens=max_new_tokens, stop=stop, temperature=temperature, top_p=top_p, top_k=0, repeat_penalty=repeat_penalty, mirostat_mode=2, mirostat_tau=8.0, mirostat_eta=0.2, stream=True): outputs.append(text['choices'][0]['text']) yield ''.join(outputs)