File size: 2,192 Bytes
f104fde
76a2333
 
f104fde
 
76a2333
 
 
 
 
 
 
 
 
f104fde
76a2333
 
 
 
f104fde
76a2333
949a808
f104fde
cb82b91
f104fde
 
 
 
 
 
 
 
76a2333
 
 
 
 
 
 
 
 
f104fde
 
 
76a2333
 
f104fde
 
 
 
 
 
 
 
 
 
76a2333
 
f104fde
76a2333
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Iterator
from llama_cpp import Llama
from huggingface_hub import hf_hub_download


def download_model():
    # See https://github.com/OpenAccess-AI-Collective/ggml-webui/blob/main/tabbed.py
    # https://huggingface.co/spaces/kat33/llama.cpp/blob/main/app.py
    print(f"Downloading model: {model_repo}/{model_filename}")
    file = hf_hub_download(
            repo_id=model_repo, filename=model_filename
    )
    print("Downloaded " + file)
    return file

model_repo = "LinkSoul/Chinese-Llama-2-7b-ggml"
model_filename = "Chinese-Llama-2-7b.ggmlv3.q4_0.bin"
# model_filename = "Chinese-Llama-2-7b.ggmlv3.q8_0.bin"
model_path = download_model()

# load Llama-2
llm = Llama(model_path=model_path, n_ctx=4000, verbose=False)


def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    for user_input, response in chat_history:
        texts.append(f'{user_input.strip()} [/INST] {response.strip()} </s><s> [INST] ')
    texts.append(f'{message.strip()} [/INST]')
    return ''.join(texts)

def generate(prompt, max_new_tokens, temperature, top_p, top_k):
    return llm(prompt,
            max_tokens=max_new_tokens,
            stop=["</s>"],
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            stream=False)


def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = llm.tokenize(prompt.encode('utf-8'))
    return len(input_ids)


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50) -> Iterator[str]:
    prompt = get_prompt(message, chat_history, system_prompt)
    output = generate(prompt, max_new_tokens, temperature, top_p, top_k)
    yield output['choices'][0]['text']

    # outputs = []
    # for resp in streamer:
    #     outputs.append(resp['choices'][0]['text'])
    #     yield ''.join(outputs)