File size: 1,466 Bytes
8a41f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from application.llm.base import BaseLLM
from application.core.settings import settings

class LlamaCpp(BaseLLM):

    def __init__(self, api_key, llm_name=settings.MODEL_PATH, **kwargs):
        global llama
        try:
            from llama_cpp import Llama
        except ImportError:
            raise ImportError("Please install llama_cpp using pip install llama-cpp-python")

        llama = Llama(model_path=llm_name, n_ctx=2048)

    def gen(self, model, engine, messages, stream=False, **kwargs):
        context = messages[0]['content']
        user_question = messages[-1]['content']
        prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

        result = llama(prompt, max_tokens=150, echo=False)

        # import sys
        # print(result['choices'][0]['text'].split('### Answer \n')[-1], file=sys.stderr)
        
        return result['choices'][0]['text'].split('### Answer \n')[-1]

    def gen_stream(self, model, engine, messages, stream=True, **kwargs):
        context = messages[0]['content']
        user_question = messages[-1]['content']
        prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

        result = llama(prompt, max_tokens=150, echo=False, stream=stream)

        # import sys
        # print(list(result), file=sys.stderr)

        for item in result:
            for choice in item['choices']:
                yield choice['text']