File size: 2,464 Bytes
b7d109d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from modules import shared
from modules.utils import get_available_models
from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls,
                                     update_model_parameters)

from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import *


def get_current_model_list() -> list:
    return [shared.model_name]  # The real chat/completions model, maybe "None"


def get_pseudo_model_list() -> list:
    return [  # these are expected by so much, so include some here as a dummy
        'gpt-3.5-turbo',
        'text-embedding-ada-002',
    ]


def load_model(model_name: str) -> dict:
    resp = {
        "id": model_name,
        "object": "engine",
        "owner": "self",
        "ready": True,
    }
    if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list():  # Real model only
        # No args. Maybe it works anyways!
        # TODO: hack some heuristics into args for better results

        shared.model_name = model_name
        unload_model()

        model_settings = get_model_settings_from_yamls(shared.model_name)
        shared.settings.update(model_settings)
        update_model_parameters(model_settings, initial=True)

        if shared.settings['mode'] != 'instruct':
            shared.settings['instruction_template'] = None

        shared.model, shared.tokenizer = load_model(shared.model_name)

        if not shared.model:  # load failed.
            shared.model_name = "None"
            raise OpenAIError(f"Model load failed for: {shared.model_name}")

    return resp


def list_models(is_legacy: bool = False) -> dict:
    # TODO: Lora's?
    all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models()

    models = {}

    if is_legacy:
        models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list]
        if not shared.model:
            models[0]['ready'] = False
    else:
        models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list]

    resp = {
        "object": "list",
        "data": models,
    }

    return resp


def model_info(model_name: str) -> dict:
    return {
        "id": model_name,
        "object": "model",
        "owned_by": "user",
        "permission": []
    }