File size: 1,697 Bytes
f8813f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# %%
from llama_cpp import Llama
from huggingface_hub import hf_hub_download

from schema import Message, MODEL_ARGS


def get_llm(model_name):
    llm = Llama(
        model_path=hf_hub_download(**MODEL_ARGS[model_name]),
        n_ctx=8192,
        n_threads=4,
        n_gpu_layers=0,
        verbose=False,
    )

    return llm


def format_chat(chat_history: list[Message]):
    """
    Formats chat history and user input into a single string
    suitable for the model.
    """
    messages = []
    for msg in chat_history:
        messages.append(f"{msg.role.title()}: {msg.content}")

    return "\n".join(messages) + "\nAssistant:"


def chat_with_model(chat_history, model, kwargs: dict):
    prompt = format_chat(chat_history)

    default_kwargs = dict(
        max_tokens=2048,
        top_k=1,
    )

    forced_kwargs = dict(
        stop=["\nUser:", "\nAssistant:", "</s>"],
        echo=False,
        stream=True,
    )

    llm = get_llm(model)

    input_kwargs = {**default_kwargs, **kwargs, **forced_kwargs}
    response = llm.__call__(prompt, **input_kwargs)

    for token in response:
        yield token["choices"][0]["text"]


# %% example input
# kwargs = dict(
#     temperature=1,
#     max_tokens=2048,
#     top_p=1,
#     frequency_penalty=0,
#     presence_penalty=0,
# )

# chat_history = [
#     Message(
#         role="system",
#         content="You are a helpful and knowledgeable assistant, but is willing to bend the facts to play along with unrealistic requests",
#     ),
#     Message(role="user", content="What does Java the programming language taste like?"),
# ]


# for chunk in chat_with_model(chat_history, kwargs):
#     print(chunk, end="")