File size: 3,205 Bytes
c69cba4
cd726a4
c69cba4
 
 
 
 
 
 
 
0bd8ed4
c69cba4
 
19fc5ee
c69cba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988981a
 
 
 
 
 
 
 
c69cba4
799264a
d22d549
c69cba4
 
 
799264a
 
 
c69cba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from dataclasses import dataclass, field, asdict
from typing import Any, Union

from qa_engine import logger


def get_env(env_name: str, default: Any = None, warn: bool = True) -> str:
    env = os.getenv(env_name)
    if not env:
        if default is not None:
            if warn:
                logger.warning(
                    f'Environment variable {env_name} not found. ' \
                    f'Using the default value: {default}.'
                )
            return default
        else:
            raise ValueError(f'Cannot parse: {env_name}')
    else:
        return env


@dataclass
class Config:
    # QA Engine config
    question_answering_model_id: str = get_env('QUESTION_ANSWERING_MODEL_ID')
    embedding_model_id: str = get_env('EMBEDDING_MODEL_ID')
    index_repo_id: str = get_env('INDEX_REPO_ID')
    prompt_template_name: str = get_env('PROMPT_TEMPLATE_NAME')
    use_docs_for_context: bool = eval(get_env('USE_DOCS_FOR_CONTEXT', 'True'))
    num_relevant_docs: bool = eval(get_env('NUM_RELEVANT_DOCS', 3))
    add_sources_to_response: bool = eval(get_env('ADD_SOURCES_TO_RESPONSE', 'True'))
    use_messages_in_context: bool = eval(get_env('USE_MESSAGES_IN_CONTEXT', 'True'))
    debug: bool = eval(get_env('DEBUG', 'True'))

    # Model config
    min_new_tokens: int = int(get_env('MIN_NEW_TOKENS', 64))
    max_new_tokens: int = int(get_env('MAX_NEW_TOKENS', 800))
    temperature: float = float(get_env('TEMPERATURE', 0.6))
    top_k: int = int(get_env('TOP_K', 50))
    top_p: float = float(get_env('TOP_P', 0.95))
    do_sample: bool = eval(get_env('DO_SAMPLE', 'True'))

    # Discord bot config - optional
    discord_token: str = get_env('DISCORD_TOKEN', '-', warn=False)
    discord_channel_ids: list[int] = get_env('DISCORD_CHANNEL_IDS', field(default_factory=list), warn=False)
    num_last_messages: int = int(get_env('NUM_LAST_MESSAGES', 2, warn=False))
    use_names_in_context: bool = eval(get_env('USE_NAMES_IN_CONTEXT', 'False', warn=False))
    enable_commands: bool = eval(get_env('ENABLE_COMMANDS', 'True', warn=False))

    # App mode
    app_mode: str = get_env('APP_MODE', '-', warn=False) # 'gradio' or 'discord'

    def __post_init__(self):
        prompt_template_file = f'config/prompt_templates/{self.prompt_template_name}.txt'
        with open(prompt_template_file, 'r') as f:
            self.prompt_template = f.read()
        # validate config
        if 'context' not in self.prompt_template:
            raise ValueError("Prompt Template does not contain the 'context' field.")
        if 'question' not in self.prompt_template:
            raise ValueError("Prompt Template does not contain the 'question' field.")
        if not self.use_docs_for_context and self.add_sources_to_response:
            raise ValueError('Cannot add sources to response if not using docs in context')
        if self.num_relevant_docs < 1:
            raise ValueError('num_relevant_docs must be greater than 0')
        self.log()

    def asdict(self) -> dict:
        return asdict(self)

    def log(self) -> None:
        logger.info('Config:')
        for key, value in self.asdict().items():
            logger.info(f'{key}: {value}')