File size: 2,964 Bytes
3332aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import threading
import toml
from pathlib import Path

from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt

from modules.llms import (
    LLMFactory,
    PromptFmt, PromptManager, PPManager, UIPPManager, LLMService
)

class ChatGPTFactory(LLMFactory):
    def __init__(self):
        pass

    def create_prompt_format(self):
        return ChatGPTChatPromptFmt()

    def create_prompt_manager(self, prompts_path: str=None):
        return ChatGPTPromptManager((prompts_path or Path('.') / 'prompts' / 'chatgpt_prompts.toml'))
    
    def create_pp_manager(self):
        return ChatGPTChatPPManager()

    def create_ui_pp_manager(self):
        return GradioChatGPTChatPPManager()
    
    def create_llm_service(self):
        return ChatGPTService()
    

class ChatGPTChatPromptFmt(PromptFmt):
    @classmethod
    def ctx(cls, context):
        pass

    @classmethod
    def prompt(cls, pingpong, truncate_size):
        pass


class ChatGPTPromptManager(PromptManager):
    _instance = None
    _lock = threading.Lock()
    _prompts = None

    def __new__(cls, prompts_path):
        if cls._instance is None:
            with cls._lock:
                if not cls._instance:
                    cls._instance = super(ChatGPTPromptManager, cls).__new__(cls)
                    cls._instance.load_prompts(prompts_path)
        return cls._instance

    def load_prompts(self, prompts_path):
        self._prompts_path = prompts_path
        self.reload_prompts()

    def reload_prompts(self):
        assert self.prompts_path, "Prompt path is missing."
        self._prompts = toml.load(self.prompts_path)

    @property
    def prompts_path(self):
        return self._prompts_path
    
    @prompts_path.setter
    def prompts_path(self, prompts_path):
        self._prompts_path = prompts_path
        self.reload_prompts()

    @property
    def prompts(self):
        if self._prompts is None:
            self.load_prompts()
        return self._prompts


class ChatGPTChatPPManager(PPManager):
    def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=None, truncate_size: int=None):
        pass


class GradioChatGPTChatPPManager(UIPPManager, ChatGPTChatPPManager):
    def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
        pass

class ChatGPTService(LLMService):
    def make_params(self, mode="chat",
                          temperature=None,
                          candidate_count=None,
                          top_k=None,
                          top_p=None,
                          max_output_tokens=None,
                          use_filter=True):
        pass
    
    async def gen_text(
        self,
        prompt,
        mode="chat", #chat or text
        parameters=None,
        use_filter=True
    ):
        pass