File size: 4,266 Bytes
0cc999a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

import json
import os

from huggingface_hub import hf_hub_download
from llama_cpp import Llama

from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel

SYS_PREFIX = "<<SYS>>\n"
SYS_POSTFIX = "\n<</SYS>>\n\n"
INST_PREFIX = "<s>[INST] "
INST_POSTFIX = " "
OUTPUT_PREFIX = "[/INST] "
OUTPUT_POSTFIX = "</s>"


def download(repo_id, filename, retry=10):
    if os.path.exists("./models/downloaded_models.json"):
        with open("./models/downloaded_models.json", "r") as f:
            downloaded_models = json.load(f)
        if repo_id in downloaded_models:
            return downloaded_models[repo_id]["path"]
    else:
        downloaded_models = {}
    while retry > 0:
        try:
            model_path = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                cache_dir="models",
                resume_download=True,
            )
            downloaded_models[repo_id] = {"path": model_path}
            with open("./models/downloaded_models.json", "w") as f:
                json.dump(downloaded_models, f)
            break
        except:
            print("Error downloading model, retrying...")
            retry -= 1
    if retry == 0:
        raise Exception("Error downloading model, please try again later.")
    return model_path


class LLaMA_Client(BaseLLMModel):
    def __init__(self, model_name, lora_path=None, user_name="") -> None:
        super().__init__(model_name=model_name, user=user_name)

        self.max_generation_token = 1000
        if model_name in MODEL_METADATA:
            path_to_model = download(
                MODEL_METADATA[model_name]["repo_id"],
                MODEL_METADATA[model_name]["filelist"][0],
            )
        else:
            dir_to_model = os.path.join("models", model_name)
            # look for nay .gguf file in the dir_to_model directory and its subdirectories
            path_to_model = None
            for root, dirs, files in os.walk(dir_to_model):
                for file in files:
                    if file.endswith(".gguf"):
                        path_to_model = os.path.join(root, file)
                        break
                if path_to_model is not None:
                    break
        self.system_prompt = ""

        if lora_path is not None:
            lora_path = os.path.join("lora", lora_path)
            self.model = Llama(model_path=path_to_model, lora_path=lora_path)
        else:
            self.model = Llama(model_path=path_to_model)

    def _get_llama_style_input(self):
        context = []
        for conv in self.history:
            if conv["role"] == "system":
                context.append(SYS_PREFIX + conv["content"] + SYS_POSTFIX)
            elif conv["role"] == "user":
                context.append(
                    INST_PREFIX + conv["content"] + INST_POSTFIX + OUTPUT_PREFIX
                )
            else:
                context.append(conv["content"] + OUTPUT_POSTFIX)
        return "".join(context)
        # for conv in self.history:
        #     if conv["role"] == "system":
        #         context.append(conv["content"])
        #     elif conv["role"] == "user":
        #         context.append(
        #             conv["content"]
        #         )
        #     else:
        #         context.append(conv["content"])
        # return "\n\n".join(context)+"\n\n"

    def get_answer_at_once(self):
        context = self._get_llama_style_input()
        response = self.model(
            context,
            max_tokens=self.max_generation_token,
            stop=[],
            echo=False,
            stream=False,
        )
        return response, len(response)

    def get_answer_stream_iter(self):
        context = self._get_llama_style_input()
        iter = self.model(
            context,
            max_tokens=self.max_generation_token,
            stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX,OUTPUT_POSTFIX],
            echo=False,
            stream=True,
        )
        partial_text = ""
        for i in iter:
            response = i["choices"][0]["text"]
            partial_text += response
            yield partial_text