File size: 5,471 Bytes
ca56e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger

from api.config import SETTINGS
from api.utils.compat import model_dump


def create_app() -> FastAPI:
    """ create fastapi app server """
    app = FastAPI()
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    return app


def create_embedding_model():
    """ get embedding model from sentence-transformers. """
    if SETTINGS.tei_endpoint is not None:
        from openai import AsyncOpenAI
        client = AsyncOpenAI(base_url=SETTINGS.tei_endpoint, api_key="none")
    else:
        from sentence_transformers import SentenceTransformer
        client = SentenceTransformer(SETTINGS.embedding_name, device=SETTINGS.embedding_device)
    return client


def create_generate_model():
    """ get generate model for chat or completion. """
    from api.core.default import DefaultEngine
    from api.adapter.model import load_model

    if SETTINGS.patch_type == "attention":
        from api.utils.patches import apply_attention_patch

        apply_attention_patch(use_memory_efficient_attention=True)
    if SETTINGS.patch_type == "ntk":
        from api.utils.patches import apply_ntk_scaling_patch

        apply_ntk_scaling_patch(SETTINGS.alpha)

    include = {
        "model_name", "quantize", "device", "device_map", "num_gpus", "pre_seq_len",
        "load_in_8bit", "load_in_4bit", "using_ptuning_v2", "dtype", "resize_embeddings"
    }
    kwargs = model_dump(SETTINGS, include=include)

    model, tokenizer = load_model(
        model_name_or_path=SETTINGS.model_path,
        adapter_model=SETTINGS.adapter_model_path,
        **kwargs,
    )

    logger.info("Using default engine")

    return DefaultEngine(
        model,
        tokenizer,
        SETTINGS.device,
        model_name=SETTINGS.model_name,
        context_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
        prompt_name=SETTINGS.chat_template,
        use_streamer_v2=SETTINGS.use_streamer_v2,
    )


def create_vllm_engine():
    """ get vllm generate engine for chat or completion. """
    try:
        from vllm.engine.arg_utils import AsyncEngineArgs
        from vllm.engine.async_llm_engine import AsyncLLMEngine
        from vllm.transformers_utils.tokenizer import get_tokenizer
        from api.core.vllm_engine import VllmEngine
    except ImportError:
        return None

    include = {
        "tokenizer_mode", "trust_remote_code", "tensor_parallel_size",
        "dtype", "gpu_memory_utilization", "max_num_seqs",
    }
    kwargs = model_dump(SETTINGS, include=include)
    engine_args = AsyncEngineArgs(
        model=SETTINGS.model_path,
        max_num_batched_tokens=SETTINGS.max_num_batched_tokens if SETTINGS.max_num_batched_tokens > 0 else None,
        max_model_len=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
        quantization=SETTINGS.quantization_method,
        **kwargs,
    )
    engine = AsyncLLMEngine.from_engine_args(engine_args)

    # A separate tokenizer to map token IDs to strings.
    tokenizer = get_tokenizer(
        engine_args.tokenizer,
        tokenizer_mode=engine_args.tokenizer_mode,
        trust_remote_code=True,
    )

    logger.info("Using vllm engine")

    return VllmEngine(
        engine,
        tokenizer,
        SETTINGS.model_name,
        SETTINGS.chat_template,
        SETTINGS.context_length,
    )


def create_llama_cpp_engine():
    """ get llama.cpp generate engine for chat or completion. """
    try:
        from llama_cpp import Llama
        from api.core.llama_cpp_engine import LlamaCppEngine
    except ImportError:
        return None

    include = {
        "n_gpu_layers", "main_gpu", "tensor_split", "n_batch", "n_threads",
        "n_threads_batch", "rope_scaling_type", "rope_freq_base", "rope_freq_scale"
    }
    kwargs = model_dump(SETTINGS, include=include)
    engine = Llama(
        model_path=SETTINGS.model_path,
        n_ctx=SETTINGS.context_length if SETTINGS.context_length > 0 else 2048,
        **kwargs,
    )

    logger.info("Using llama.cpp engine")

    return LlamaCppEngine(engine, SETTINGS.model_name, SETTINGS.chat_template)


def create_tgi_engine():
    """ get llama.cpp generate engine for chat or completion. """
    try:
        from text_generation import AsyncClient
        from api.core.tgi import TGIEngine
    except ImportError:
        return None

    client = AsyncClient(SETTINGS.tgi_endpoint)
    logger.info("Using TGI engine")

    return TGIEngine(client, SETTINGS.model_name, SETTINGS.chat_template)


# fastapi app
app = create_app()

# model for embedding
EMBEDDED_MODEL = create_embedding_model() if (SETTINGS.embedding_name and SETTINGS.activate_inference) else None

# model for transformers generate
if (not SETTINGS.only_embedding) and SETTINGS.activate_inference:
    if SETTINGS.engine == "default":
        GENERATE_ENGINE = create_generate_model()
    elif SETTINGS.engine == "vllm":
        GENERATE_ENGINE = create_vllm_engine()
    elif SETTINGS.engine == "llama.cpp":
        GENERATE_ENGINE = create_llama_cpp_engine()
    elif SETTINGS.engine == "tgi":
        GENERATE_ENGINE = create_tgi_engine()
else:
    GENERATE_ENGINE = None

# model names for special processing
EXCLUDE_MODELS = ["baichuan-13b", "baichuan2-13b", "qwen", "chatglm3"]